refactor(prompt): use Provider service in effect layers (#20167)
parent
6314f09c14
commit
181b5f6236
|
|
@ -75,6 +75,7 @@ export namespace Agent {
|
|||
const config = yield* Config.Service
|
||||
const auth = yield* Auth.Service
|
||||
const skill = yield* Skill.Service
|
||||
const provider = yield* Provider.Service
|
||||
|
||||
const state = yield* InstanceState.make<State>(
|
||||
Effect.fn("Agent.state")(function* (ctx) {
|
||||
|
|
@ -330,9 +331,9 @@ export namespace Agent {
|
|||
model?: { providerID: ProviderID; modelID: ModelID }
|
||||
}) {
|
||||
const cfg = yield* config.get()
|
||||
const model = input.model ?? (yield* Effect.promise(() => Provider.defaultModel()))
|
||||
const resolved = yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID))
|
||||
const language = yield* Effect.promise(() => Provider.getLanguage(resolved))
|
||||
const model = input.model ?? (yield* provider.defaultModel())
|
||||
const resolved = yield* provider.getModel(model.providerID, model.modelID)
|
||||
const language = yield* provider.getLanguage(resolved)
|
||||
|
||||
const system = [PROMPT_GENERATE]
|
||||
yield* Effect.promise(() =>
|
||||
|
|
@ -393,6 +394,7 @@ export namespace Agent {
|
|||
)
|
||||
|
||||
export const defaultLayer = layer.pipe(
|
||||
Layer.provide(Provider.defaultLayer),
|
||||
Layer.provide(Auth.defaultLayer),
|
||||
Layer.provide(Config.defaultLayer),
|
||||
Layer.provide(Skill.defaultLayer),
|
||||
|
|
|
|||
|
|
@ -1541,10 +1541,9 @@ export namespace Provider {
|
|||
}),
|
||||
)
|
||||
|
||||
const { runPromise } = makeRuntime(
|
||||
Service,
|
||||
layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer)),
|
||||
)
|
||||
export const defaultLayer = layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer))
|
||||
|
||||
const { runPromise } = makeRuntime(Service, defaultLayer)
|
||||
|
||||
export async function list() {
|
||||
return runPromise((svc) => svc.list())
|
||||
|
|
|
|||
|
|
@ -63,7 +63,13 @@ export namespace SessionCompaction {
|
|||
export const layer: Layer.Layer<
|
||||
Service,
|
||||
never,
|
||||
Bus.Service | Config.Service | Session.Service | Agent.Service | Plugin.Service | SessionProcessor.Service
|
||||
| Bus.Service
|
||||
| Config.Service
|
||||
| Session.Service
|
||||
| Agent.Service
|
||||
| Plugin.Service
|
||||
| SessionProcessor.Service
|
||||
| Provider.Service
|
||||
> = Layer.effect(
|
||||
Service,
|
||||
Effect.gen(function* () {
|
||||
|
|
@ -73,6 +79,7 @@ export namespace SessionCompaction {
|
|||
const agents = yield* Agent.Service
|
||||
const plugin = yield* Plugin.Service
|
||||
const processors = yield* SessionProcessor.Service
|
||||
const provider = yield* Provider.Service
|
||||
|
||||
const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: {
|
||||
tokens: MessageV2.Assistant["tokens"]
|
||||
|
|
@ -170,11 +177,9 @@ export namespace SessionCompaction {
|
|||
}
|
||||
|
||||
const agent = yield* agents.get("compaction")
|
||||
const model = yield* Effect.promise(() =>
|
||||
agent.model
|
||||
? Provider.getModel(agent.model.providerID, agent.model.modelID)
|
||||
: Provider.getModel(userMessage.model.providerID, userMessage.model.modelID),
|
||||
)
|
||||
const model = agent.model
|
||||
? yield* provider.getModel(agent.model.providerID, agent.model.modelID)
|
||||
: yield* provider.getModel(userMessage.model.providerID, userMessage.model.modelID)
|
||||
// Allow plugins to inject context or replace compaction prompt.
|
||||
const compacting = yield* plugin.trigger(
|
||||
"experimental.session.compacting",
|
||||
|
|
@ -377,6 +382,7 @@ When constructing the summary, try to stick to this template:
|
|||
export const defaultLayer = Layer.unwrap(
|
||||
Effect.sync(() =>
|
||||
layer.pipe(
|
||||
Layer.provide(Provider.defaultLayer),
|
||||
Layer.provide(Session.defaultLayer),
|
||||
Layer.provide(SessionProcessor.defaultLayer),
|
||||
Layer.provide(Agent.defaultLayer),
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ export namespace SessionPrompt {
|
|||
const status = yield* SessionStatus.Service
|
||||
const sessions = yield* Session.Service
|
||||
const agents = yield* Agent.Service
|
||||
const provider = yield* Provider.Service
|
||||
const processor = yield* SessionProcessor.Service
|
||||
const compaction = yield* SessionCompaction.Service
|
||||
const plugin = yield* Plugin.Service
|
||||
|
|
@ -206,14 +207,14 @@ export namespace SessionPrompt {
|
|||
|
||||
const ag = yield* agents.get("title")
|
||||
if (!ag) return
|
||||
const text = yield* Effect.promise(async (signal) => {
|
||||
const mdl = ag.model
|
||||
? await Provider.getModel(ag.model.providerID, ag.model.modelID)
|
||||
: ((await Provider.getSmallModel(input.providerID)) ??
|
||||
(await Provider.getModel(input.providerID, input.modelID)))
|
||||
? yield* provider.getModel(ag.model.providerID, ag.model.modelID)
|
||||
: ((yield* provider.getSmallModel(input.providerID)) ??
|
||||
(yield* provider.getModel(input.providerID, input.modelID)))
|
||||
const msgs = onlySubtasks
|
||||
? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
|
||||
: await MessageV2.toModelMessages(context, mdl)
|
||||
: yield* Effect.promise(() => MessageV2.toModelMessages(context, mdl))
|
||||
const text = yield* Effect.promise(async (signal) => {
|
||||
const result = await LLM.stream({
|
||||
agent: ag,
|
||||
user: firstInfo,
|
||||
|
|
@ -932,21 +933,35 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
return { info: msg, parts: [part] }
|
||||
})
|
||||
|
||||
const getModel = (providerID: ProviderID, modelID: ModelID, sessionID: SessionID) =>
|
||||
Effect.promise(() =>
|
||||
Provider.getModel(providerID, modelID).catch((e) => {
|
||||
if (Provider.ModelNotFoundError.isInstance(e)) {
|
||||
const hint = e.data.suggestions?.length ? ` Did you mean: ${e.data.suggestions.join(", ")}?` : ""
|
||||
Bus.publish(Session.Event.Error, {
|
||||
const getModel = Effect.fn("SessionPrompt.getModel")(function* (
|
||||
providerID: ProviderID,
|
||||
modelID: ModelID,
|
||||
sessionID: SessionID,
|
||||
) {
|
||||
const exit = yield* provider.getModel(providerID, modelID).pipe(Effect.exit)
|
||||
if (Exit.isSuccess(exit)) return exit.value
|
||||
const err = Cause.squash(exit.cause)
|
||||
if (Provider.ModelNotFoundError.isInstance(err)) {
|
||||
const hint = err.data.suggestions?.length ? ` Did you mean: ${err.data.suggestions.join(", ")}?` : ""
|
||||
yield* bus.publish(Session.Event.Error, {
|
||||
sessionID,
|
||||
error: new NamedError.Unknown({
|
||||
message: `Model not found: ${e.data.providerID}/${e.data.modelID}.${hint}`,
|
||||
message: `Model not found: ${err.data.providerID}/${err.data.modelID}.${hint}`,
|
||||
}).toObject(),
|
||||
})
|
||||
}
|
||||
throw e
|
||||
}),
|
||||
)
|
||||
return yield* Effect.failCause(exit.cause)
|
||||
})
|
||||
|
||||
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
|
||||
const model = yield* Effect.promise(async () => {
|
||||
for await (const item of MessageV2.stream(sessionID)) {
|
||||
if (item.info.role === "user" && item.info.model) return item.info.model
|
||||
}
|
||||
})
|
||||
if (model) return model
|
||||
return yield* provider.defaultModel()
|
||||
})
|
||||
|
||||
const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) {
|
||||
const agentName = input.agent || (yield* agents.defaultAgent())
|
||||
|
|
@ -960,9 +975,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
}
|
||||
|
||||
const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID))
|
||||
const same = ag.model && model.providerID === ag.model.providerID && model.modelID === ag.model.modelID
|
||||
const full =
|
||||
!input.variant && ag.variant
|
||||
? yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID).catch(() => undefined))
|
||||
!input.variant && ag.variant && same
|
||||
? yield* provider
|
||||
.getModel(model.providerID, model.modelID)
|
||||
.pipe(Effect.catch(() => Effect.succeed(undefined)))
|
||||
: undefined
|
||||
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
|
||||
|
||||
|
|
@ -1109,7 +1127,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
]
|
||||
const read = yield* Effect.promise(() => ReadTool.init()).pipe(
|
||||
Effect.flatMap((t) =>
|
||||
Effect.promise(() => Provider.getModel(info.model.providerID, info.model.modelID)).pipe(
|
||||
provider.getModel(info.model.providerID, info.model.modelID).pipe(
|
||||
Effect.flatMap((mdl) =>
|
||||
Effect.promise(() =>
|
||||
t.execute(args, {
|
||||
|
|
@ -1711,6 +1729,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
Layer.provide(FileTime.defaultLayer),
|
||||
Layer.provide(ToolRegistry.defaultLayer),
|
||||
Layer.provide(Truncate.layer),
|
||||
Layer.provide(Provider.defaultLayer),
|
||||
Layer.provide(AppFileSystem.defaultLayer),
|
||||
Layer.provide(Plugin.defaultLayer),
|
||||
Layer.provide(Session.defaultLayer),
|
||||
|
|
@ -1856,15 +1875,6 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
return runPromise((svc) => svc.command(CommandInput.parse(input)))
|
||||
}
|
||||
|
||||
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
|
||||
return yield* Effect.promise(async () => {
|
||||
for await (const item of MessageV2.stream(sessionID)) {
|
||||
if (item.info.role === "user" && item.info.model) return item.info.model
|
||||
}
|
||||
return Provider.defaultModel()
|
||||
})
|
||||
})
|
||||
|
||||
/** @internal Exported for testing */
|
||||
export function createStructuredOutputTool(input: {
|
||||
schema: Record<string, any>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
import { Effect, Layer } from "effect"
|
||||
import { Provider } from "../../src/provider/provider"
|
||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||
|
||||
export namespace ProviderTest {
|
||||
export function model(override: Partial<Provider.Model> = {}): Provider.Model {
|
||||
const id = override.id ?? ModelID.make("gpt-5.2")
|
||||
const providerID = override.providerID ?? ProviderID.make("openai")
|
||||
return {
|
||||
id,
|
||||
providerID,
|
||||
name: "Test Model",
|
||||
capabilities: {
|
||||
toolcall: true,
|
||||
attachment: false,
|
||||
reasoning: false,
|
||||
temperature: true,
|
||||
interleaved: false,
|
||||
input: { text: true, image: false, audio: false, video: false, pdf: false },
|
||||
output: { text: true, image: false, audio: false, video: false, pdf: false },
|
||||
},
|
||||
api: { id, url: "https://example.com", npm: "@ai-sdk/openai" },
|
||||
cost: { input: 0, output: 0, cache: { read: 0, write: 0 } },
|
||||
limit: { context: 200_000, output: 10_000 },
|
||||
status: "active",
|
||||
options: {},
|
||||
headers: {},
|
||||
release_date: "2025-01-01",
|
||||
...override,
|
||||
}
|
||||
}
|
||||
|
||||
export function info(override: Partial<Provider.Info> = {}, mdl = model()): Provider.Info {
|
||||
const id = override.id ?? mdl.providerID
|
||||
return {
|
||||
id,
|
||||
name: "Test Provider",
|
||||
source: "config",
|
||||
env: [],
|
||||
options: {},
|
||||
models: { [mdl.id]: mdl },
|
||||
...override,
|
||||
}
|
||||
}
|
||||
|
||||
export function fake(override: Partial<Provider.Interface> & { model?: Provider.Model; info?: Provider.Info } = {}) {
|
||||
const mdl = override.model ?? model()
|
||||
const row = override.info ?? info({}, mdl)
|
||||
return {
|
||||
model: mdl,
|
||||
info: row,
|
||||
layer: Layer.succeed(
|
||||
Provider.Service,
|
||||
Provider.Service.of({
|
||||
list: Effect.fn("TestProvider.list")(() => Effect.succeed({ [row.id]: row })),
|
||||
getProvider: Effect.fn("TestProvider.getProvider")((providerID) => {
|
||||
if (providerID === row.id) return Effect.succeed(row)
|
||||
return Effect.die(new Error(`Unknown test provider: ${providerID}`))
|
||||
}),
|
||||
getModel: Effect.fn("TestProvider.getModel")((providerID, modelID) => {
|
||||
if (providerID === row.id && modelID === mdl.id) return Effect.succeed(mdl)
|
||||
return Effect.die(new Error(`Unknown test model: ${providerID}/${modelID}`))
|
||||
}),
|
||||
getLanguage: Effect.fn("TestProvider.getLanguage")(() =>
|
||||
Effect.die(new Error("ProviderTest.getLanguage not configured")),
|
||||
),
|
||||
closest: Effect.fn("TestProvider.closest")((providerID) =>
|
||||
Effect.succeed(providerID === row.id ? { providerID: row.id, modelID: mdl.id } : undefined),
|
||||
),
|
||||
getSmallModel: Effect.fn("TestProvider.getSmallModel")((providerID) =>
|
||||
Effect.succeed(providerID === row.id ? mdl : undefined),
|
||||
),
|
||||
defaultModel: Effect.fn("TestProvider.defaultModel")(() =>
|
||||
Effect.succeed({ providerID: row.id, modelID: mdl.id }),
|
||||
),
|
||||
...override,
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
|
||||
import { afterEach, describe, expect, mock, test } from "bun:test"
|
||||
import { APICallError } from "ai"
|
||||
import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect"
|
||||
import * as Stream from "effect/Stream"
|
||||
|
|
@ -20,9 +20,9 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema"
|
|||
import { SessionStatus } from "../../src/session/status"
|
||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||
import type { Provider } from "../../src/provider/provider"
|
||||
import * as ProviderModule from "../../src/provider/provider"
|
||||
import * as SessionProcessorModule from "../../src/session/processor"
|
||||
import { Snapshot } from "../../src/snapshot"
|
||||
import { ProviderTest } from "../fake/provider"
|
||||
|
||||
Log.init({ print: false })
|
||||
|
||||
|
|
@ -65,6 +65,8 @@ function createModel(opts: {
|
|||
} as Provider.Model
|
||||
}
|
||||
|
||||
const wide = () => ProviderTest.fake({ model: createModel({ context: 100_000, output: 32_000 }) })
|
||||
|
||||
async function user(sessionID: SessionID, text: string) {
|
||||
const msg = await Session.updateMessage({
|
||||
id: MessageID.ascending(),
|
||||
|
|
@ -162,10 +164,11 @@ function layer(result: "continue" | "compact") {
|
|||
)
|
||||
}
|
||||
|
||||
function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer) {
|
||||
function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer, provider = ProviderTest.fake()) {
|
||||
const bus = Bus.layer
|
||||
return ManagedRuntime.make(
|
||||
Layer.mergeAll(SessionCompaction.layer, bus).pipe(
|
||||
Layer.provide(provider.layer),
|
||||
Layer.provide(Session.defaultLayer),
|
||||
Layer.provide(layer(result)),
|
||||
Layer.provide(Agent.defaultLayer),
|
||||
|
|
@ -198,12 +201,13 @@ function llm() {
|
|||
}
|
||||
}
|
||||
|
||||
function liveRuntime(layer: Layer.Layer<LLM.Service>) {
|
||||
function liveRuntime(layer: Layer.Layer<LLM.Service>, provider = ProviderTest.fake()) {
|
||||
const bus = Bus.layer
|
||||
const status = SessionStatus.layer.pipe(Layer.provide(bus))
|
||||
const processor = SessionProcessorModule.SessionProcessor.layer
|
||||
return ManagedRuntime.make(
|
||||
Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe(
|
||||
Layer.provide(provider.layer),
|
||||
Layer.provide(Session.defaultLayer),
|
||||
Layer.provide(Snapshot.defaultLayer),
|
||||
Layer.provide(layer),
|
||||
|
|
@ -544,14 +548,12 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const done = defer()
|
||||
let seen = false
|
||||
const rt = runtime("continue")
|
||||
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||
let unsub: (() => void) | undefined
|
||||
try {
|
||||
unsub = await rt.runPromise(
|
||||
|
|
@ -596,11 +598,9 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const rt = runtime("compact")
|
||||
const rt = runtime("compact", Plugin.defaultLayer, wide())
|
||||
try {
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const result = await rt.runPromise(
|
||||
|
|
@ -636,11 +636,9 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const rt = runtime("continue")
|
||||
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||
try {
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const result = await rt.runPromise(
|
||||
|
|
@ -678,8 +676,6 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
await user(session.id, "root")
|
||||
const replay = await user(session.id, "image")
|
||||
|
|
@ -693,7 +689,7 @@ describe("session.compaction.process", () => {
|
|||
url: "https://example.com/cat.png",
|
||||
})
|
||||
const msg = await user(session.id, "current")
|
||||
const rt = runtime("continue")
|
||||
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||
try {
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const result = await rt.runPromise(
|
||||
|
|
@ -728,13 +724,11 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
await user(session.id, "earlier")
|
||||
const msg = await user(session.id, "current")
|
||||
|
||||
const rt = runtime("continue")
|
||||
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||
try {
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const result = await rt.runPromise(
|
||||
|
|
@ -790,13 +784,11 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const abort = new AbortController()
|
||||
const rt = liveRuntime(stub.layer)
|
||||
const rt = liveRuntime(stub.layer, wide())
|
||||
let off: (() => void) | undefined
|
||||
let run: Promise<"continue" | "stop"> | undefined
|
||||
try {
|
||||
|
|
@ -866,13 +858,11 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
const abort = new AbortController()
|
||||
const rt = runtime("continue", plugin(ready))
|
||||
const rt = runtime("continue", plugin(ready), wide())
|
||||
let run: Promise<"continue" | "stop"> | undefined
|
||||
try {
|
||||
run = rt
|
||||
|
|
@ -970,11 +960,9 @@ describe("session.compaction.process", () => {
|
|||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
||||
|
||||
const session = await Session.create({})
|
||||
const msg = await user(session.id, "hello")
|
||||
const rt = liveRuntime(stub.layer)
|
||||
const rt = liveRuntime(stub.layer, wide())
|
||||
try {
|
||||
const msgs = await Session.messages({ sessionID: session.id })
|
||||
await rt.runPromise(
|
||||
|
|
|
|||
|
|
@ -1,247 +0,0 @@
|
|||
import { describe, expect, spyOn, test } from "bun:test"
|
||||
import { Instance } from "../../src/project/instance"
|
||||
import { Provider } from "../../src/provider/provider"
|
||||
import { Session } from "../../src/session"
|
||||
import { MessageV2 } from "../../src/session/message-v2"
|
||||
import { SessionPrompt } from "../../src/session/prompt"
|
||||
import { SessionStatus } from "../../src/session/status"
|
||||
import { MessageID, PartID, SessionID } from "../../src/session/schema"
|
||||
import { Log } from "../../src/util/log"
|
||||
import { tmpdir } from "../fixture/fixture"
|
||||
|
||||
Log.init({ print: false })
|
||||
|
||||
function deferred() {
|
||||
let resolve!: () => void
|
||||
const promise = new Promise<void>((done) => {
|
||||
resolve = done
|
||||
})
|
||||
return { promise, resolve }
|
||||
}
|
||||
|
||||
// Helper: seed a session with a user message + finished assistant message
|
||||
// so loop() exits immediately without calling any LLM
|
||||
async function seed(sessionID: SessionID) {
|
||||
const userMsg: MessageV2.Info = {
|
||||
id: MessageID.ascending(),
|
||||
role: "user",
|
||||
sessionID,
|
||||
time: { created: Date.now() },
|
||||
agent: "build",
|
||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
||||
}
|
||||
await Session.updateMessage(userMsg)
|
||||
await Session.updatePart({
|
||||
id: PartID.ascending(),
|
||||
messageID: userMsg.id,
|
||||
sessionID,
|
||||
type: "text",
|
||||
text: "hello",
|
||||
})
|
||||
|
||||
const assistantMsg: MessageV2.Info = {
|
||||
id: MessageID.ascending(),
|
||||
role: "assistant",
|
||||
parentID: userMsg.id,
|
||||
sessionID,
|
||||
mode: "build",
|
||||
agent: "build",
|
||||
cost: 0,
|
||||
path: { cwd: "/tmp", root: "/tmp" },
|
||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
||||
modelID: "gpt-5.2" as any,
|
||||
providerID: "openai" as any,
|
||||
time: { created: Date.now(), completed: Date.now() },
|
||||
finish: "stop",
|
||||
}
|
||||
await Session.updateMessage(assistantMsg)
|
||||
await Session.updatePart({
|
||||
id: PartID.ascending(),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID,
|
||||
type: "text",
|
||||
text: "hi there",
|
||||
})
|
||||
|
||||
return { userMsg, assistantMsg }
|
||||
}
|
||||
|
||||
describe("session.prompt concurrency", () => {
|
||||
test("loop returns assistant message and sets status to idle", async () => {
|
||||
await using tmp = await tmpdir({ git: true })
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const session = await Session.create({})
|
||||
await seed(session.id)
|
||||
|
||||
const result = await SessionPrompt.loop({ sessionID: session.id })
|
||||
expect(result.info.role).toBe("assistant")
|
||||
if (result.info.role === "assistant") expect(result.info.finish).toBe("stop")
|
||||
|
||||
const status = await SessionStatus.get(session.id)
|
||||
expect(status.type).toBe("idle")
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("concurrent loop callers get the same result", async () => {
|
||||
await using tmp = await tmpdir({ git: true })
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const session = await Session.create({})
|
||||
await seed(session.id)
|
||||
|
||||
const [a, b] = await Promise.all([
|
||||
SessionPrompt.loop({ sessionID: session.id }),
|
||||
SessionPrompt.loop({ sessionID: session.id }),
|
||||
])
|
||||
|
||||
expect(a.info.id).toBe(b.info.id)
|
||||
expect(a.info.role).toBe("assistant")
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("assertNotBusy throws when loop is running", async () => {
|
||||
await using tmp = await tmpdir({ git: true })
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const session = await Session.create({})
|
||||
const userMsg: MessageV2.Info = {
|
||||
id: MessageID.ascending(),
|
||||
role: "user",
|
||||
sessionID: session.id,
|
||||
time: { created: Date.now() },
|
||||
agent: "build",
|
||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
||||
}
|
||||
await Session.updateMessage(userMsg)
|
||||
await Session.updatePart({
|
||||
id: PartID.ascending(),
|
||||
messageID: userMsg.id,
|
||||
sessionID: session.id,
|
||||
type: "text",
|
||||
text: "hello",
|
||||
})
|
||||
|
||||
const ready = deferred()
|
||||
const gate = deferred()
|
||||
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
|
||||
ready.resolve()
|
||||
await gate.promise
|
||||
throw new Error("test stop")
|
||||
})
|
||||
|
||||
try {
|
||||
const loopPromise = SessionPrompt.loop({ sessionID: session.id }).catch(() => undefined)
|
||||
await ready.promise
|
||||
|
||||
await expect(SessionPrompt.assertNotBusy(session.id)).rejects.toBeInstanceOf(Session.BusyError)
|
||||
|
||||
gate.resolve()
|
||||
await loopPromise
|
||||
} finally {
|
||||
gate.resolve()
|
||||
getModel.mockRestore()
|
||||
}
|
||||
|
||||
// After loop completes, assertNotBusy should succeed
|
||||
await SessionPrompt.assertNotBusy(session.id)
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("cancel sets status to idle", async () => {
|
||||
await using tmp = await tmpdir({ git: true })
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const session = await Session.create({})
|
||||
// Seed only a user message — loop must call getModel to proceed
|
||||
const userMsg: MessageV2.Info = {
|
||||
id: MessageID.ascending(),
|
||||
role: "user",
|
||||
sessionID: session.id,
|
||||
time: { created: Date.now() },
|
||||
agent: "build",
|
||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
||||
}
|
||||
await Session.updateMessage(userMsg)
|
||||
await Session.updatePart({
|
||||
id: PartID.ascending(),
|
||||
messageID: userMsg.id,
|
||||
sessionID: session.id,
|
||||
type: "text",
|
||||
text: "hello",
|
||||
})
|
||||
// Also seed an assistant message so lastAssistant() fallback can find it
|
||||
const assistantMsg: MessageV2.Info = {
|
||||
id: MessageID.ascending(),
|
||||
role: "assistant",
|
||||
parentID: userMsg.id,
|
||||
sessionID: session.id,
|
||||
mode: "build",
|
||||
agent: "build",
|
||||
cost: 0,
|
||||
path: { cwd: "/tmp", root: "/tmp" },
|
||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
||||
modelID: "gpt-5.2" as any,
|
||||
providerID: "openai" as any,
|
||||
time: { created: Date.now() },
|
||||
}
|
||||
await Session.updateMessage(assistantMsg)
|
||||
await Session.updatePart({
|
||||
id: PartID.ascending(),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: session.id,
|
||||
type: "text",
|
||||
text: "hi there",
|
||||
})
|
||||
|
||||
const ready = deferred()
|
||||
const gate = deferred()
|
||||
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
|
||||
ready.resolve()
|
||||
await gate.promise
|
||||
throw new Error("test stop")
|
||||
})
|
||||
|
||||
try {
|
||||
// Start loop — it will block in getModel (assistant has no finish, so loop continues)
|
||||
const loopPromise = SessionPrompt.loop({ sessionID: session.id })
|
||||
|
||||
await ready.promise
|
||||
|
||||
await SessionPrompt.cancel(session.id)
|
||||
|
||||
const status = await SessionStatus.get(session.id)
|
||||
expect(status.type).toBe("idle")
|
||||
|
||||
// loop should resolve cleanly, not throw "All fibers interrupted"
|
||||
const result = await loopPromise
|
||||
expect(result.info.role).toBe("assistant")
|
||||
expect(result.info.id).toBe(assistantMsg.id)
|
||||
} finally {
|
||||
gate.resolve()
|
||||
getModel.mockRestore()
|
||||
}
|
||||
},
|
||||
})
|
||||
}, 10000)
|
||||
|
||||
test("cancel on idle session just sets idle", async () => {
|
||||
await using tmp = await tmpdir({ git: true })
|
||||
await Instance.provide({
|
||||
directory: tmp.path,
|
||||
fn: async () => {
|
||||
const session = await Session.create({})
|
||||
await SessionPrompt.cancel(session.id)
|
||||
const status = await SessionStatus.get(session.id)
|
||||
expect(status.type).toBe("idle")
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -12,6 +12,7 @@ import { LSP } from "../../src/lsp"
|
|||
import { MCP } from "../../src/mcp"
|
||||
import { Permission } from "../../src/permission"
|
||||
import { Plugin } from "../../src/plugin"
|
||||
import { Provider as ProviderSvc } from "../../src/provider/provider"
|
||||
import type { Provider } from "../../src/provider/provider"
|
||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||
import { Session } from "../../src/session"
|
||||
|
|
@ -151,6 +152,7 @@ function makeHttp() {
|
|||
Permission.layer,
|
||||
Plugin.defaultLayer,
|
||||
Config.defaultLayer,
|
||||
ProviderSvc.defaultLayer,
|
||||
filetime,
|
||||
lsp,
|
||||
mcp,
|
||||
|
|
|
|||
Loading…
Reference in New Issue