fix(effect): preserve task cancellation during prompt
parent
8616818e37
commit
9a5cf96b7a
|
|
@ -7,7 +7,6 @@ import { SessionID, MessageID } from "../session/schema"
|
|||
import { MessageV2 } from "../session/message-v2"
|
||||
import { Agent } from "../agent/agent"
|
||||
import { SessionPrompt } from "../session/prompt"
|
||||
import { iife } from "@/util/iife"
|
||||
import { Config } from "../config/config"
|
||||
import { Permission } from "@/permission"
|
||||
|
||||
|
|
@ -68,7 +67,7 @@ export const TaskTool = Tool.defineEffect(
|
|||
)
|
||||
}
|
||||
|
||||
const next = yield* agent.get(params.subagent_type).pipe(Effect.catch(() => Effect.succeed(undefined)))
|
||||
const next = yield* agent.get(params.subagent_type)
|
||||
if (!next) {
|
||||
return yield* Effect.fail(new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`))
|
||||
}
|
||||
|
|
@ -76,14 +75,17 @@ export const TaskTool = Tool.defineEffect(
|
|||
const hasTask = next.permission.some((rule) => rule.permission === "task")
|
||||
const hasTodo = next.permission.some((rule) => rule.permission === "todowrite")
|
||||
|
||||
const session = yield* Effect.promise(() =>
|
||||
iife(async () => {
|
||||
if (params.task_id) {
|
||||
const found = await Session.get(SessionID.make(params.task_id)).catch(() => {})
|
||||
if (found) return found
|
||||
}
|
||||
|
||||
return Session.create({
|
||||
const taskID = params.task_id
|
||||
const session = taskID
|
||||
? yield* Effect.promise(() => {
|
||||
const id = SessionID.make(taskID)
|
||||
return Session.get(id).catch(() => undefined)
|
||||
})
|
||||
: undefined
|
||||
const nextSession =
|
||||
session ??
|
||||
(yield* Effect.promise(() =>
|
||||
Session.create({
|
||||
parentID: ctx.sessionID,
|
||||
title: params.description + ` (@${next.name} subagent)`,
|
||||
permission: [
|
||||
|
|
@ -111,9 +113,8 @@ export const TaskTool = Tool.defineEffect(
|
|||
permission: item,
|
||||
})) ?? []),
|
||||
],
|
||||
})
|
||||
}),
|
||||
)
|
||||
}),
|
||||
))
|
||||
|
||||
const msg = yield* Effect.sync(() => MessageV2.get({ sessionID: ctx.sessionID, messageID: ctx.messageID }))
|
||||
if (msg.info.role !== "assistant") return yield* Effect.fail(new Error("Not an assistant message"))
|
||||
|
|
@ -126,7 +127,7 @@ export const TaskTool = Tool.defineEffect(
|
|||
ctx.metadata({
|
||||
title: params.description,
|
||||
metadata: {
|
||||
sessionId: session.id,
|
||||
sessionId: nextSession.id,
|
||||
model,
|
||||
},
|
||||
})
|
||||
|
|
@ -134,54 +135,51 @@ export const TaskTool = Tool.defineEffect(
|
|||
const messageID = MessageID.ascending()
|
||||
|
||||
function cancel() {
|
||||
SessionPrompt.cancel(session.id)
|
||||
SessionPrompt.cancel(nextSession.id)
|
||||
}
|
||||
return yield* Effect.acquireUseRelease(
|
||||
Effect.sync(() => {
|
||||
ctx.abort.addEventListener("abort", cancel)
|
||||
}),
|
||||
() => Effect.promise(() => SessionPrompt.resolvePromptParts(params.prompt)),
|
||||
() =>
|
||||
Effect.gen(function* () {
|
||||
const parts = yield* Effect.promise(() => SessionPrompt.resolvePromptParts(params.prompt))
|
||||
const result = yield* Effect.promise(() =>
|
||||
SessionPrompt.prompt({
|
||||
messageID,
|
||||
sessionID: nextSession.id,
|
||||
model: {
|
||||
modelID: model.modelID,
|
||||
providerID: model.providerID,
|
||||
},
|
||||
agent: next.name,
|
||||
tools: {
|
||||
...(hasTodo ? {} : { todowrite: false }),
|
||||
...(hasTask ? {} : { task: false }),
|
||||
...Object.fromEntries((cfg.experimental?.primary_tools ?? []).map((item) => [item, false])),
|
||||
},
|
||||
parts,
|
||||
}),
|
||||
)
|
||||
return {
|
||||
title: params.description,
|
||||
metadata: {
|
||||
sessionId: nextSession.id,
|
||||
model,
|
||||
},
|
||||
output: [
|
||||
`task_id: ${nextSession.id} (for resuming to continue this task if needed)`,
|
||||
"",
|
||||
"<task_result>",
|
||||
result.parts.findLast((item) => item.type === "text")?.text ?? "",
|
||||
"</task_result>",
|
||||
].join("\n"),
|
||||
}
|
||||
}),
|
||||
() =>
|
||||
Effect.sync(() => {
|
||||
ctx.abort.removeEventListener("abort", cancel)
|
||||
}),
|
||||
).pipe(
|
||||
Effect.flatMap((parts) =>
|
||||
Effect.promise(() =>
|
||||
SessionPrompt.prompt({
|
||||
messageID,
|
||||
sessionID: session.id,
|
||||
model: {
|
||||
modelID: model.modelID,
|
||||
providerID: model.providerID,
|
||||
},
|
||||
agent: next.name,
|
||||
tools: {
|
||||
...(hasTodo ? {} : { todowrite: false }),
|
||||
...(hasTask ? {} : { task: false }),
|
||||
...Object.fromEntries((cfg.experimental?.primary_tools ?? []).map((item) => [item, false])),
|
||||
},
|
||||
parts,
|
||||
}),
|
||||
),
|
||||
),
|
||||
Effect.map((result) => {
|
||||
const text = result.parts.findLast((item) => item.type === "text")?.text ?? ""
|
||||
return {
|
||||
title: params.description,
|
||||
metadata: {
|
||||
sessionId: session.id,
|
||||
model,
|
||||
},
|
||||
output: [
|
||||
`task_id: ${session.id} (for resuming to continue this task if needed)`,
|
||||
"",
|
||||
"<task_result>",
|
||||
text,
|
||||
"</task_result>",
|
||||
].join("\n"),
|
||||
}
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { NodeFileSystem } from "@effect/platform-node"
|
||||
import { expect, spyOn } from "bun:test"
|
||||
import { expect } from "bun:test"
|
||||
import { Cause, Effect, Exit, Fiber, Layer } from "effect"
|
||||
import path from "path"
|
||||
import z from "zod"
|
||||
|
|
@ -13,7 +13,6 @@ import { MCP } from "../../src/mcp"
|
|||
import { Permission } from "../../src/permission"
|
||||
import { Plugin } from "../../src/plugin"
|
||||
import { Provider as ProviderSvc } from "../../src/provider/provider"
|
||||
import type { Provider } from "../../src/provider/provider"
|
||||
import { ModelID, ProviderID } from "../../src/provider/schema"
|
||||
import { Question } from "../../src/question"
|
||||
import { Todo } from "../../src/session/todo"
|
||||
|
|
@ -626,7 +625,7 @@ it.live(
|
|||
"cancel finalizes subtask tool state",
|
||||
() =>
|
||||
provideTmpdirInstance(
|
||||
(dir) =>
|
||||
() =>
|
||||
Effect.gen(function* () {
|
||||
const ready = defer<void>()
|
||||
const aborted = defer<void>()
|
||||
|
|
@ -642,6 +641,13 @@ it.live(
|
|||
command: z.string().optional(),
|
||||
}),
|
||||
execute: async (_args, ctx) => {
|
||||
ctx.metadata({
|
||||
title: "inspect bug",
|
||||
metadata: {
|
||||
sessionId: SessionID.make("task"),
|
||||
model: ref,
|
||||
},
|
||||
})
|
||||
ready.resolve()
|
||||
ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
|
||||
await new Promise<void>(() => {})
|
||||
|
|
@ -674,11 +680,19 @@ it.live(
|
|||
expect(taskMsg?.info.role).toBe("assistant")
|
||||
if (!taskMsg || taskMsg.info.role !== "assistant") return
|
||||
|
||||
const tool = toolPart(taskMsg.parts)
|
||||
expect(tool?.type).toBe("tool")
|
||||
const tool = errorTool(taskMsg.parts)
|
||||
if (!tool) return
|
||||
|
||||
expect(tool.state.status).not.toBe("running")
|
||||
expect(tool.state.error).toBe("Cancelled")
|
||||
expect(tool.state.input).toEqual({
|
||||
description: "inspect bug",
|
||||
prompt: "look into the cache key path",
|
||||
subagent_type: "general",
|
||||
})
|
||||
expect(tool.state.metadata).toEqual({
|
||||
sessionId: SessionID.make("task"),
|
||||
model: ref,
|
||||
})
|
||||
expect(taskMsg.info.time.completed).toBeDefined()
|
||||
expect(taskMsg.info.finish).toBeDefined()
|
||||
}),
|
||||
|
|
|
|||
Loading…
Reference in New Issue