diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index 50c3781e9c..9df3f36eb8 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -72,6 +72,7 @@ export namespace SessionRevert { if (!rev) return session rev.snapshot = session.revert?.snapshot ?? (yield* snap.track()) + if (session.revert?.snapshot) yield* snap.restore(session.revert.snapshot) yield* snap.revert(patches) if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string) const range = all.filter((msg) => msg.info.id >= rev!.messageID) diff --git a/packages/opencode/test/session/revert-compact.test.ts b/packages/opencode/test/session/revert-compact.test.ts index c7230772df..95d90325ad 100644 --- a/packages/opencode/test/session/revert-compact.test.ts +++ b/packages/opencode/test/session/revert-compact.test.ts @@ -1,10 +1,12 @@ import { describe, expect, test, beforeEach, afterEach } from "bun:test" +import fs from "fs/promises" import path from "path" import { Session } from "../../src/session" import { ModelID, ProviderID } from "../../src/provider/schema" import { SessionRevert } from "../../src/session/revert" import { SessionCompaction } from "../../src/session/compaction" import { MessageV2 } from "../../src/session/message-v2" +import { Snapshot } from "../../src/snapshot" import { Log } from "../../src/util/log" import { Instance } from "../../src/project/instance" import { MessageID, PartID } from "../../src/session/schema" @@ -70,6 +72,13 @@ function tool(sessionID: string, messageID: string) { }) } +const tokens = { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, +} + describe("revert + compact workflow", () => { test("should properly handle compact command after revert", async () => { await using tmp = await tmpdir({ git: true }) @@ -434,4 +443,179 @@ describe("revert + compact workflow", () => { }, }) }) + + test("restore messages in sequential order", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await fs.writeFile(path.join(tmp.path, "a.txt"), "a0") + await fs.writeFile(path.join(tmp.path, "b.txt"), "b0") + await fs.writeFile(path.join(tmp.path, "c.txt"), "c0") + + const session = await Session.create({}) + const sid = session.id + + const turn = async (file: string, next: string) => { + const u = await user(sid) + await text(sid, u.id, `${file}:${next}`) + const a = await assistant(sid, u.id, tmp.path) + const before = await Snapshot.track() + if (!before) throw new Error("expected snapshot") + await fs.writeFile(path.join(tmp.path, file), next) + const after = await Snapshot.track() + if (!after) throw new Error("expected snapshot") + const patch = await Snapshot.patch(before) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "step-start", + snapshot: before, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "step-finish", + reason: "stop", + snapshot: after, + cost: 0, + tokens, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + return u.id + } + + const first = await turn("a.txt", "a1") + const second = await turn("b.txt", "b2") + const third = await turn("c.txt", "c3") + + await SessionRevert.revert({ + sessionID: sid, + messageID: first, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(first) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0") + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0") + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") + + await SessionRevert.revert({ + sessionID: sid, + messageID: second, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(second) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0") + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") + + await SessionRevert.revert({ + sessionID: sid, + messageID: third, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(third) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2") + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0") + + await SessionRevert.unrevert({ + sessionID: sid, + }) + expect((await Session.get(sid)).revert).toBeUndefined() + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") + expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2") + expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3") + }, + }) + }) + + test("restore same file in sequential order", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await fs.writeFile(path.join(tmp.path, "a.txt"), "a0") + + const session = await Session.create({}) + const sid = session.id + + const turn = async (next: string) => { + const u = await user(sid) + await text(sid, u.id, `a.txt:${next}`) + const a = await assistant(sid, u.id, tmp.path) + const before = await Snapshot.track() + if (!before) throw new Error("expected snapshot") + await fs.writeFile(path.join(tmp.path, "a.txt"), next) + const after = await Snapshot.track() + if (!after) throw new Error("expected snapshot") + const patch = await Snapshot.patch(before) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "step-start", + snapshot: before, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "step-finish", + reason: "stop", + snapshot: after, + cost: 0, + tokens, + }) + await Session.updatePart({ + id: PartID.ascending(), + messageID: a.id, + sessionID: sid, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + return u.id + } + + const first = await turn("a1") + const second = await turn("a2") + const third = await turn("a3") + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3") + + await SessionRevert.revert({ + sessionID: sid, + messageID: first, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(first) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0") + + await SessionRevert.revert({ + sessionID: sid, + messageID: second, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(second) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1") + + await SessionRevert.revert({ + sessionID: sid, + messageID: third, + }) + expect((await Session.get(sid)).revert?.messageID).toBe(third) + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2") + + await SessionRevert.unrevert({ + sessionID: sid, + }) + expect((await Session.get(sid)).revert).toBeUndefined() + expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3") + }, + }) + }) })