From 6fe0f52a0587a79135af92d1adc02ae90722040a Mon Sep 17 00:00:00 2001 From: Kit Langton Date: Tue, 7 Apr 2026 11:12:29 -0400 Subject: [PATCH] fix(task): preserve subtask cancellation during startup --- packages/opencode/src/tool/subtask.ts | 31 ++++++--- packages/opencode/src/tool/task.ts | 2 + packages/opencode/test/tool/task.test.ts | 86 +++++++++++++++++++++++- 3 files changed, 110 insertions(+), 9 deletions(-) diff --git a/packages/opencode/src/tool/subtask.ts b/packages/opencode/src/tool/subtask.ts index 1f7f75815f..1173837dce 100644 --- a/packages/opencode/src/tool/subtask.ts +++ b/packages/opencode/src/tool/subtask.ts @@ -35,6 +35,8 @@ type Input = { prompt: string agent: Agent.Info model: Ref + abort?: AbortSignal + cancel?: (sessionID: SessionID) => Promise | void start?: (sessionID: SessionID, model: Ref) => Promise | void } @@ -61,15 +63,28 @@ export function output(sessionID: SessionID, text: string) { export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) { const cfg = yield* deps.cfg const model = input.agent.model ?? input.model - const found = input.taskID ? yield* deps.get(input.taskID) : undefined - const session = found - ? found - : yield* deps.create({ - parentID: input.parentID, - title: input.description + ` (@${input.agent.name} subagent)`, - }) + const session = yield* Effect.uninterruptibleMask((restore) => + Effect.gen(function* () { + const found = input.taskID ? yield* restore(deps.get(input.taskID)) : undefined + const session = found + ? found + : yield* restore( + deps.create({ + parentID: input.parentID, + title: input.description + ` (@${input.agent.name} subagent)`, + }), + ) - yield* Effect.promise(() => Promise.resolve(input.start?.(session.id, model))) + const start = input.start?.(session.id, model) + if (start) yield* Effect.promise(() => Promise.resolve(start)) + return session + }), + ) + + if (input.abort?.aborted) { + const cancel = input.cancel?.(session.id) + if (cancel) yield* Effect.promise(() => Promise.resolve(cancel)) + } const result = yield* deps.prompt({ sessionID: session.id, diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index 280cc2a533..d66756f808 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -88,6 +88,8 @@ export const TaskTool = Tool.define("task", async (ctx) => { description: params.description, prompt: params.prompt, agent, + abort: ctx.abort, + cancel: SessionPrompt.cancel, model: { modelID: msg.info.modelID, providerID: msg.info.providerID, diff --git a/packages/opencode/test/tool/task.test.ts b/packages/opencode/test/tool/task.test.ts index aae48a30ab..f6cbe2de30 100644 --- a/packages/opencode/test/tool/task.test.ts +++ b/packages/opencode/test/tool/task.test.ts @@ -1,13 +1,28 @@ -import { afterEach, describe, expect, test } from "bun:test" +import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" import { Agent } from "../../src/agent/agent" +import { Config } from "../../src/config/config" import { Instance } from "../../src/project/instance" +import { ModelID, ProviderID } from "../../src/provider/schema" +import { MessageV2 } from "../../src/session/message-v2" +import { SessionPrompt } from "../../src/session/prompt" +import { MessageID, SessionID } from "../../src/session/schema" +import { Session } from "../../src/session" import { TaskTool } from "../../src/tool/task" import { tmpdir } from "../fixture/fixture" afterEach(async () => { + mock.restore() await Instance.disposeAll() }) +function wait() { + let done!: (value: T | PromiseLike) => void + const promise = new Promise((resolve) => { + done = resolve + }) + return { promise, done } +} + describe("tool.task", () => { test("description sorts subagents by name and is stable across calls", async () => { await using tmp = await tmpdir({ @@ -46,4 +61,73 @@ describe("tool.task", () => { }, }) }) + + test("cancels child session when aborted during creation", async () => { + const started = wait() + const gate = wait() + const parent = SessionID.make("parent") + const child = SessionID.make("child") + const messageID = MessageID.ascending() + const abort = new AbortController() + const agent: Agent.Info = { + name: "general", + description: "General agent", + mode: "subagent", + options: {}, + permission: [], + } + const ref = { + providerID: ProviderID.make("test"), + modelID: ModelID.make("test-model"), + } + + spyOn(Agent, "list").mockResolvedValue([agent]) + spyOn(Agent, "get").mockResolvedValue(agent) + spyOn(Config, "get").mockResolvedValue({ experimental: {} } as Awaited>) + spyOn(MessageV2, "get").mockResolvedValue({ + info: { + role: "assistant", + providerID: ref.providerID, + modelID: ref.modelID, + }, + } as Awaited>) + spyOn(Session, "get").mockRejectedValue(new Error("missing")) + spyOn(Session, "create").mockImplementation(async () => { + started.done() + await gate.promise + return { id: child } as Awaited> + }) + const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue() + spyOn(SessionPrompt, "resolvePromptParts").mockResolvedValue( + [] as Awaited>, + ) + spyOn(SessionPrompt, "prompt").mockResolvedValue({ + parts: [{ type: "text", text: "done" }], + } as Awaited>) + + const tool = await TaskTool.init() + const run = tool.execute( + { + description: "inspect bug", + prompt: "check it", + subagent_type: "general", + }, + { + sessionID: parent, + messageID, + agent: "build", + abort: abort.signal, + messages: [], + metadata: () => {}, + ask: async () => {}, + }, + ) + + await started.promise + abort.abort() + gate.done() + await run + + expect(cancel).toHaveBeenCalledWith(child) + }) })