refactor(prompt): use Provider service in effect layers (#20167)
parent
6314f09c14
commit
181b5f6236
|
|
@ -75,6 +75,7 @@ export namespace Agent {
|
||||||
const config = yield* Config.Service
|
const config = yield* Config.Service
|
||||||
const auth = yield* Auth.Service
|
const auth = yield* Auth.Service
|
||||||
const skill = yield* Skill.Service
|
const skill = yield* Skill.Service
|
||||||
|
const provider = yield* Provider.Service
|
||||||
|
|
||||||
const state = yield* InstanceState.make<State>(
|
const state = yield* InstanceState.make<State>(
|
||||||
Effect.fn("Agent.state")(function* (ctx) {
|
Effect.fn("Agent.state")(function* (ctx) {
|
||||||
|
|
@ -330,9 +331,9 @@ export namespace Agent {
|
||||||
model?: { providerID: ProviderID; modelID: ModelID }
|
model?: { providerID: ProviderID; modelID: ModelID }
|
||||||
}) {
|
}) {
|
||||||
const cfg = yield* config.get()
|
const cfg = yield* config.get()
|
||||||
const model = input.model ?? (yield* Effect.promise(() => Provider.defaultModel()))
|
const model = input.model ?? (yield* provider.defaultModel())
|
||||||
const resolved = yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID))
|
const resolved = yield* provider.getModel(model.providerID, model.modelID)
|
||||||
const language = yield* Effect.promise(() => Provider.getLanguage(resolved))
|
const language = yield* provider.getLanguage(resolved)
|
||||||
|
|
||||||
const system = [PROMPT_GENERATE]
|
const system = [PROMPT_GENERATE]
|
||||||
yield* Effect.promise(() =>
|
yield* Effect.promise(() =>
|
||||||
|
|
@ -393,6 +394,7 @@ export namespace Agent {
|
||||||
)
|
)
|
||||||
|
|
||||||
export const defaultLayer = layer.pipe(
|
export const defaultLayer = layer.pipe(
|
||||||
|
Layer.provide(Provider.defaultLayer),
|
||||||
Layer.provide(Auth.defaultLayer),
|
Layer.provide(Auth.defaultLayer),
|
||||||
Layer.provide(Config.defaultLayer),
|
Layer.provide(Config.defaultLayer),
|
||||||
Layer.provide(Skill.defaultLayer),
|
Layer.provide(Skill.defaultLayer),
|
||||||
|
|
|
||||||
|
|
@ -1541,10 +1541,9 @@ export namespace Provider {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
const { runPromise } = makeRuntime(
|
export const defaultLayer = layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer))
|
||||||
Service,
|
|
||||||
layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer)),
|
const { runPromise } = makeRuntime(Service, defaultLayer)
|
||||||
)
|
|
||||||
|
|
||||||
export async function list() {
|
export async function list() {
|
||||||
return runPromise((svc) => svc.list())
|
return runPromise((svc) => svc.list())
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,13 @@ export namespace SessionCompaction {
|
||||||
export const layer: Layer.Layer<
|
export const layer: Layer.Layer<
|
||||||
Service,
|
Service,
|
||||||
never,
|
never,
|
||||||
Bus.Service | Config.Service | Session.Service | Agent.Service | Plugin.Service | SessionProcessor.Service
|
| Bus.Service
|
||||||
|
| Config.Service
|
||||||
|
| Session.Service
|
||||||
|
| Agent.Service
|
||||||
|
| Plugin.Service
|
||||||
|
| SessionProcessor.Service
|
||||||
|
| Provider.Service
|
||||||
> = Layer.effect(
|
> = Layer.effect(
|
||||||
Service,
|
Service,
|
||||||
Effect.gen(function* () {
|
Effect.gen(function* () {
|
||||||
|
|
@ -73,6 +79,7 @@ export namespace SessionCompaction {
|
||||||
const agents = yield* Agent.Service
|
const agents = yield* Agent.Service
|
||||||
const plugin = yield* Plugin.Service
|
const plugin = yield* Plugin.Service
|
||||||
const processors = yield* SessionProcessor.Service
|
const processors = yield* SessionProcessor.Service
|
||||||
|
const provider = yield* Provider.Service
|
||||||
|
|
||||||
const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: {
|
const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: {
|
||||||
tokens: MessageV2.Assistant["tokens"]
|
tokens: MessageV2.Assistant["tokens"]
|
||||||
|
|
@ -170,11 +177,9 @@ export namespace SessionCompaction {
|
||||||
}
|
}
|
||||||
|
|
||||||
const agent = yield* agents.get("compaction")
|
const agent = yield* agents.get("compaction")
|
||||||
const model = yield* Effect.promise(() =>
|
const model = agent.model
|
||||||
agent.model
|
? yield* provider.getModel(agent.model.providerID, agent.model.modelID)
|
||||||
? Provider.getModel(agent.model.providerID, agent.model.modelID)
|
: yield* provider.getModel(userMessage.model.providerID, userMessage.model.modelID)
|
||||||
: Provider.getModel(userMessage.model.providerID, userMessage.model.modelID),
|
|
||||||
)
|
|
||||||
// Allow plugins to inject context or replace compaction prompt.
|
// Allow plugins to inject context or replace compaction prompt.
|
||||||
const compacting = yield* plugin.trigger(
|
const compacting = yield* plugin.trigger(
|
||||||
"experimental.session.compacting",
|
"experimental.session.compacting",
|
||||||
|
|
@ -377,6 +382,7 @@ When constructing the summary, try to stick to this template:
|
||||||
export const defaultLayer = Layer.unwrap(
|
export const defaultLayer = Layer.unwrap(
|
||||||
Effect.sync(() =>
|
Effect.sync(() =>
|
||||||
layer.pipe(
|
layer.pipe(
|
||||||
|
Layer.provide(Provider.defaultLayer),
|
||||||
Layer.provide(Session.defaultLayer),
|
Layer.provide(Session.defaultLayer),
|
||||||
Layer.provide(SessionProcessor.defaultLayer),
|
Layer.provide(SessionProcessor.defaultLayer),
|
||||||
Layer.provide(Agent.defaultLayer),
|
Layer.provide(Agent.defaultLayer),
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,7 @@ export namespace SessionPrompt {
|
||||||
const status = yield* SessionStatus.Service
|
const status = yield* SessionStatus.Service
|
||||||
const sessions = yield* Session.Service
|
const sessions = yield* Session.Service
|
||||||
const agents = yield* Agent.Service
|
const agents = yield* Agent.Service
|
||||||
|
const provider = yield* Provider.Service
|
||||||
const processor = yield* SessionProcessor.Service
|
const processor = yield* SessionProcessor.Service
|
||||||
const compaction = yield* SessionCompaction.Service
|
const compaction = yield* SessionCompaction.Service
|
||||||
const plugin = yield* Plugin.Service
|
const plugin = yield* Plugin.Service
|
||||||
|
|
@ -206,14 +207,14 @@ export namespace SessionPrompt {
|
||||||
|
|
||||||
const ag = yield* agents.get("title")
|
const ag = yield* agents.get("title")
|
||||||
if (!ag) return
|
if (!ag) return
|
||||||
const text = yield* Effect.promise(async (signal) => {
|
|
||||||
const mdl = ag.model
|
const mdl = ag.model
|
||||||
? await Provider.getModel(ag.model.providerID, ag.model.modelID)
|
? yield* provider.getModel(ag.model.providerID, ag.model.modelID)
|
||||||
: ((await Provider.getSmallModel(input.providerID)) ??
|
: ((yield* provider.getSmallModel(input.providerID)) ??
|
||||||
(await Provider.getModel(input.providerID, input.modelID)))
|
(yield* provider.getModel(input.providerID, input.modelID)))
|
||||||
const msgs = onlySubtasks
|
const msgs = onlySubtasks
|
||||||
? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
|
? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
|
||||||
: await MessageV2.toModelMessages(context, mdl)
|
: yield* Effect.promise(() => MessageV2.toModelMessages(context, mdl))
|
||||||
|
const text = yield* Effect.promise(async (signal) => {
|
||||||
const result = await LLM.stream({
|
const result = await LLM.stream({
|
||||||
agent: ag,
|
agent: ag,
|
||||||
user: firstInfo,
|
user: firstInfo,
|
||||||
|
|
@ -932,21 +933,35 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||||
return { info: msg, parts: [part] }
|
return { info: msg, parts: [part] }
|
||||||
})
|
})
|
||||||
|
|
||||||
const getModel = (providerID: ProviderID, modelID: ModelID, sessionID: SessionID) =>
|
const getModel = Effect.fn("SessionPrompt.getModel")(function* (
|
||||||
Effect.promise(() =>
|
providerID: ProviderID,
|
||||||
Provider.getModel(providerID, modelID).catch((e) => {
|
modelID: ModelID,
|
||||||
if (Provider.ModelNotFoundError.isInstance(e)) {
|
sessionID: SessionID,
|
||||||
const hint = e.data.suggestions?.length ? ` Did you mean: ${e.data.suggestions.join(", ")}?` : ""
|
) {
|
||||||
Bus.publish(Session.Event.Error, {
|
const exit = yield* provider.getModel(providerID, modelID).pipe(Effect.exit)
|
||||||
|
if (Exit.isSuccess(exit)) return exit.value
|
||||||
|
const err = Cause.squash(exit.cause)
|
||||||
|
if (Provider.ModelNotFoundError.isInstance(err)) {
|
||||||
|
const hint = err.data.suggestions?.length ? ` Did you mean: ${err.data.suggestions.join(", ")}?` : ""
|
||||||
|
yield* bus.publish(Session.Event.Error, {
|
||||||
sessionID,
|
sessionID,
|
||||||
error: new NamedError.Unknown({
|
error: new NamedError.Unknown({
|
||||||
message: `Model not found: ${e.data.providerID}/${e.data.modelID}.${hint}`,
|
message: `Model not found: ${err.data.providerID}/${err.data.modelID}.${hint}`,
|
||||||
}).toObject(),
|
}).toObject(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
throw e
|
return yield* Effect.failCause(exit.cause)
|
||||||
}),
|
})
|
||||||
)
|
|
||||||
|
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
|
||||||
|
const model = yield* Effect.promise(async () => {
|
||||||
|
for await (const item of MessageV2.stream(sessionID)) {
|
||||||
|
if (item.info.role === "user" && item.info.model) return item.info.model
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if (model) return model
|
||||||
|
return yield* provider.defaultModel()
|
||||||
|
})
|
||||||
|
|
||||||
const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) {
|
const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) {
|
||||||
const agentName = input.agent || (yield* agents.defaultAgent())
|
const agentName = input.agent || (yield* agents.defaultAgent())
|
||||||
|
|
@ -960,9 +975,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID))
|
const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID))
|
||||||
|
const same = ag.model && model.providerID === ag.model.providerID && model.modelID === ag.model.modelID
|
||||||
const full =
|
const full =
|
||||||
!input.variant && ag.variant
|
!input.variant && ag.variant && same
|
||||||
? yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID).catch(() => undefined))
|
? yield* provider
|
||||||
|
.getModel(model.providerID, model.modelID)
|
||||||
|
.pipe(Effect.catch(() => Effect.succeed(undefined)))
|
||||||
: undefined
|
: undefined
|
||||||
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
|
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
|
||||||
|
|
||||||
|
|
@ -1109,7 +1127,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||||
]
|
]
|
||||||
const read = yield* Effect.promise(() => ReadTool.init()).pipe(
|
const read = yield* Effect.promise(() => ReadTool.init()).pipe(
|
||||||
Effect.flatMap((t) =>
|
Effect.flatMap((t) =>
|
||||||
Effect.promise(() => Provider.getModel(info.model.providerID, info.model.modelID)).pipe(
|
provider.getModel(info.model.providerID, info.model.modelID).pipe(
|
||||||
Effect.flatMap((mdl) =>
|
Effect.flatMap((mdl) =>
|
||||||
Effect.promise(() =>
|
Effect.promise(() =>
|
||||||
t.execute(args, {
|
t.execute(args, {
|
||||||
|
|
@ -1711,6 +1729,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||||
Layer.provide(FileTime.defaultLayer),
|
Layer.provide(FileTime.defaultLayer),
|
||||||
Layer.provide(ToolRegistry.defaultLayer),
|
Layer.provide(ToolRegistry.defaultLayer),
|
||||||
Layer.provide(Truncate.layer),
|
Layer.provide(Truncate.layer),
|
||||||
|
Layer.provide(Provider.defaultLayer),
|
||||||
Layer.provide(AppFileSystem.defaultLayer),
|
Layer.provide(AppFileSystem.defaultLayer),
|
||||||
Layer.provide(Plugin.defaultLayer),
|
Layer.provide(Plugin.defaultLayer),
|
||||||
Layer.provide(Session.defaultLayer),
|
Layer.provide(Session.defaultLayer),
|
||||||
|
|
@ -1856,15 +1875,6 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||||
return runPromise((svc) => svc.command(CommandInput.parse(input)))
|
return runPromise((svc) => svc.command(CommandInput.parse(input)))
|
||||||
}
|
}
|
||||||
|
|
||||||
const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
|
|
||||||
return yield* Effect.promise(async () => {
|
|
||||||
for await (const item of MessageV2.stream(sessionID)) {
|
|
||||||
if (item.info.role === "user" && item.info.model) return item.info.model
|
|
||||||
}
|
|
||||||
return Provider.defaultModel()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
/** @internal Exported for testing */
|
/** @internal Exported for testing */
|
||||||
export function createStructuredOutputTool(input: {
|
export function createStructuredOutputTool(input: {
|
||||||
schema: Record<string, any>
|
schema: Record<string, any>
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
import { Effect, Layer } from "effect"
|
||||||
|
import { Provider } from "../../src/provider/provider"
|
||||||
|
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||||
|
|
||||||
|
export namespace ProviderTest {
|
||||||
|
export function model(override: Partial<Provider.Model> = {}): Provider.Model {
|
||||||
|
const id = override.id ?? ModelID.make("gpt-5.2")
|
||||||
|
const providerID = override.providerID ?? ProviderID.make("openai")
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
providerID,
|
||||||
|
name: "Test Model",
|
||||||
|
capabilities: {
|
||||||
|
toolcall: true,
|
||||||
|
attachment: false,
|
||||||
|
reasoning: false,
|
||||||
|
temperature: true,
|
||||||
|
interleaved: false,
|
||||||
|
input: { text: true, image: false, audio: false, video: false, pdf: false },
|
||||||
|
output: { text: true, image: false, audio: false, video: false, pdf: false },
|
||||||
|
},
|
||||||
|
api: { id, url: "https://example.com", npm: "@ai-sdk/openai" },
|
||||||
|
cost: { input: 0, output: 0, cache: { read: 0, write: 0 } },
|
||||||
|
limit: { context: 200_000, output: 10_000 },
|
||||||
|
status: "active",
|
||||||
|
options: {},
|
||||||
|
headers: {},
|
||||||
|
release_date: "2025-01-01",
|
||||||
|
...override,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function info(override: Partial<Provider.Info> = {}, mdl = model()): Provider.Info {
|
||||||
|
const id = override.id ?? mdl.providerID
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
name: "Test Provider",
|
||||||
|
source: "config",
|
||||||
|
env: [],
|
||||||
|
options: {},
|
||||||
|
models: { [mdl.id]: mdl },
|
||||||
|
...override,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function fake(override: Partial<Provider.Interface> & { model?: Provider.Model; info?: Provider.Info } = {}) {
|
||||||
|
const mdl = override.model ?? model()
|
||||||
|
const row = override.info ?? info({}, mdl)
|
||||||
|
return {
|
||||||
|
model: mdl,
|
||||||
|
info: row,
|
||||||
|
layer: Layer.succeed(
|
||||||
|
Provider.Service,
|
||||||
|
Provider.Service.of({
|
||||||
|
list: Effect.fn("TestProvider.list")(() => Effect.succeed({ [row.id]: row })),
|
||||||
|
getProvider: Effect.fn("TestProvider.getProvider")((providerID) => {
|
||||||
|
if (providerID === row.id) return Effect.succeed(row)
|
||||||
|
return Effect.die(new Error(`Unknown test provider: ${providerID}`))
|
||||||
|
}),
|
||||||
|
getModel: Effect.fn("TestProvider.getModel")((providerID, modelID) => {
|
||||||
|
if (providerID === row.id && modelID === mdl.id) return Effect.succeed(mdl)
|
||||||
|
return Effect.die(new Error(`Unknown test model: ${providerID}/${modelID}`))
|
||||||
|
}),
|
||||||
|
getLanguage: Effect.fn("TestProvider.getLanguage")(() =>
|
||||||
|
Effect.die(new Error("ProviderTest.getLanguage not configured")),
|
||||||
|
),
|
||||||
|
closest: Effect.fn("TestProvider.closest")((providerID) =>
|
||||||
|
Effect.succeed(providerID === row.id ? { providerID: row.id, modelID: mdl.id } : undefined),
|
||||||
|
),
|
||||||
|
getSmallModel: Effect.fn("TestProvider.getSmallModel")((providerID) =>
|
||||||
|
Effect.succeed(providerID === row.id ? mdl : undefined),
|
||||||
|
),
|
||||||
|
defaultModel: Effect.fn("TestProvider.defaultModel")(() =>
|
||||||
|
Effect.succeed({ providerID: row.id, modelID: mdl.id }),
|
||||||
|
),
|
||||||
|
...override,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
|
import { afterEach, describe, expect, mock, test } from "bun:test"
|
||||||
import { APICallError } from "ai"
|
import { APICallError } from "ai"
|
||||||
import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect"
|
import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect"
|
||||||
import * as Stream from "effect/Stream"
|
import * as Stream from "effect/Stream"
|
||||||
|
|
@ -20,9 +20,9 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema"
|
||||||
import { SessionStatus } from "../../src/session/status"
|
import { SessionStatus } from "../../src/session/status"
|
||||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||||
import type { Provider } from "../../src/provider/provider"
|
import type { Provider } from "../../src/provider/provider"
|
||||||
import * as ProviderModule from "../../src/provider/provider"
|
|
||||||
import * as SessionProcessorModule from "../../src/session/processor"
|
import * as SessionProcessorModule from "../../src/session/processor"
|
||||||
import { Snapshot } from "../../src/snapshot"
|
import { Snapshot } from "../../src/snapshot"
|
||||||
|
import { ProviderTest } from "../fake/provider"
|
||||||
|
|
||||||
Log.init({ print: false })
|
Log.init({ print: false })
|
||||||
|
|
||||||
|
|
@ -65,6 +65,8 @@ function createModel(opts: {
|
||||||
} as Provider.Model
|
} as Provider.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const wide = () => ProviderTest.fake({ model: createModel({ context: 100_000, output: 32_000 }) })
|
||||||
|
|
||||||
async function user(sessionID: SessionID, text: string) {
|
async function user(sessionID: SessionID, text: string) {
|
||||||
const msg = await Session.updateMessage({
|
const msg = await Session.updateMessage({
|
||||||
id: MessageID.ascending(),
|
id: MessageID.ascending(),
|
||||||
|
|
@ -162,10 +164,11 @@ function layer(result: "continue" | "compact") {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer) {
|
function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer, provider = ProviderTest.fake()) {
|
||||||
const bus = Bus.layer
|
const bus = Bus.layer
|
||||||
return ManagedRuntime.make(
|
return ManagedRuntime.make(
|
||||||
Layer.mergeAll(SessionCompaction.layer, bus).pipe(
|
Layer.mergeAll(SessionCompaction.layer, bus).pipe(
|
||||||
|
Layer.provide(provider.layer),
|
||||||
Layer.provide(Session.defaultLayer),
|
Layer.provide(Session.defaultLayer),
|
||||||
Layer.provide(layer(result)),
|
Layer.provide(layer(result)),
|
||||||
Layer.provide(Agent.defaultLayer),
|
Layer.provide(Agent.defaultLayer),
|
||||||
|
|
@ -198,12 +201,13 @@ function llm() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function liveRuntime(layer: Layer.Layer<LLM.Service>) {
|
function liveRuntime(layer: Layer.Layer<LLM.Service>, provider = ProviderTest.fake()) {
|
||||||
const bus = Bus.layer
|
const bus = Bus.layer
|
||||||
const status = SessionStatus.layer.pipe(Layer.provide(bus))
|
const status = SessionStatus.layer.pipe(Layer.provide(bus))
|
||||||
const processor = SessionProcessorModule.SessionProcessor.layer
|
const processor = SessionProcessorModule.SessionProcessor.layer
|
||||||
return ManagedRuntime.make(
|
return ManagedRuntime.make(
|
||||||
Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe(
|
Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe(
|
||||||
|
Layer.provide(provider.layer),
|
||||||
Layer.provide(Session.defaultLayer),
|
Layer.provide(Session.defaultLayer),
|
||||||
Layer.provide(Snapshot.defaultLayer),
|
Layer.provide(Snapshot.defaultLayer),
|
||||||
Layer.provide(layer),
|
Layer.provide(layer),
|
||||||
|
|
@ -544,14 +548,12 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const done = defer()
|
const done = defer()
|
||||||
let seen = false
|
let seen = false
|
||||||
const rt = runtime("continue")
|
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||||
let unsub: (() => void) | undefined
|
let unsub: (() => void) | undefined
|
||||||
try {
|
try {
|
||||||
unsub = await rt.runPromise(
|
unsub = await rt.runPromise(
|
||||||
|
|
@ -596,11 +598,9 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const rt = runtime("compact")
|
const rt = runtime("compact", Plugin.defaultLayer, wide())
|
||||||
try {
|
try {
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const result = await rt.runPromise(
|
const result = await rt.runPromise(
|
||||||
|
|
@ -636,11 +636,9 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const rt = runtime("continue")
|
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||||
try {
|
try {
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const result = await rt.runPromise(
|
const result = await rt.runPromise(
|
||||||
|
|
@ -678,8 +676,6 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
await user(session.id, "root")
|
await user(session.id, "root")
|
||||||
const replay = await user(session.id, "image")
|
const replay = await user(session.id, "image")
|
||||||
|
|
@ -693,7 +689,7 @@ describe("session.compaction.process", () => {
|
||||||
url: "https://example.com/cat.png",
|
url: "https://example.com/cat.png",
|
||||||
})
|
})
|
||||||
const msg = await user(session.id, "current")
|
const msg = await user(session.id, "current")
|
||||||
const rt = runtime("continue")
|
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||||
try {
|
try {
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const result = await rt.runPromise(
|
const result = await rt.runPromise(
|
||||||
|
|
@ -728,13 +724,11 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
await user(session.id, "earlier")
|
await user(session.id, "earlier")
|
||||||
const msg = await user(session.id, "current")
|
const msg = await user(session.id, "current")
|
||||||
|
|
||||||
const rt = runtime("continue")
|
const rt = runtime("continue", Plugin.defaultLayer, wide())
|
||||||
try {
|
try {
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const result = await rt.runPromise(
|
const result = await rt.runPromise(
|
||||||
|
|
@ -790,13 +784,11 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const abort = new AbortController()
|
const abort = new AbortController()
|
||||||
const rt = liveRuntime(stub.layer)
|
const rt = liveRuntime(stub.layer, wide())
|
||||||
let off: (() => void) | undefined
|
let off: (() => void) | undefined
|
||||||
let run: Promise<"continue" | "stop"> | undefined
|
let run: Promise<"continue" | "stop"> | undefined
|
||||||
try {
|
try {
|
||||||
|
|
@ -866,13 +858,11 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
const abort = new AbortController()
|
const abort = new AbortController()
|
||||||
const rt = runtime("continue", plugin(ready))
|
const rt = runtime("continue", plugin(ready), wide())
|
||||||
let run: Promise<"continue" | "stop"> | undefined
|
let run: Promise<"continue" | "stop"> | undefined
|
||||||
try {
|
try {
|
||||||
run = rt
|
run = rt
|
||||||
|
|
@ -970,11 +960,9 @@ describe("session.compaction.process", () => {
|
||||||
await Instance.provide({
|
await Instance.provide({
|
||||||
directory: tmp.path,
|
directory: tmp.path,
|
||||||
fn: async () => {
|
fn: async () => {
|
||||||
spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
|
|
||||||
|
|
||||||
const session = await Session.create({})
|
const session = await Session.create({})
|
||||||
const msg = await user(session.id, "hello")
|
const msg = await user(session.id, "hello")
|
||||||
const rt = liveRuntime(stub.layer)
|
const rt = liveRuntime(stub.layer, wide())
|
||||||
try {
|
try {
|
||||||
const msgs = await Session.messages({ sessionID: session.id })
|
const msgs = await Session.messages({ sessionID: session.id })
|
||||||
await rt.runPromise(
|
await rt.runPromise(
|
||||||
|
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
import { describe, expect, spyOn, test } from "bun:test"
|
|
||||||
import { Instance } from "../../src/project/instance"
|
|
||||||
import { Provider } from "../../src/provider/provider"
|
|
||||||
import { Session } from "../../src/session"
|
|
||||||
import { MessageV2 } from "../../src/session/message-v2"
|
|
||||||
import { SessionPrompt } from "../../src/session/prompt"
|
|
||||||
import { SessionStatus } from "../../src/session/status"
|
|
||||||
import { MessageID, PartID, SessionID } from "../../src/session/schema"
|
|
||||||
import { Log } from "../../src/util/log"
|
|
||||||
import { tmpdir } from "../fixture/fixture"
|
|
||||||
|
|
||||||
Log.init({ print: false })
|
|
||||||
|
|
||||||
function deferred() {
|
|
||||||
let resolve!: () => void
|
|
||||||
const promise = new Promise<void>((done) => {
|
|
||||||
resolve = done
|
|
||||||
})
|
|
||||||
return { promise, resolve }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper: seed a session with a user message + finished assistant message
|
|
||||||
// so loop() exits immediately without calling any LLM
|
|
||||||
async function seed(sessionID: SessionID) {
|
|
||||||
const userMsg: MessageV2.Info = {
|
|
||||||
id: MessageID.ascending(),
|
|
||||||
role: "user",
|
|
||||||
sessionID,
|
|
||||||
time: { created: Date.now() },
|
|
||||||
agent: "build",
|
|
||||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
|
||||||
}
|
|
||||||
await Session.updateMessage(userMsg)
|
|
||||||
await Session.updatePart({
|
|
||||||
id: PartID.ascending(),
|
|
||||||
messageID: userMsg.id,
|
|
||||||
sessionID,
|
|
||||||
type: "text",
|
|
||||||
text: "hello",
|
|
||||||
})
|
|
||||||
|
|
||||||
const assistantMsg: MessageV2.Info = {
|
|
||||||
id: MessageID.ascending(),
|
|
||||||
role: "assistant",
|
|
||||||
parentID: userMsg.id,
|
|
||||||
sessionID,
|
|
||||||
mode: "build",
|
|
||||||
agent: "build",
|
|
||||||
cost: 0,
|
|
||||||
path: { cwd: "/tmp", root: "/tmp" },
|
|
||||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
|
||||||
modelID: "gpt-5.2" as any,
|
|
||||||
providerID: "openai" as any,
|
|
||||||
time: { created: Date.now(), completed: Date.now() },
|
|
||||||
finish: "stop",
|
|
||||||
}
|
|
||||||
await Session.updateMessage(assistantMsg)
|
|
||||||
await Session.updatePart({
|
|
||||||
id: PartID.ascending(),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID,
|
|
||||||
type: "text",
|
|
||||||
text: "hi there",
|
|
||||||
})
|
|
||||||
|
|
||||||
return { userMsg, assistantMsg }
|
|
||||||
}
|
|
||||||
|
|
||||||
describe("session.prompt concurrency", () => {
|
|
||||||
test("loop returns assistant message and sets status to idle", async () => {
|
|
||||||
await using tmp = await tmpdir({ git: true })
|
|
||||||
await Instance.provide({
|
|
||||||
directory: tmp.path,
|
|
||||||
fn: async () => {
|
|
||||||
const session = await Session.create({})
|
|
||||||
await seed(session.id)
|
|
||||||
|
|
||||||
const result = await SessionPrompt.loop({ sessionID: session.id })
|
|
||||||
expect(result.info.role).toBe("assistant")
|
|
||||||
if (result.info.role === "assistant") expect(result.info.finish).toBe("stop")
|
|
||||||
|
|
||||||
const status = await SessionStatus.get(session.id)
|
|
||||||
expect(status.type).toBe("idle")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test("concurrent loop callers get the same result", async () => {
|
|
||||||
await using tmp = await tmpdir({ git: true })
|
|
||||||
await Instance.provide({
|
|
||||||
directory: tmp.path,
|
|
||||||
fn: async () => {
|
|
||||||
const session = await Session.create({})
|
|
||||||
await seed(session.id)
|
|
||||||
|
|
||||||
const [a, b] = await Promise.all([
|
|
||||||
SessionPrompt.loop({ sessionID: session.id }),
|
|
||||||
SessionPrompt.loop({ sessionID: session.id }),
|
|
||||||
])
|
|
||||||
|
|
||||||
expect(a.info.id).toBe(b.info.id)
|
|
||||||
expect(a.info.role).toBe("assistant")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test("assertNotBusy throws when loop is running", async () => {
|
|
||||||
await using tmp = await tmpdir({ git: true })
|
|
||||||
await Instance.provide({
|
|
||||||
directory: tmp.path,
|
|
||||||
fn: async () => {
|
|
||||||
const session = await Session.create({})
|
|
||||||
const userMsg: MessageV2.Info = {
|
|
||||||
id: MessageID.ascending(),
|
|
||||||
role: "user",
|
|
||||||
sessionID: session.id,
|
|
||||||
time: { created: Date.now() },
|
|
||||||
agent: "build",
|
|
||||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
|
||||||
}
|
|
||||||
await Session.updateMessage(userMsg)
|
|
||||||
await Session.updatePart({
|
|
||||||
id: PartID.ascending(),
|
|
||||||
messageID: userMsg.id,
|
|
||||||
sessionID: session.id,
|
|
||||||
type: "text",
|
|
||||||
text: "hello",
|
|
||||||
})
|
|
||||||
|
|
||||||
const ready = deferred()
|
|
||||||
const gate = deferred()
|
|
||||||
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
|
|
||||||
ready.resolve()
|
|
||||||
await gate.promise
|
|
||||||
throw new Error("test stop")
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
|
||||||
const loopPromise = SessionPrompt.loop({ sessionID: session.id }).catch(() => undefined)
|
|
||||||
await ready.promise
|
|
||||||
|
|
||||||
await expect(SessionPrompt.assertNotBusy(session.id)).rejects.toBeInstanceOf(Session.BusyError)
|
|
||||||
|
|
||||||
gate.resolve()
|
|
||||||
await loopPromise
|
|
||||||
} finally {
|
|
||||||
gate.resolve()
|
|
||||||
getModel.mockRestore()
|
|
||||||
}
|
|
||||||
|
|
||||||
// After loop completes, assertNotBusy should succeed
|
|
||||||
await SessionPrompt.assertNotBusy(session.id)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
test("cancel sets status to idle", async () => {
|
|
||||||
await using tmp = await tmpdir({ git: true })
|
|
||||||
await Instance.provide({
|
|
||||||
directory: tmp.path,
|
|
||||||
fn: async () => {
|
|
||||||
const session = await Session.create({})
|
|
||||||
// Seed only a user message — loop must call getModel to proceed
|
|
||||||
const userMsg: MessageV2.Info = {
|
|
||||||
id: MessageID.ascending(),
|
|
||||||
role: "user",
|
|
||||||
sessionID: session.id,
|
|
||||||
time: { created: Date.now() },
|
|
||||||
agent: "build",
|
|
||||||
model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
|
|
||||||
}
|
|
||||||
await Session.updateMessage(userMsg)
|
|
||||||
await Session.updatePart({
|
|
||||||
id: PartID.ascending(),
|
|
||||||
messageID: userMsg.id,
|
|
||||||
sessionID: session.id,
|
|
||||||
type: "text",
|
|
||||||
text: "hello",
|
|
||||||
})
|
|
||||||
// Also seed an assistant message so lastAssistant() fallback can find it
|
|
||||||
const assistantMsg: MessageV2.Info = {
|
|
||||||
id: MessageID.ascending(),
|
|
||||||
role: "assistant",
|
|
||||||
parentID: userMsg.id,
|
|
||||||
sessionID: session.id,
|
|
||||||
mode: "build",
|
|
||||||
agent: "build",
|
|
||||||
cost: 0,
|
|
||||||
path: { cwd: "/tmp", root: "/tmp" },
|
|
||||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
|
||||||
modelID: "gpt-5.2" as any,
|
|
||||||
providerID: "openai" as any,
|
|
||||||
time: { created: Date.now() },
|
|
||||||
}
|
|
||||||
await Session.updateMessage(assistantMsg)
|
|
||||||
await Session.updatePart({
|
|
||||||
id: PartID.ascending(),
|
|
||||||
messageID: assistantMsg.id,
|
|
||||||
sessionID: session.id,
|
|
||||||
type: "text",
|
|
||||||
text: "hi there",
|
|
||||||
})
|
|
||||||
|
|
||||||
const ready = deferred()
|
|
||||||
const gate = deferred()
|
|
||||||
const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
|
|
||||||
ready.resolve()
|
|
||||||
await gate.promise
|
|
||||||
throw new Error("test stop")
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Start loop — it will block in getModel (assistant has no finish, so loop continues)
|
|
||||||
const loopPromise = SessionPrompt.loop({ sessionID: session.id })
|
|
||||||
|
|
||||||
await ready.promise
|
|
||||||
|
|
||||||
await SessionPrompt.cancel(session.id)
|
|
||||||
|
|
||||||
const status = await SessionStatus.get(session.id)
|
|
||||||
expect(status.type).toBe("idle")
|
|
||||||
|
|
||||||
// loop should resolve cleanly, not throw "All fibers interrupted"
|
|
||||||
const result = await loopPromise
|
|
||||||
expect(result.info.role).toBe("assistant")
|
|
||||||
expect(result.info.id).toBe(assistantMsg.id)
|
|
||||||
} finally {
|
|
||||||
gate.resolve()
|
|
||||||
getModel.mockRestore()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}, 10000)
|
|
||||||
|
|
||||||
test("cancel on idle session just sets idle", async () => {
|
|
||||||
await using tmp = await tmpdir({ git: true })
|
|
||||||
await Instance.provide({
|
|
||||||
directory: tmp.path,
|
|
||||||
fn: async () => {
|
|
||||||
const session = await Session.create({})
|
|
||||||
await SessionPrompt.cancel(session.id)
|
|
||||||
const status = await SessionStatus.get(session.id)
|
|
||||||
expect(status.type).toBe("idle")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
@ -12,6 +12,7 @@ import { LSP } from "../../src/lsp"
|
||||||
import { MCP } from "../../src/mcp"
|
import { MCP } from "../../src/mcp"
|
||||||
import { Permission } from "../../src/permission"
|
import { Permission } from "../../src/permission"
|
||||||
import { Plugin } from "../../src/plugin"
|
import { Plugin } from "../../src/plugin"
|
||||||
|
import { Provider as ProviderSvc } from "../../src/provider/provider"
|
||||||
import type { Provider } from "../../src/provider/provider"
|
import type { Provider } from "../../src/provider/provider"
|
||||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||||
import { Session } from "../../src/session"
|
import { Session } from "../../src/session"
|
||||||
|
|
@ -151,6 +152,7 @@ function makeHttp() {
|
||||||
Permission.layer,
|
Permission.layer,
|
||||||
Plugin.defaultLayer,
|
Plugin.defaultLayer,
|
||||||
Config.defaultLayer,
|
Config.defaultLayer,
|
||||||
|
ProviderSvc.defaultLayer,
|
||||||
filetime,
|
filetime,
|
||||||
lsp,
|
lsp,
|
||||||
mcp,
|
mcp,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue