fix(core): fix restoring earlier messages in a reverted chain (#20780)

pull/14251/head^2
Nate Williams 2026-04-03 08:53:00 -04:00 committed by GitHub
parent b969066a20
commit 6359d00fb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 185 additions and 0 deletions

View File

@ -72,6 +72,7 @@ export namespace SessionRevert {
if (!rev) return session
rev.snapshot = session.revert?.snapshot ?? (yield* snap.track())
if (session.revert?.snapshot) yield* snap.restore(session.revert.snapshot)
yield* snap.revert(patches)
if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string)
const range = all.filter((msg) => msg.info.id >= rev!.messageID)

View File

@ -1,10 +1,12 @@
import { describe, expect, test, beforeEach, afterEach } from "bun:test"
import fs from "fs/promises"
import path from "path"
import { Session } from "../../src/session"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { SessionRevert } from "../../src/session/revert"
import { SessionCompaction } from "../../src/session/compaction"
import { MessageV2 } from "../../src/session/message-v2"
import { Snapshot } from "../../src/snapshot"
import { Log } from "../../src/util/log"
import { Instance } from "../../src/project/instance"
import { MessageID, PartID } from "../../src/session/schema"
@ -70,6 +72,13 @@ function tool(sessionID: string, messageID: string) {
})
}
const tokens = {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
}
describe("revert + compact workflow", () => {
test("should properly handle compact command after revert", async () => {
await using tmp = await tmpdir({ git: true })
@ -434,4 +443,179 @@ describe("revert + compact workflow", () => {
},
})
})
test("restore messages in sequential order", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
await fs.writeFile(path.join(tmp.path, "b.txt"), "b0")
await fs.writeFile(path.join(tmp.path, "c.txt"), "c0")
const session = await Session.create({})
const sid = session.id
const turn = async (file: string, next: string) => {
const u = await user(sid)
await text(sid, u.id, `${file}:${next}`)
const a = await assistant(sid, u.id, tmp.path)
const before = await Snapshot.track()
if (!before) throw new Error("expected snapshot")
await fs.writeFile(path.join(tmp.path, file), next)
const after = await Snapshot.track()
if (!after) throw new Error("expected snapshot")
const patch = await Snapshot.patch(before)
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "step-start",
snapshot: before,
})
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "step-finish",
reason: "stop",
snapshot: after,
cost: 0,
tokens,
})
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "patch",
hash: patch.hash,
files: patch.files,
})
return u.id
}
const first = await turn("a.txt", "a1")
const second = await turn("b.txt", "b2")
const third = await turn("c.txt", "c3")
await SessionRevert.revert({
sessionID: sid,
messageID: first,
})
expect((await Session.get(sid)).revert?.messageID).toBe(first)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
await SessionRevert.revert({
sessionID: sid,
messageID: second,
})
expect((await Session.get(sid)).revert?.messageID).toBe(second)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
await SessionRevert.revert({
sessionID: sid,
messageID: third,
})
expect((await Session.get(sid)).revert?.messageID).toBe(third)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
await SessionRevert.unrevert({
sessionID: sid,
})
expect((await Session.get(sid)).revert).toBeUndefined()
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3")
},
})
})
test("restore same file in sequential order", async () => {
await using tmp = await tmpdir({ git: true })
await Instance.provide({
directory: tmp.path,
fn: async () => {
await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
const session = await Session.create({})
const sid = session.id
const turn = async (next: string) => {
const u = await user(sid)
await text(sid, u.id, `a.txt:${next}`)
const a = await assistant(sid, u.id, tmp.path)
const before = await Snapshot.track()
if (!before) throw new Error("expected snapshot")
await fs.writeFile(path.join(tmp.path, "a.txt"), next)
const after = await Snapshot.track()
if (!after) throw new Error("expected snapshot")
const patch = await Snapshot.patch(before)
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "step-start",
snapshot: before,
})
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "step-finish",
reason: "stop",
snapshot: after,
cost: 0,
tokens,
})
await Session.updatePart({
id: PartID.ascending(),
messageID: a.id,
sessionID: sid,
type: "patch",
hash: patch.hash,
files: patch.files,
})
return u.id
}
const first = await turn("a1")
const second = await turn("a2")
const third = await turn("a3")
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
await SessionRevert.revert({
sessionID: sid,
messageID: first,
})
expect((await Session.get(sid)).revert?.messageID).toBe(first)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
await SessionRevert.revert({
sessionID: sid,
messageID: second,
})
expect((await Session.get(sid)).revert?.messageID).toBe(second)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
await SessionRevert.revert({
sessionID: sid,
messageID: third,
})
expect((await Session.get(sid)).revert?.messageID).toBe(third)
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2")
await SessionRevert.unrevert({
sessionID: sid,
})
expect((await Session.get(sid)).revert).toBeUndefined()
expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
},
})
})
})