fix(task): preserve subtask cancellation during startup

pull/21116/head
Kit Langton 2026-04-07 11:12:29 -04:00
parent a3bf978919
commit 6fe0f52a05
3 changed files with 110 additions and 9 deletions

View File

@ -35,6 +35,8 @@ type Input = {
prompt: string prompt: string
agent: Agent.Info agent: Agent.Info
model: Ref model: Ref
abort?: AbortSignal
cancel?: (sessionID: SessionID) => Promise<void> | void
start?: (sessionID: SessionID, model: Ref) => Promise<void> | void start?: (sessionID: SessionID, model: Ref) => Promise<void> | void
} }
@ -61,15 +63,28 @@ export function output(sessionID: SessionID, text: string) {
export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) { export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) {
const cfg = yield* deps.cfg const cfg = yield* deps.cfg
const model = input.agent.model ?? input.model const model = input.agent.model ?? input.model
const found = input.taskID ? yield* deps.get(input.taskID) : undefined const session = yield* Effect.uninterruptibleMask((restore) =>
const session = found Effect.gen(function* () {
? found const found = input.taskID ? yield* restore(deps.get(input.taskID)) : undefined
: yield* deps.create({ const session = found
parentID: input.parentID, ? found
title: input.description + ` (@${input.agent.name} subagent)`, : yield* restore(
}) deps.create({
parentID: input.parentID,
title: input.description + ` (@${input.agent.name} subagent)`,
}),
)
yield* Effect.promise(() => Promise.resolve(input.start?.(session.id, model))) const start = input.start?.(session.id, model)
if (start) yield* Effect.promise(() => Promise.resolve(start))
return session
}),
)
if (input.abort?.aborted) {
const cancel = input.cancel?.(session.id)
if (cancel) yield* Effect.promise(() => Promise.resolve(cancel))
}
const result = yield* deps.prompt({ const result = yield* deps.prompt({
sessionID: session.id, sessionID: session.id,

View File

@ -88,6 +88,8 @@ export const TaskTool = Tool.define("task", async (ctx) => {
description: params.description, description: params.description,
prompt: params.prompt, prompt: params.prompt,
agent, agent,
abort: ctx.abort,
cancel: SessionPrompt.cancel,
model: { model: {
modelID: msg.info.modelID, modelID: msg.info.modelID,
providerID: msg.info.providerID, providerID: msg.info.providerID,

View File

@ -1,13 +1,28 @@
import { afterEach, describe, expect, test } from "bun:test" import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
import { Agent } from "../../src/agent/agent" import { Agent } from "../../src/agent/agent"
import { Config } from "../../src/config/config"
import { Instance } from "../../src/project/instance" import { Instance } from "../../src/project/instance"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { MessageV2 } from "../../src/session/message-v2"
import { SessionPrompt } from "../../src/session/prompt"
import { MessageID, SessionID } from "../../src/session/schema"
import { Session } from "../../src/session"
import { TaskTool } from "../../src/tool/task" import { TaskTool } from "../../src/tool/task"
import { tmpdir } from "../fixture/fixture" import { tmpdir } from "../fixture/fixture"
afterEach(async () => { afterEach(async () => {
mock.restore()
await Instance.disposeAll() await Instance.disposeAll()
}) })
function wait<T>() {
let done!: (value: T | PromiseLike<T>) => void
const promise = new Promise<T>((resolve) => {
done = resolve
})
return { promise, done }
}
describe("tool.task", () => { describe("tool.task", () => {
test("description sorts subagents by name and is stable across calls", async () => { test("description sorts subagents by name and is stable across calls", async () => {
await using tmp = await tmpdir({ await using tmp = await tmpdir({
@ -46,4 +61,73 @@ describe("tool.task", () => {
}, },
}) })
}) })
test("cancels child session when aborted during creation", async () => {
const started = wait<void>()
const gate = wait<void>()
const parent = SessionID.make("parent")
const child = SessionID.make("child")
const messageID = MessageID.ascending()
const abort = new AbortController()
const agent: Agent.Info = {
name: "general",
description: "General agent",
mode: "subagent",
options: {},
permission: [],
}
const ref = {
providerID: ProviderID.make("test"),
modelID: ModelID.make("test-model"),
}
spyOn(Agent, "list").mockResolvedValue([agent])
spyOn(Agent, "get").mockResolvedValue(agent)
spyOn(Config, "get").mockResolvedValue({ experimental: {} } as Awaited<ReturnType<typeof Config.get>>)
spyOn(MessageV2, "get").mockResolvedValue({
info: {
role: "assistant",
providerID: ref.providerID,
modelID: ref.modelID,
},
} as Awaited<ReturnType<typeof MessageV2.get>>)
spyOn(Session, "get").mockRejectedValue(new Error("missing"))
spyOn(Session, "create").mockImplementation(async () => {
started.done()
await gate.promise
return { id: child } as Awaited<ReturnType<typeof Session.create>>
})
const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue()
spyOn(SessionPrompt, "resolvePromptParts").mockResolvedValue(
[] as Awaited<ReturnType<typeof SessionPrompt.resolvePromptParts>>,
)
spyOn(SessionPrompt, "prompt").mockResolvedValue({
parts: [{ type: "text", text: "done" }],
} as Awaited<ReturnType<typeof SessionPrompt.prompt>>)
const tool = await TaskTool.init()
const run = tool.execute(
{
description: "inspect bug",
prompt: "check it",
subagent_type: "general",
},
{
sessionID: parent,
messageID,
agent: "build",
abort: abort.signal,
messages: [],
metadata: () => {},
ask: async () => {},
},
)
await started.promise
abort.abort()
gate.done()
await run
expect(cancel).toHaveBeenCalledWith(child)
})
}) })