fix(task): preserve subtask cancellation during startup
parent
a3bf978919
commit
6fe0f52a05
|
|
@ -35,6 +35,8 @@ type Input = {
|
|||
prompt: string
|
||||
agent: Agent.Info
|
||||
model: Ref
|
||||
abort?: AbortSignal
|
||||
cancel?: (sessionID: SessionID) => Promise<void> | void
|
||||
start?: (sessionID: SessionID, model: Ref) => Promise<void> | void
|
||||
}
|
||||
|
||||
|
|
@ -61,15 +63,28 @@ export function output(sessionID: SessionID, text: string) {
|
|||
export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) {
|
||||
const cfg = yield* deps.cfg
|
||||
const model = input.agent.model ?? input.model
|
||||
const found = input.taskID ? yield* deps.get(input.taskID) : undefined
|
||||
const session = yield* Effect.uninterruptibleMask((restore) =>
|
||||
Effect.gen(function* () {
|
||||
const found = input.taskID ? yield* restore(deps.get(input.taskID)) : undefined
|
||||
const session = found
|
||||
? found
|
||||
: yield* deps.create({
|
||||
: yield* restore(
|
||||
deps.create({
|
||||
parentID: input.parentID,
|
||||
title: input.description + ` (@${input.agent.name} subagent)`,
|
||||
})
|
||||
}),
|
||||
)
|
||||
|
||||
yield* Effect.promise(() => Promise.resolve(input.start?.(session.id, model)))
|
||||
const start = input.start?.(session.id, model)
|
||||
if (start) yield* Effect.promise(() => Promise.resolve(start))
|
||||
return session
|
||||
}),
|
||||
)
|
||||
|
||||
if (input.abort?.aborted) {
|
||||
const cancel = input.cancel?.(session.id)
|
||||
if (cancel) yield* Effect.promise(() => Promise.resolve(cancel))
|
||||
}
|
||||
|
||||
const result = yield* deps.prompt({
|
||||
sessionID: session.id,
|
||||
|
|
|
|||
|
|
@ -88,6 +88,8 @@ export const TaskTool = Tool.define("task", async (ctx) => {
|
|||
description: params.description,
|
||||
prompt: params.prompt,
|
||||
agent,
|
||||
abort: ctx.abort,
|
||||
cancel: SessionPrompt.cancel,
|
||||
model: {
|
||||
modelID: msg.info.modelID,
|
||||
providerID: msg.info.providerID,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,28 @@
|
|||
import { afterEach, describe, expect, test } from "bun:test"
|
||||
import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
|
||||
import { Agent } from "../../src/agent/agent"
|
||||
import { Config } from "../../src/config/config"
|
||||
import { Instance } from "../../src/project/instance"
|
||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||
import { MessageV2 } from "../../src/session/message-v2"
|
||||
import { SessionPrompt } from "../../src/session/prompt"
|
||||
import { MessageID, SessionID } from "../../src/session/schema"
|
||||
import { Session } from "../../src/session"
|
||||
import { TaskTool } from "../../src/tool/task"
|
||||
import { tmpdir } from "../fixture/fixture"
|
||||
|
||||
afterEach(async () => {
|
||||
mock.restore()
|
||||
await Instance.disposeAll()
|
||||
})
|
||||
|
||||
function wait<T>() {
|
||||
let done!: (value: T | PromiseLike<T>) => void
|
||||
const promise = new Promise<T>((resolve) => {
|
||||
done = resolve
|
||||
})
|
||||
return { promise, done }
|
||||
}
|
||||
|
||||
describe("tool.task", () => {
|
||||
test("description sorts subagents by name and is stable across calls", async () => {
|
||||
await using tmp = await tmpdir({
|
||||
|
|
@ -46,4 +61,73 @@ describe("tool.task", () => {
|
|||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("cancels child session when aborted during creation", async () => {
|
||||
const started = wait<void>()
|
||||
const gate = wait<void>()
|
||||
const parent = SessionID.make("parent")
|
||||
const child = SessionID.make("child")
|
||||
const messageID = MessageID.ascending()
|
||||
const abort = new AbortController()
|
||||
const agent: Agent.Info = {
|
||||
name: "general",
|
||||
description: "General agent",
|
||||
mode: "subagent",
|
||||
options: {},
|
||||
permission: [],
|
||||
}
|
||||
const ref = {
|
||||
providerID: ProviderID.make("test"),
|
||||
modelID: ModelID.make("test-model"),
|
||||
}
|
||||
|
||||
spyOn(Agent, "list").mockResolvedValue([agent])
|
||||
spyOn(Agent, "get").mockResolvedValue(agent)
|
||||
spyOn(Config, "get").mockResolvedValue({ experimental: {} } as Awaited<ReturnType<typeof Config.get>>)
|
||||
spyOn(MessageV2, "get").mockResolvedValue({
|
||||
info: {
|
||||
role: "assistant",
|
||||
providerID: ref.providerID,
|
||||
modelID: ref.modelID,
|
||||
},
|
||||
} as Awaited<ReturnType<typeof MessageV2.get>>)
|
||||
spyOn(Session, "get").mockRejectedValue(new Error("missing"))
|
||||
spyOn(Session, "create").mockImplementation(async () => {
|
||||
started.done()
|
||||
await gate.promise
|
||||
return { id: child } as Awaited<ReturnType<typeof Session.create>>
|
||||
})
|
||||
const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue()
|
||||
spyOn(SessionPrompt, "resolvePromptParts").mockResolvedValue(
|
||||
[] as Awaited<ReturnType<typeof SessionPrompt.resolvePromptParts>>,
|
||||
)
|
||||
spyOn(SessionPrompt, "prompt").mockResolvedValue({
|
||||
parts: [{ type: "text", text: "done" }],
|
||||
} as Awaited<ReturnType<typeof SessionPrompt.prompt>>)
|
||||
|
||||
const tool = await TaskTool.init()
|
||||
const run = tool.execute(
|
||||
{
|
||||
description: "inspect bug",
|
||||
prompt: "check it",
|
||||
subagent_type: "general",
|
||||
},
|
||||
{
|
||||
sessionID: parent,
|
||||
messageID,
|
||||
agent: "build",
|
||||
abort: abort.signal,
|
||||
messages: [],
|
||||
metadata: () => {},
|
||||
ask: async () => {},
|
||||
},
|
||||
)
|
||||
|
||||
await started.promise
|
||||
abort.abort()
|
||||
gate.done()
|
||||
await run
|
||||
|
||||
expect(cancel).toHaveBeenCalledWith(child)
|
||||
})
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue