refactor(prompt): use Provider service in effect layers (#20167)

pull/20375/head
Kit Langton 2026-03-31 20:07:58 -04:00 committed by opencode
parent 6314f09c14
commit 181b5f6236
8 changed files with 163 additions and 322 deletions

View File

@ -75,6 +75,7 @@ export namespace Agent {
const config = yield* Config.Service const config = yield* Config.Service
const auth = yield* Auth.Service const auth = yield* Auth.Service
const skill = yield* Skill.Service const skill = yield* Skill.Service
const provider = yield* Provider.Service
const state = yield* InstanceState.make<State>( const state = yield* InstanceState.make<State>(
Effect.fn("Agent.state")(function* (ctx) { Effect.fn("Agent.state")(function* (ctx) {
@ -330,9 +331,9 @@ export namespace Agent {
model?: { providerID: ProviderID; modelID: ModelID } model?: { providerID: ProviderID; modelID: ModelID }
}) { }) {
const cfg = yield* config.get() const cfg = yield* config.get()
const model = input.model ?? (yield* Effect.promise(() => Provider.defaultModel())) const model = input.model ?? (yield* provider.defaultModel())
const resolved = yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID)) const resolved = yield* provider.getModel(model.providerID, model.modelID)
const language = yield* Effect.promise(() => Provider.getLanguage(resolved)) const language = yield* provider.getLanguage(resolved)
const system = [PROMPT_GENERATE] const system = [PROMPT_GENERATE]
yield* Effect.promise(() => yield* Effect.promise(() =>
@ -393,6 +394,7 @@ export namespace Agent {
) )
export const defaultLayer = layer.pipe( export const defaultLayer = layer.pipe(
Layer.provide(Provider.defaultLayer),
Layer.provide(Auth.defaultLayer), Layer.provide(Auth.defaultLayer),
Layer.provide(Config.defaultLayer), Layer.provide(Config.defaultLayer),
Layer.provide(Skill.defaultLayer), Layer.provide(Skill.defaultLayer),

View File

@ -1541,10 +1541,9 @@ export namespace Provider {
}), }),
) )
const { runPromise } = makeRuntime( export const defaultLayer = layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer))
Service,
layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer)), const { runPromise } = makeRuntime(Service, defaultLayer)
)
export async function list() { export async function list() {
return runPromise((svc) => svc.list()) return runPromise((svc) => svc.list())

View File

@ -63,7 +63,13 @@ export namespace SessionCompaction {
export const layer: Layer.Layer< export const layer: Layer.Layer<
Service, Service,
never, never,
Bus.Service | Config.Service | Session.Service | Agent.Service | Plugin.Service | SessionProcessor.Service | Bus.Service
| Config.Service
| Session.Service
| Agent.Service
| Plugin.Service
| SessionProcessor.Service
| Provider.Service
> = Layer.effect( > = Layer.effect(
Service, Service,
Effect.gen(function* () { Effect.gen(function* () {
@ -73,6 +79,7 @@ export namespace SessionCompaction {
const agents = yield* Agent.Service const agents = yield* Agent.Service
const plugin = yield* Plugin.Service const plugin = yield* Plugin.Service
const processors = yield* SessionProcessor.Service const processors = yield* SessionProcessor.Service
const provider = yield* Provider.Service
const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: { const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: {
tokens: MessageV2.Assistant["tokens"] tokens: MessageV2.Assistant["tokens"]
@ -170,11 +177,9 @@ export namespace SessionCompaction {
} }
const agent = yield* agents.get("compaction") const agent = yield* agents.get("compaction")
const model = yield* Effect.promise(() => const model = agent.model
agent.model ? yield* provider.getModel(agent.model.providerID, agent.model.modelID)
? Provider.getModel(agent.model.providerID, agent.model.modelID) : yield* provider.getModel(userMessage.model.providerID, userMessage.model.modelID)
: Provider.getModel(userMessage.model.providerID, userMessage.model.modelID),
)
// Allow plugins to inject context or replace compaction prompt. // Allow plugins to inject context or replace compaction prompt.
const compacting = yield* plugin.trigger( const compacting = yield* plugin.trigger(
"experimental.session.compacting", "experimental.session.compacting",
@ -377,6 +382,7 @@ When constructing the summary, try to stick to this template:
export const defaultLayer = Layer.unwrap( export const defaultLayer = Layer.unwrap(
Effect.sync(() => Effect.sync(() =>
layer.pipe( layer.pipe(
Layer.provide(Provider.defaultLayer),
Layer.provide(Session.defaultLayer), Layer.provide(Session.defaultLayer),
Layer.provide(SessionProcessor.defaultLayer), Layer.provide(SessionProcessor.defaultLayer),
Layer.provide(Agent.defaultLayer), Layer.provide(Agent.defaultLayer),

View File

@ -84,6 +84,7 @@ export namespace SessionPrompt {
const status = yield* SessionStatus.Service const status = yield* SessionStatus.Service
const sessions = yield* Session.Service const sessions = yield* Session.Service
const agents = yield* Agent.Service const agents = yield* Agent.Service
const provider = yield* Provider.Service
const processor = yield* SessionProcessor.Service const processor = yield* SessionProcessor.Service
const compaction = yield* SessionCompaction.Service const compaction = yield* SessionCompaction.Service
const plugin = yield* Plugin.Service const plugin = yield* Plugin.Service
@ -206,14 +207,14 @@ export namespace SessionPrompt {
const ag = yield* agents.get("title") const ag = yield* agents.get("title")
if (!ag) return if (!ag) return
const text = yield* Effect.promise(async (signal) => {
const mdl = ag.model const mdl = ag.model
? await Provider.getModel(ag.model.providerID, ag.model.modelID) ? yield* provider.getModel(ag.model.providerID, ag.model.modelID)
: ((await Provider.getSmallModel(input.providerID)) ?? : ((yield* provider.getSmallModel(input.providerID)) ??
(await Provider.getModel(input.providerID, input.modelID))) (yield* provider.getModel(input.providerID, input.modelID)))
const msgs = onlySubtasks const msgs = onlySubtasks
? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }] ? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
: await MessageV2.toModelMessages(context, mdl) : yield* Effect.promise(() => MessageV2.toModelMessages(context, mdl))
const text = yield* Effect.promise(async (signal) => {
const result = await LLM.stream({ const result = await LLM.stream({
agent: ag, agent: ag,
user: firstInfo, user: firstInfo,
@ -932,21 +933,35 @@ NOTE: At any point in time through this workflow you should feel free to ask the
return { info: msg, parts: [part] } return { info: msg, parts: [part] }
}) })
const getModel = (providerID: ProviderID, modelID: ModelID, sessionID: SessionID) => const getModel = Effect.fn("SessionPrompt.getModel")(function* (
Effect.promise(() => providerID: ProviderID,
Provider.getModel(providerID, modelID).catch((e) => { modelID: ModelID,
if (Provider.ModelNotFoundError.isInstance(e)) { sessionID: SessionID,
const hint = e.data.suggestions?.length ? ` Did you mean: ${e.data.suggestions.join(", ")}?` : "" ) {
Bus.publish(Session.Event.Error, { const exit = yield* provider.getModel(providerID, modelID).pipe(Effect.exit)
if (Exit.isSuccess(exit)) return exit.value
const err = Cause.squash(exit.cause)
if (Provider.ModelNotFoundError.isInstance(err)) {
const hint = err.data.suggestions?.length ? ` Did you mean: ${err.data.suggestions.join(", ")}?` : ""
yield* bus.publish(Session.Event.Error, {
sessionID, sessionID,
error: new NamedError.Unknown({ error: new NamedError.Unknown({
message: `Model not found: ${e.data.providerID}/${e.data.modelID}.${hint}`, message: `Model not found: ${err.data.providerID}/${err.data.modelID}.${hint}`,
}).toObject(), }).toObject(),
}) })
} }
throw e return yield* Effect.failCause(exit.cause)
}), })
)
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
const model = yield* Effect.promise(async () => {
for await (const item of MessageV2.stream(sessionID)) {
if (item.info.role === "user" && item.info.model) return item.info.model
}
})
if (model) return model
return yield* provider.defaultModel()
})
const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) { const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) {
const agentName = input.agent || (yield* agents.defaultAgent()) const agentName = input.agent || (yield* agents.defaultAgent())
@ -960,9 +975,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
} }
const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID)) const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID))
const same = ag.model && model.providerID === ag.model.providerID && model.modelID === ag.model.modelID
const full = const full =
!input.variant && ag.variant !input.variant && ag.variant && same
? yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID).catch(() => undefined)) ? yield* provider
.getModel(model.providerID, model.modelID)
.pipe(Effect.catch(() => Effect.succeed(undefined)))
: undefined : undefined
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined) const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
@ -1109,7 +1127,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
] ]
const read = yield* Effect.promise(() => ReadTool.init()).pipe( const read = yield* Effect.promise(() => ReadTool.init()).pipe(
Effect.flatMap((t) => Effect.flatMap((t) =>
Effect.promise(() => Provider.getModel(info.model.providerID, info.model.modelID)).pipe( provider.getModel(info.model.providerID, info.model.modelID).pipe(
Effect.flatMap((mdl) => Effect.flatMap((mdl) =>
Effect.promise(() => Effect.promise(() =>
t.execute(args, { t.execute(args, {
@ -1711,6 +1729,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
Layer.provide(FileTime.defaultLayer), Layer.provide(FileTime.defaultLayer),
Layer.provide(ToolRegistry.defaultLayer), Layer.provide(ToolRegistry.defaultLayer),
Layer.provide(Truncate.layer), Layer.provide(Truncate.layer),
Layer.provide(Provider.defaultLayer),
Layer.provide(AppFileSystem.defaultLayer), Layer.provide(AppFileSystem.defaultLayer),
Layer.provide(Plugin.defaultLayer), Layer.provide(Plugin.defaultLayer),
Layer.provide(Session.defaultLayer), Layer.provide(Session.defaultLayer),
@ -1856,15 +1875,6 @@ NOTE: At any point in time through this workflow you should feel free to ask the
return runPromise((svc) => svc.command(CommandInput.parse(input))) return runPromise((svc) => svc.command(CommandInput.parse(input)))
} }
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
return yield* Effect.promise(async () => {
for await (const item of MessageV2.stream(sessionID)) {
if (item.info.role === "user" && item.info.model) return item.info.model
}
return Provider.defaultModel()
})
})
/** @internal Exported for testing */ /** @internal Exported for testing */
export function createStructuredOutputTool(input: { export function createStructuredOutputTool(input: {
schema: Record<string, any> schema: Record<string, any>

View File

@ -0,0 +1,81 @@
import { Effect, Layer } from "effect"
import { Provider } from "../../src/provider/provider"
import { ModelID, ProviderID } from "../../src/provider/schema"
export namespace ProviderTest {
export function model(override: Partial<Provider.Model> = {}): Provider.Model {
const id = override.id ?? ModelID.make("gpt-5.2")
const providerID = override.providerID ?? ProviderID.make("openai")
return {
id,
providerID,
name: "Test Model",
capabilities: {
toolcall: true,
attachment: false,
reasoning: false,
temperature: true,
interleaved: false,
input: { text: true, image: false, audio: false, video: false, pdf: false },
output: { text: true, image: false, audio: false, video: false, pdf: false },
},
api: { id, url: "https://example.com", npm: "@ai-sdk/openai" },
cost: { input: 0, output: 0, cache: { read: 0, write: 0 } },
limit: { context: 200_000, output: 10_000 },
status: "active",
options: {},
headers: {},
release_date: "2025-01-01",
...override,
}
}
export function info(override: Partial<Provider.Info> = {}, mdl = model()): Provider.Info {
const id = override.id ?? mdl.providerID
return {
id,
name: "Test Provider",
source: "config",
env: [],
options: {},
models: { [mdl.id]: mdl },
...override,
}
}
export function fake(override: Partial<Provider.Interface> & { model?: Provider.Model; info?: Provider.Info } = {}) {
const mdl = override.model ?? model()
const row = override.info ?? info({}, mdl)
return {
model: mdl,
info: row,
layer: Layer.succeed(
Provider.Service,
Provider.Service.of({
list: Effect.fn("TestProvider.list")(() => Effect.succeed({ [row.id]: row })),
getProvider: Effect.fn("TestProvider.getProvider")((providerID) => {
if (providerID === row.id) return Effect.succeed(row)
return Effect.die(new Error(`Unknown test provider: ${providerID}`))
}),
getModel: Effect.fn("TestProvider.getModel")((providerID, modelID) => {
if (providerID === row.id && modelID === mdl.id) return Effect.succeed(mdl)
return Effect.die(new Error(`Unknown test model: ${providerID}/${modelID}`))
}),
getLanguage: Effect.fn("TestProvider.getLanguage")(() =>
Effect.die(new Error("ProviderTest.getLanguage not configured")),
),
closest: Effect.fn("TestProvider.closest")((providerID) =>
Effect.succeed(providerID === row.id ? { providerID: row.id, modelID: mdl.id } : undefined),
),
getSmallModel: Effect.fn("TestProvider.getSmallModel")((providerID) =>
Effect.succeed(providerID === row.id ? mdl : undefined),
),
defaultModel: Effect.fn("TestProvider.defaultModel")(() =>
Effect.succeed({ providerID: row.id, modelID: mdl.id }),
),
...override,
}),
),
}
}
}

View File

@ -1,4 +1,4 @@
import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" import { afterEach, describe, expect, mock, test } from "bun:test"
import { APICallError } from "ai" import { APICallError } from "ai"
import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect" import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect"
import * as Stream from "effect/Stream" import * as Stream from "effect/Stream"
@ -20,9 +20,9 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema"
import { SessionStatus } from "../../src/session/status" import { SessionStatus } from "../../src/session/status"
import { ModelID, ProviderID } from "../../src/provider/schema" import { ModelID, ProviderID } from "../../src/provider/schema"
import type { Provider } from "../../src/provider/provider" import type { Provider } from "../../src/provider/provider"
import * as ProviderModule from "../../src/provider/provider"
import * as SessionProcessorModule from "../../src/session/processor" import * as SessionProcessorModule from "../../src/session/processor"
import { Snapshot } from "../../src/snapshot" import { Snapshot } from "../../src/snapshot"
import { ProviderTest } from "../fake/provider"
Log.init({ print: false }) Log.init({ print: false })
@ -65,6 +65,8 @@ function createModel(opts: {
} as Provider.Model } as Provider.Model
} }
const wide = () => ProviderTest.fake({ model: createModel({ context: 100_000, output: 32_000 }) })
async function user(sessionID: SessionID, text: string) { async function user(sessionID: SessionID, text: string) {
const msg = await Session.updateMessage({ const msg = await Session.updateMessage({
id: MessageID.ascending(), id: MessageID.ascending(),
@ -162,10 +164,11 @@ function layer(result: "continue" | "compact") {
) )
} }
function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer) { function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer, provider = ProviderTest.fake()) {
const bus = Bus.layer const bus = Bus.layer
return ManagedRuntime.make( return ManagedRuntime.make(
Layer.mergeAll(SessionCompaction.layer, bus).pipe( Layer.mergeAll(SessionCompaction.layer, bus).pipe(
Layer.provide(provider.layer),
Layer.provide(Session.defaultLayer), Layer.provide(Session.defaultLayer),
Layer.provide(layer(result)), Layer.provide(layer(result)),
Layer.provide(Agent.defaultLayer), Layer.provide(Agent.defaultLayer),
@ -198,12 +201,13 @@ function llm() {
} }
} }
function liveRuntime(layer: Layer.Layer<LLM.Service>) { function liveRuntime(layer: Layer.Layer<LLM.Service>, provider = ProviderTest.fake()) {
const bus = Bus.layer const bus = Bus.layer
const status = SessionStatus.layer.pipe(Layer.provide(bus)) const status = SessionStatus.layer.pipe(Layer.provide(bus))
const processor = SessionProcessorModule.SessionProcessor.layer const processor = SessionProcessorModule.SessionProcessor.layer
return ManagedRuntime.make( return ManagedRuntime.make(
Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe( Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe(
Layer.provide(provider.layer),
Layer.provide(Session.defaultLayer), Layer.provide(Session.defaultLayer),
Layer.provide(Snapshot.defaultLayer), Layer.provide(Snapshot.defaultLayer),
Layer.provide(layer), Layer.provide(layer),
@ -544,14 +548,12 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const done = defer() const done = defer()
let seen = false let seen = false
const rt = runtime("continue") const rt = runtime("continue", Plugin.defaultLayer, wide())
let unsub: (() => void) | undefined let unsub: (() => void) | undefined
try { try {
unsub = await rt.runPromise( unsub = await rt.runPromise(
@ -596,11 +598,9 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const rt = runtime("compact") const rt = runtime("compact", Plugin.defaultLayer, wide())
try { try {
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise( const result = await rt.runPromise(
@ -636,11 +636,9 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const rt = runtime("continue") const rt = runtime("continue", Plugin.defaultLayer, wide())
try { try {
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise( const result = await rt.runPromise(
@ -678,8 +676,6 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
await user(session.id, "root") await user(session.id, "root")
const replay = await user(session.id, "image") const replay = await user(session.id, "image")
@ -693,7 +689,7 @@ describe("session.compaction.process", () => {
url: "https://example.com/cat.png", url: "https://example.com/cat.png",
}) })
const msg = await user(session.id, "current") const msg = await user(session.id, "current")
const rt = runtime("continue") const rt = runtime("continue", Plugin.defaultLayer, wide())
try { try {
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise( const result = await rt.runPromise(
@ -728,13 +724,11 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
await user(session.id, "earlier") await user(session.id, "earlier")
const msg = await user(session.id, "current") const msg = await user(session.id, "current")
const rt = runtime("continue") const rt = runtime("continue", Plugin.defaultLayer, wide())
try { try {
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise( const result = await rt.runPromise(
@ -790,13 +784,11 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const abort = new AbortController() const abort = new AbortController()
const rt = liveRuntime(stub.layer) const rt = liveRuntime(stub.layer, wide())
let off: (() => void) | undefined let off: (() => void) | undefined
let run: Promise<"continue" | "stop"> | undefined let run: Promise<"continue" | "stop"> | undefined
try { try {
@ -866,13 +858,11 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
const abort = new AbortController() const abort = new AbortController()
const rt = runtime("continue", plugin(ready)) const rt = runtime("continue", plugin(ready), wide())
let run: Promise<"continue" | "stop"> | undefined let run: Promise<"continue" | "stop"> | undefined
try { try {
run = rt run = rt
@ -970,11 +960,9 @@ describe("session.compaction.process", () => {
await Instance.provide({ await Instance.provide({
directory: tmp.path, directory: tmp.path,
fn: async () => { fn: async () => {
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
const session = await Session.create({}) const session = await Session.create({})
const msg = await user(session.id, "hello") const msg = await user(session.id, "hello")
const rt = liveRuntime(stub.layer) const rt = liveRuntime(stub.layer, wide())
try { try {
const msgs = await Session.messages({ sessionID: session.id }) const msgs = await Session.messages({ sessionID: session.id })
await rt.runPromise( await rt.runPromise(

View File

@ -1,247 +0,0 @@
import { describe, expect, spyOn, test } from "bun:test"
import { Instance } from "../../src/project/instance"
import { Provider } from "../../src/provider/provider"
import { Session } from "../../src/session"
import { MessageV2 } from "../../src/session/message-v2"
import { SessionPrompt } from "../../src/session/prompt"
import { SessionStatus } from "../../src/session/status"
import { MessageID, PartID, SessionID } from "../../src/session/schema"
import { Log } from "../../src/util/log"
import { tmpdir } from "../fixture/fixture"
Log.init({ print: false })
function deferred() {
let resolve!: () => void
const promise = new Promise<void>((done) => {
resolve = done
})
return { promise, resolve }
}
// Helper: seed a session with a user message + finished assistant message
// so loop() exits immediately without calling any LLM
async function seed(sessionID: SessionID) {
const userMsg: MessageV2.Info = {
id: MessageID.ascending(),
role: "user",
sessionID,
time: { created: Date.now() },
agent: "build",
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
}
await Session.updateMessage(userMsg)
await Session.updatePart({
id: PartID.ascending(),
messageID: userMsg.id,
sessionID,
type: "text",
text: "hello",
})
const assistantMsg: MessageV2.Info = {
id: MessageID.ascending(),
role: "assistant",
parentID: userMsg.id,
sessionID,
mode: "build",
agent: "build",
cost: 0,
path: { cwd: "/tmp", root: "/tmp" },
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
modelID: "gpt-5.2" as any,
providerID: "openai" as any,
time: { created: Date.now(), completed: Date.now() },
finish: "stop",
}
await Session.updateMessage(assistantMsg)
await Session.updatePart({
id: PartID.ascending(),
messageID: assistantMsg.id,
sessionID,
type: "text",
text: "hi there",
})
return { userMsg, assistantMsg }
}
describe("session.prompt concurrency", () => {
test("loop returns assistant message and sets status to idle", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
const session = await Session.create({})
await seed(session.id)
const result = await SessionPrompt.loop({ sessionID: session.id })
expect(result.info.role).toBe("assistant")
if (result.info.role === "assistant") expect(result.info.finish).toBe("stop")
const status = await SessionStatus.get(session.id)
expect(status.type).toBe("idle")
},
})
})
test("concurrent loop callers get the same result", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
const session = await Session.create({})
await seed(session.id)
const [a, b] = await Promise.all([
SessionPrompt.loop({ sessionID: session.id }),
SessionPrompt.loop({ sessionID: session.id }),
])
expect(a.info.id).toBe(b.info.id)
expect(a.info.role).toBe("assistant")
},
})
})
test("assertNotBusy throws when loop is running", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
const session = await Session.create({})
const userMsg: MessageV2.Info = {
id: MessageID.ascending(),
role: "user",
sessionID: session.id,
time: { created: Date.now() },
agent: "build",
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
}
await Session.updateMessage(userMsg)
await Session.updatePart({
id: PartID.ascending(),
messageID: userMsg.id,
sessionID: session.id,
type: "text",
text: "hello",
})
const ready = deferred()
const gate = deferred()
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
ready.resolve()
await gate.promise
throw new Error("test stop")
})
try {
const loopPromise = SessionPrompt.loop({ sessionID: session.id }).catch(() => undefined)
await ready.promise
await expect(SessionPrompt.assertNotBusy(session.id)).rejects.toBeInstanceOf(Session.BusyError)
gate.resolve()
await loopPromise
} finally {
gate.resolve()
getModel.mockRestore()
}
// After loop completes, assertNotBusy should succeed
await SessionPrompt.assertNotBusy(session.id)
},
})
})
test("cancel sets status to idle", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
const session = await Session.create({})
// Seed only a user message — loop must call getModel to proceed
const userMsg: MessageV2.Info = {
id: MessageID.ascending(),
role: "user",
sessionID: session.id,
time: { created: Date.now() },
agent: "build",
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
}
await Session.updateMessage(userMsg)
await Session.updatePart({
id: PartID.ascending(),
messageID: userMsg.id,
sessionID: session.id,
type: "text",
text: "hello",
})
// Also seed an assistant message so lastAssistant() fallback can find it
const assistantMsg: MessageV2.Info = {
id: MessageID.ascending(),
role: "assistant",
parentID: userMsg.id,
sessionID: session.id,
mode: "build",
agent: "build",
cost: 0,
path: { cwd: "/tmp", root: "/tmp" },
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
modelID: "gpt-5.2" as any,
providerID: "openai" as any,
time: { created: Date.now() },
}
await Session.updateMessage(assistantMsg)
await Session.updatePart({
id: PartID.ascending(),
messageID: assistantMsg.id,
sessionID: session.id,
type: "text",
text: "hi there",
})
const ready = deferred()
const gate = deferred()
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
ready.resolve()
await gate.promise
throw new Error("test stop")
})
try {
// Start loop — it will block in getModel (assistant has no finish, so loop continues)
const loopPromise = SessionPrompt.loop({ sessionID: session.id })
await ready.promise
await SessionPrompt.cancel(session.id)
const status = await SessionStatus.get(session.id)
expect(status.type).toBe("idle")
// loop should resolve cleanly, not throw "All fibers interrupted"
const result = await loopPromise
expect(result.info.role).toBe("assistant")
expect(result.info.id).toBe(assistantMsg.id)
} finally {
gate.resolve()
getModel.mockRestore()
}
},
})
}, 10000)
test("cancel on idle session just sets idle", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
const session = await Session.create({})
await SessionPrompt.cancel(session.id)
const status = await SessionStatus.get(session.id)
expect(status.type).toBe("idle")
},
})
})
})

View File

@ -12,6 +12,7 @@ import { LSP } from "../../src/lsp"
import { MCP } from "../../src/mcp" import { MCP } from "../../src/mcp"
import { Permission } from "../../src/permission" import { Permission } from "../../src/permission"
import { Plugin } from "../../src/plugin" import { Plugin } from "../../src/plugin"
import { Provider as ProviderSvc } from "../../src/provider/provider"
import type { Provider } from "../../src/provider/provider" import type { Provider } from "../../src/provider/provider"
import { ModelID, ProviderID } from "../../src/provider/schema" import { ModelID, ProviderID } from "../../src/provider/schema"
import { Session } from "../../src/session" import { Session } from "../../src/session"
@ -151,6 +152,7 @@ function makeHttp() {
Permission.layer, Permission.layer,
Plugin.defaultLayer, Plugin.defaultLayer,
Config.defaultLayer, Config.defaultLayer,
ProviderSvc.defaultLayer,
filetime, filetime,
lsp, lsp,
mcp, mcp,