diff --git a/packages/opencode/src/server/middleware.ts b/packages/opencode/src/server/middleware.ts index ebf0163cd6..278740c57d 100644 --- a/packages/opencode/src/server/middleware.ts +++ b/packages/opencode/src/server/middleware.ts @@ -1,6 +1,7 @@ import { Provider } from "../provider/provider" import { NamedError } from "@opencode-ai/util/error" import { NotFoundError } from "../storage/db" +import { Session } from "../session" import type { ContentfulStatusCode } from "hono/utils/http-status" import type { ErrorHandler } from "hono" import { HTTPException } from "hono/http-exception" @@ -20,6 +21,9 @@ export function errorHandler(log: Log.Logger): ErrorHandler { else status = 500 return c.json(err.toObject(), { status }) } + if (err instanceof Session.BusyError) { + return c.json(new NamedError.Unknown({ message: err.message }).toObject(), { status: 400 }) + } if (err instanceof HTTPException) return err.getResponse() const message = err instanceof Error && err.stack ? err.stack : err.toString() return c.json(new NamedError.Unknown({ message }).toObject(), { diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index c986040523..74506c31da 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -849,7 +849,8 @@ export namespace Session { export const children = fn(SessionID.zod, (id) => runPromise((svc) => svc.children(id))) export const remove = fn(SessionID.zod, (id) => runPromise((svc) => svc.remove(id))) export async function updateMessage(msg: T): Promise { - return runPromise((svc) => svc.updateMessage(MessageV2.Info.parse(msg) as T)) + MessageV2.Info.parse(msg) + return runPromise((svc) => svc.updateMessage(msg)) } export const removeMessage = fn(z.object({ sessionID: SessionID.zod, messageID: MessageID.zod }), (input) => @@ -862,7 +863,8 @@ export namespace Session { ) export async function updatePart(part: T): Promise { - return runPromise((svc) => svc.updatePart(MessageV2.Part.parse(part) as T)) + MessageV2.Part.parse(part) + return runPromise((svc) => svc.updatePart(part)) } export const updatePartDelta = fn( diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index a80ee45201..b1e9840e4f 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -92,12 +92,10 @@ export namespace SessionRevert { const sessionID = session.id const msgs = await Session.messages({ sessionID }) const messageID = session.revert.messageID - const preserve = [] as MessageV2.WithParts[] const remove = [] as MessageV2.WithParts[] let target: MessageV2.WithParts | undefined for (const msg of msgs) { if (msg.info.id < messageID) { - preserve.push(msg) continue } if (msg.info.id > messageID) { @@ -105,7 +103,6 @@ export namespace SessionRevert { continue } if (session.revert.partID) { - preserve.push(msg) target = msg continue } diff --git a/packages/opencode/test/server/session-actions.test.ts b/packages/opencode/test/server/session-actions.test.ts new file mode 100644 index 0000000000..e6dba676ce --- /dev/null +++ b/packages/opencode/test/server/session-actions.test.ts @@ -0,0 +1,83 @@ +import { afterEach, describe, expect, mock, spyOn, test } from "bun:test" +import { Instance } from "../../src/project/instance" +import { Server } from "../../src/server/server" +import { Session } from "../../src/session" +import { ModelID, ProviderID } from "../../src/provider/schema" +import { MessageID, PartID, type SessionID } from "../../src/session/schema" +import { SessionPrompt } from "../../src/session/prompt" +import { Log } from "../../src/util/log" +import { tmpdir } from "../fixture/fixture" + +Log.init({ print: false }) + +afterEach(async () => { + mock.restore() + await Instance.disposeAll() +}) + +async function user(sessionID: SessionID, text: string) { + const msg = await Session.updateMessage({ + id: MessageID.ascending(), + role: "user", + sessionID, + agent: "build", + model: { providerID: ProviderID.make("test"), modelID: ModelID.make("test") }, + time: { created: Date.now() }, + }) + await Session.updatePart({ + id: PartID.ascending(), + sessionID, + messageID: msg.id, + type: "text", + text, + }) + return msg +} + +describe("session action routes", () => { + test("abort route calls SessionPrompt.cancel", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue() + const app = Server.Default() + + const res = await app.request(`/session/${session.id}/abort`, { + method: "POST", + }) + + expect(res.status).toBe(200) + expect(await res.json()).toBe(true) + expect(cancel).toHaveBeenCalledWith(session.id) + + await Session.remove(session.id) + }, + }) + }) + + test("delete message route returns 400 when session is busy", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const msg = await user(session.id, "hello") + const busy = spyOn(SessionPrompt, "assertNotBusy").mockRejectedValue(new Session.BusyError(session.id)) + const remove = spyOn(Session, "removeMessage").mockResolvedValue(msg.id) + const app = Server.Default() + + const res = await app.request(`/session/${session.id}/message/${msg.id}`, { + method: "DELETE", + }) + + expect(res.status).toBe(400) + expect(busy).toHaveBeenCalledWith(session.id) + expect(remove).not.toHaveBeenCalled() + + await Session.remove(session.id) + }, + }) + }) +}) diff --git a/packages/opencode/test/session/compaction.test.ts b/packages/opencode/test/session/compaction.test.ts index c08fef5633..a686d7ccff 100644 --- a/packages/opencode/test/session/compaction.test.ts +++ b/packages/opencode/test/session/compaction.test.ts @@ -509,6 +509,36 @@ describe("session.compaction.prune", () => { }) describe("session.compaction.process", () => { + test("throws when parent is not a user message", async () => { + await using tmp = await tmpdir() + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const msg = await user(session.id, "hello") + const reply = await assistant(session.id, msg.id, tmp.path) + const rt = runtime("continue") + try { + const msgs = await Session.messages({ sessionID: session.id }) + await expect( + rt.runPromise( + SessionCompaction.Service.use((svc) => + svc.process({ + parentID: reply.id, + messages: msgs, + sessionID: session.id, + auto: false, + }), + ), + ), + ).rejects.toThrow(`Compaction parent must be a user message: ${reply.id}`) + } finally { + await rt.dispose() + } + }, + }) + }) + test("publishes compacted event on continue", async () => { await using tmp = await tmpdir() await Instance.provide({ diff --git a/packages/opencode/test/session/prompt-effect.test.ts b/packages/opencode/test/session/prompt-effect.test.ts index 9f35a21f4a..ef664113f3 100644 --- a/packages/opencode/test/session/prompt-effect.test.ts +++ b/packages/opencode/test/session/prompt-effect.test.ts @@ -1,7 +1,8 @@ import { NodeFileSystem } from "@effect/platform-node" -import { expect } from "bun:test" +import { expect, spyOn } from "bun:test" import { Cause, Effect, Exit, Fiber, Layer, ServiceMap } from "effect" import * as Stream from "effect/Stream" +import z from "zod" import type { Agent } from "../../src/agent/agent" import { Agent as AgentSvc } from "../../src/agent/agent" import { Bus } from "../../src/bus" @@ -25,6 +26,7 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema" import { SessionStatus } from "../../src/session/status" import { Shell } from "../../src/shell/shell" import { Snapshot } from "../../src/snapshot" +import { TaskTool } from "../../src/tool/task" import { ToolRegistry } from "../../src/tool/registry" import { Truncate } from "../../src/tool/truncate" import { Log } from "../../src/util/log" @@ -630,6 +632,69 @@ it.effect( 30_000, ) +it.effect( + "cancel finalizes subtask tool state", + () => + provideTmpdirInstance( + (dir) => + Effect.gen(function* () { + const ready = defer() + const aborted = defer() + const init = spyOn(TaskTool, "init").mockImplementation(async () => ({ + description: "task", + parameters: z.object({ + description: z.string(), + prompt: z.string(), + subagent_type: z.string(), + task_id: z.string().optional(), + command: z.string().optional(), + }), + execute: async (_args, ctx) => { + ready.resolve() + ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true }) + await new Promise(() => {}) + return { + title: "", + metadata: { + sessionId: SessionID.make("task"), + model: ref, + }, + output: "", + } + }, + })) + yield* Effect.addFinalizer(() => Effect.sync(() => init.mockRestore())) + + const { prompt, chat } = yield* boot() + const msg = yield* user(chat.id, "hello") + yield* addSubtask(chat.id, msg.id) + + const fiber = yield* prompt.loop({ sessionID: chat.id }).pipe(Effect.forkChild) + yield* Effect.promise(() => ready.promise) + yield* prompt.cancel(chat.id) + yield* Effect.promise(() => aborted.promise) + + const exit = yield* Fiber.await(fiber) + expect(Exit.isSuccess(exit)).toBe(true) + + const msgs = yield* Effect.promise(() => MessageV2.filterCompacted(MessageV2.stream(chat.id))) + const taskMsg = msgs.find((item) => item.info.role === "assistant" && item.info.agent === "general") + expect(taskMsg?.info.role).toBe("assistant") + if (!taskMsg || taskMsg.info.role !== "assistant") return + + const tool = toolPart(taskMsg.parts) + expect(tool?.type).toBe("tool") + if (!tool) return + + expect(tool.state.status).not.toBe("running") + expect(taskMsg.info.time.completed).toBeDefined() + expect(taskMsg.info.finish).toBeDefined() + }), + { git: true, config: cfg }, + ), + 30_000, +) + it.effect( "cancel with queued callers resolves all cleanly", () =>