fix(effect): preserve task cancellation during prompt

pull/21017/head
Kit Langton 2026-04-04 19:52:07 -04:00
parent 8616818e37
commit 9a5cf96b7a
2 changed files with 71 additions and 59 deletions

View File

@ -7,7 +7,6 @@ import { SessionID, MessageID } from "../session/schema"
import { MessageV2 } from "../session/message-v2"
import { Agent } from "../agent/agent"
import { SessionPrompt } from "../session/prompt"
import { iife } from "@/util/iife"
import { Config } from "../config/config"
import { Permission } from "@/permission"
@ -68,7 +67,7 @@ export const TaskTool = Tool.defineEffect(
)
}
const next = yield* agent.get(params.subagent_type).pipe(Effect.catch(() => Effect.succeed(undefined)))
const next = yield* agent.get(params.subagent_type)
if (!next) {
return yield* Effect.fail(new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`))
}
@ -76,14 +75,17 @@ export const TaskTool = Tool.defineEffect(
const hasTask = next.permission.some((rule) => rule.permission === "task")
const hasTodo = next.permission.some((rule) => rule.permission === "todowrite")
const session = yield* Effect.promise(() =>
iife(async () => {
if (params.task_id) {
const found = await Session.get(SessionID.make(params.task_id)).catch(() => {})
if (found) return found
}
return Session.create({
const taskID = params.task_id
const session = taskID
? yield* Effect.promise(() => {
const id = SessionID.make(taskID)
return Session.get(id).catch(() => undefined)
})
: undefined
const nextSession =
session ??
(yield* Effect.promise(() =>
Session.create({
parentID: ctx.sessionID,
title: params.description + ` (@${next.name} subagent)`,
permission: [
@ -111,9 +113,8 @@ export const TaskTool = Tool.defineEffect(
permission: item,
})) ?? []),
],
})
}),
)
}),
))
const msg = yield* Effect.sync(() => MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID }))
if (msg.info.role !== "assistant") return yield* Effect.fail(new Error("Not an assistant message"))
@ -126,7 +127,7 @@ export const TaskTool = Tool.defineEffect(
ctx.metadata({
title: params.description,
metadata: {
sessionId: session.id,
sessionId: nextSession.id,
model,
},
})
@ -134,54 +135,51 @@ export const TaskTool = Tool.defineEffect(
const messageID = MessageID.ascending()
function cancel() {
SessionPrompt.cancel(session.id)
SessionPrompt.cancel(nextSession.id)
}
return yield* Effect.acquireUseRelease(
Effect.sync(() => {
ctx.abort.addEventListener("abort", cancel)
}),
() => Effect.promise(() => SessionPrompt.resolvePromptParts(params.prompt)),
() =>
Effect.gen(function* () {
const parts = yield* Effect.promise(() => SessionPrompt.resolvePromptParts(params.prompt))
const result = yield* Effect.promise(() =>
SessionPrompt.prompt({
messageID,
sessionID: nextSession.id,
model: {
modelID: model.modelID,
providerID: model.providerID,
},
agent: next.name,
tools: {
...(hasTodo ? {} : { todowrite: false }),
...(hasTask ? {} : { task: false }),
...Object.fromEntries((cfg.experimental?.primary_tools ?? []).map((item) => [item, false])),
},
parts,
}),
)
return {
title: params.description,
metadata: {
sessionId: nextSession.id,
model,
},
output: [
`task_id: ${nextSession.id} (for resuming to continue this task if needed)`,
"",
"<task_result>",
result.parts.findLast((item) => item.type === "text")?.text ?? "",
"</task_result>",
].join("\n"),
}
}),
() =>
Effect.sync(() => {
ctx.abort.removeEventListener("abort", cancel)
}),
).pipe(
Effect.flatMap((parts) =>
Effect.promise(() =>
SessionPrompt.prompt({
messageID,
sessionID: session.id,
model: {
modelID: model.modelID,
providerID: model.providerID,
},
agent: next.name,
tools: {
...(hasTodo ? {} : { todowrite: false }),
...(hasTask ? {} : { task: false }),
...Object.fromEntries((cfg.experimental?.primary_tools ?? []).map((item) => [item, false])),
},
parts,
}),
),
),
Effect.map((result) => {
const text = result.parts.findLast((item) => item.type === "text")?.text ?? ""
return {
title: params.description,
metadata: {
sessionId: session.id,
model,
},
output: [
`task_id: ${session.id} (for resuming to continue this task if needed)`,
"",
"<task_result>",
text,
"</task_result>",
].join("\n"),
}
}),
)
})

View File

@ -1,5 +1,5 @@
import { NodeFileSystem } from "@effect/platform-node"
import { expect, spyOn } from "bun:test"
import { expect } from "bun:test"
import { Cause, Effect, Exit, Fiber, Layer } from "effect"
import path from "path"
import z from "zod"
@ -13,7 +13,6 @@ import { MCP } from "../../src/mcp"
import { Permission } from "../../src/permission"
import { Plugin } from "../../src/plugin"
import { Provider as ProviderSvc } from "../../src/provider/provider"
import type { Provider } from "../../src/provider/provider"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { Question } from "../../src/question"
import { Todo } from "../../src/session/todo"
@ -626,7 +625,7 @@ it.live(
"cancel finalizes subtask tool state",
() =>
provideTmpdirInstance(
(dir) =>
() =>
Effect.gen(function* () {
const ready = defer<void>()
const aborted = defer<void>()
@ -642,6 +641,13 @@ it.live(
command: z.string().optional(),
}),
execute: async (_args, ctx) => {
ctx.metadata({
title: "inspect bug",
metadata: {
sessionId: SessionID.make("task"),
model: ref,
},
})
ready.resolve()
ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
await new Promise<void>(() => {})
@ -674,11 +680,19 @@ it.live(
expect(taskMsg?.info.role).toBe("assistant")
if (!taskMsg || taskMsg.info.role !== "assistant") return
const tool = toolPart(taskMsg.parts)
expect(tool?.type).toBe("tool")
const tool = errorTool(taskMsg.parts)
if (!tool) return
expect(tool.state.status).not.toBe("running")
expect(tool.state.error).toBe("Cancelled")
expect(tool.state.input).toEqual({
description: "inspect bug",
prompt: "look into the cache key path",
subagent_type: "general",
})
expect(tool.state.metadata).toEqual({
sessionId: SessionID.make("task"),
model: ref,
})
expect(taskMsg.info.time.completed).toBeDefined()
expect(taskMsg.info.finish).toBeDefined()
}),