fix(task): preserve subtask cancellation during startup
parent
a3bf978919
commit
6fe0f52a05
|
|
@ -35,6 +35,8 @@ type Input = {
|
||||||
prompt: string
|
prompt: string
|
||||||
agent: Agent.Info
|
agent: Agent.Info
|
||||||
model: Ref
|
model: Ref
|
||||||
|
abort?: AbortSignal
|
||||||
|
cancel?: (sessionID: SessionID) => Promise<void> | void
|
||||||
start?: (sessionID: SessionID, model: Ref) => Promise<void> | void
|
start?: (sessionID: SessionID, model: Ref) => Promise<void> | void
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -61,15 +63,28 @@ export function output(sessionID: SessionID, text: string) {
|
||||||
export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) {
|
export const run = Effect.fn("Subtask.run")(function* (deps: Deps, input: Input) {
|
||||||
const cfg = yield* deps.cfg
|
const cfg = yield* deps.cfg
|
||||||
const model = input.agent.model ?? input.model
|
const model = input.agent.model ?? input.model
|
||||||
const found = input.taskID ? yield* deps.get(input.taskID) : undefined
|
const session = yield* Effect.uninterruptibleMask((restore) =>
|
||||||
|
Effect.gen(function* () {
|
||||||
|
const found = input.taskID ? yield* restore(deps.get(input.taskID)) : undefined
|
||||||
const session = found
|
const session = found
|
||||||
? found
|
? found
|
||||||
: yield* deps.create({
|
: yield* restore(
|
||||||
|
deps.create({
|
||||||
parentID: input.parentID,
|
parentID: input.parentID,
|
||||||
title: input.description + ` (@${input.agent.name} subagent)`,
|
title: input.description + ` (@${input.agent.name} subagent)`,
|
||||||
})
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
yield* Effect.promise(() => Promise.resolve(input.start?.(session.id, model)))
|
const start = input.start?.(session.id, model)
|
||||||
|
if (start) yield* Effect.promise(() => Promise.resolve(start))
|
||||||
|
return session
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input.abort?.aborted) {
|
||||||
|
const cancel = input.cancel?.(session.id)
|
||||||
|
if (cancel) yield* Effect.promise(() => Promise.resolve(cancel))
|
||||||
|
}
|
||||||
|
|
||||||
const result = yield* deps.prompt({
|
const result = yield* deps.prompt({
|
||||||
sessionID: session.id,
|
sessionID: session.id,
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,8 @@ export const TaskTool = Tool.define("task", async (ctx) => {
|
||||||
description: params.description,
|
description: params.description,
|
||||||
prompt: params.prompt,
|
prompt: params.prompt,
|
||||||
agent,
|
agent,
|
||||||
|
abort: ctx.abort,
|
||||||
|
cancel: SessionPrompt.cancel,
|
||||||
model: {
|
model: {
|
||||||
modelID: msg.info.modelID,
|
modelID: msg.info.modelID,
|
||||||
providerID: msg.info.providerID,
|
providerID: msg.info.providerID,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,28 @@
|
||||||
import { afterEach, describe, expect, test } from "bun:test"
|
import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
|
||||||
import { Agent } from "../../src/agent/agent"
|
import { Agent } from "../../src/agent/agent"
|
||||||
|
import { Config } from "../../src/config/config"
|
||||||
import { Instance } from "../../src/project/instance"
|
import { Instance } from "../../src/project/instance"
|
||||||
|
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||||
|
import { MessageV2 } from "../../src/session/message-v2"
|
||||||
|
import { SessionPrompt } from "../../src/session/prompt"
|
||||||
|
import { MessageID, SessionID } from "../../src/session/schema"
|
||||||
|
import { Session } from "../../src/session"
|
||||||
import { TaskTool } from "../../src/tool/task"
|
import { TaskTool } from "../../src/tool/task"
|
||||||
import { tmpdir } from "../fixture/fixture"
|
import { tmpdir } from "../fixture/fixture"
|
||||||
|
|
||||||
afterEach(async () => {
|
afterEach(async () => {
|
||||||
|
mock.restore()
|
||||||
await Instance.disposeAll()
|
await Instance.disposeAll()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
function wait<T>() {
|
||||||
|
let done!: (value: T | PromiseLike<T>) => void
|
||||||
|
const promise = new Promise<T>((resolve) => {
|
||||||
|
done = resolve
|
||||||
|
})
|
||||||
|
return { promise, done }
|
||||||
|
}
|
||||||
|
|
||||||
describe("tool.task", () => {
|
describe("tool.task", () => {
|
||||||
test("description sorts subagents by name and is stable across calls", async () => {
|
test("description sorts subagents by name and is stable across calls", async () => {
|
||||||
await using tmp = await tmpdir({
|
await using tmp = await tmpdir({
|
||||||
|
|
@ -46,4 +61,73 @@ describe("tool.task", () => {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("cancels child session when aborted during creation", async () => {
|
||||||
|
const started = wait<void>()
|
||||||
|
const gate = wait<void>()
|
||||||
|
const parent = SessionID.make("parent")
|
||||||
|
const child = SessionID.make("child")
|
||||||
|
const messageID = MessageID.ascending()
|
||||||
|
const abort = new AbortController()
|
||||||
|
const agent: Agent.Info = {
|
||||||
|
name: "general",
|
||||||
|
description: "General agent",
|
||||||
|
mode: "subagent",
|
||||||
|
options: {},
|
||||||
|
permission: [],
|
||||||
|
}
|
||||||
|
const ref = {
|
||||||
|
providerID: ProviderID.make("test"),
|
||||||
|
modelID: ModelID.make("test-model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
spyOn(Agent, "list").mockResolvedValue([agent])
|
||||||
|
spyOn(Agent, "get").mockResolvedValue(agent)
|
||||||
|
spyOn(Config, "get").mockResolvedValue({ experimental: {} } as Awaited<ReturnType<typeof Config.get>>)
|
||||||
|
spyOn(MessageV2, "get").mockResolvedValue({
|
||||||
|
info: {
|
||||||
|
role: "assistant",
|
||||||
|
providerID: ref.providerID,
|
||||||
|
modelID: ref.modelID,
|
||||||
|
},
|
||||||
|
} as Awaited<ReturnType<typeof MessageV2.get>>)
|
||||||
|
spyOn(Session, "get").mockRejectedValue(new Error("missing"))
|
||||||
|
spyOn(Session, "create").mockImplementation(async () => {
|
||||||
|
started.done()
|
||||||
|
await gate.promise
|
||||||
|
return { id: child } as Awaited<ReturnType<typeof Session.create>>
|
||||||
|
})
|
||||||
|
const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue()
|
||||||
|
spyOn(SessionPrompt, "resolvePromptParts").mockResolvedValue(
|
||||||
|
[] as Awaited<ReturnType<typeof SessionPrompt.resolvePromptParts>>,
|
||||||
|
)
|
||||||
|
spyOn(SessionPrompt, "prompt").mockResolvedValue({
|
||||||
|
parts: [{ type: "text", text: "done" }],
|
||||||
|
} as Awaited<ReturnType<typeof SessionPrompt.prompt>>)
|
||||||
|
|
||||||
|
const tool = await TaskTool.init()
|
||||||
|
const run = tool.execute(
|
||||||
|
{
|
||||||
|
description: "inspect bug",
|
||||||
|
prompt: "check it",
|
||||||
|
subagent_type: "general",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sessionID: parent,
|
||||||
|
messageID,
|
||||||
|
agent: "build",
|
||||||
|
abort: abort.signal,
|
||||||
|
messages: [],
|
||||||
|
metadata: () => {},
|
||||||
|
ask: async () => {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await started.promise
|
||||||
|
abort.abort()
|
||||||
|
gate.done()
|
||||||
|
await run
|
||||||
|
|
||||||
|
expect(cancel).toHaveBeenCalledWith(child)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue