refactor(session): effectify SessionRevert service (#20143)
parent
954a6ca88e
commit
3fc0367b93
|
|
@ -1,12 +1,14 @@
|
||||||
import z from "zod"
|
import z from "zod"
|
||||||
import { SessionID, MessageID, PartID } from "./schema"
|
import { Effect, Layer, ServiceMap } from "effect"
|
||||||
import { Snapshot } from "../snapshot"
|
import { makeRuntime } from "@/effect/run-service"
|
||||||
import { MessageV2 } from "./message-v2"
|
|
||||||
import { Session } from "."
|
|
||||||
import { Log } from "../util/log"
|
|
||||||
import { SyncEvent } from "../sync"
|
|
||||||
import { Storage } from "@/storage/storage"
|
|
||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
|
import { Snapshot } from "../snapshot"
|
||||||
|
import { Storage } from "@/storage/storage"
|
||||||
|
import { SyncEvent } from "../sync"
|
||||||
|
import { Log } from "../util/log"
|
||||||
|
import { Session } from "."
|
||||||
|
import { MessageV2 } from "./message-v2"
|
||||||
|
import { SessionID, MessageID, PartID } from "./schema"
|
||||||
import { SessionPrompt } from "./prompt"
|
import { SessionPrompt } from "./prompt"
|
||||||
import { SessionSummary } from "./summary"
|
import { SessionSummary } from "./summary"
|
||||||
|
|
||||||
|
|
@ -20,30 +22,43 @@ export namespace SessionRevert {
|
||||||
})
|
})
|
||||||
export type RevertInput = z.infer<typeof RevertInput>
|
export type RevertInput = z.infer<typeof RevertInput>
|
||||||
|
|
||||||
export async function revert(input: RevertInput) {
|
export interface Interface {
|
||||||
await SessionPrompt.assertNotBusy(input.sessionID)
|
readonly revert: (input: RevertInput) => Effect.Effect<Session.Info>
|
||||||
const all = await Session.messages({ sessionID: input.sessionID })
|
readonly unrevert: (input: { sessionID: SessionID }) => Effect.Effect<Session.Info>
|
||||||
let lastUser: MessageV2.User | undefined
|
readonly cleanup: (session: Session.Info) => Effect.Effect<void>
|
||||||
const session = await Session.get(input.sessionID)
|
}
|
||||||
|
|
||||||
let revert: Session.Info["revert"]
|
export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/SessionRevert") {}
|
||||||
|
|
||||||
|
export const layer = Layer.effect(
|
||||||
|
Service,
|
||||||
|
Effect.gen(function* () {
|
||||||
|
const sessions = yield* Session.Service
|
||||||
|
const snap = yield* Snapshot.Service
|
||||||
|
const storage = yield* Storage.Service
|
||||||
|
const bus = yield* Bus.Service
|
||||||
|
|
||||||
|
const revert = Effect.fn("SessionRevert.revert")(function* (input: RevertInput) {
|
||||||
|
yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
|
||||||
|
const all = yield* sessions.messages({ sessionID: input.sessionID })
|
||||||
|
let lastUser: MessageV2.User | undefined
|
||||||
|
const session = yield* sessions.get(input.sessionID)
|
||||||
|
|
||||||
|
let rev: Session.Info["revert"]
|
||||||
const patches: Snapshot.Patch[] = []
|
const patches: Snapshot.Patch[] = []
|
||||||
for (const msg of all) {
|
for (const msg of all) {
|
||||||
if (msg.info.role === "user") lastUser = msg.info
|
if (msg.info.role === "user") lastUser = msg.info
|
||||||
const remaining = []
|
const remaining = []
|
||||||
for (const part of msg.parts) {
|
for (const part of msg.parts) {
|
||||||
if (revert) {
|
if (rev) {
|
||||||
if (part.type === "patch") {
|
if (part.type === "patch") patches.push(part)
|
||||||
patches.push(part)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!revert) {
|
if (!rev) {
|
||||||
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
||||||
// if no useful parts left in message, same as reverting whole message
|
|
||||||
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
||||||
revert = {
|
rev = {
|
||||||
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
||||||
partID,
|
partID,
|
||||||
}
|
}
|
||||||
|
|
@ -53,51 +68,46 @@ export namespace SessionRevert {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (revert) {
|
if (!rev) return session
|
||||||
const session = await Session.get(input.sessionID)
|
|
||||||
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
|
rev.snapshot = session.revert?.snapshot ?? (yield* snap.track())
|
||||||
await Snapshot.revert(patches)
|
yield* snap.revert(patches)
|
||||||
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
|
if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string)
|
||||||
const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID)
|
const range = all.filter((msg) => msg.info.id >= rev!.messageID)
|
||||||
const diffs = await SessionSummary.computeDiff({ messages: rangeMessages })
|
const diffs = yield* Effect.promise(() => SessionSummary.computeDiff({ messages: range }))
|
||||||
await Storage.write(["session_diff", input.sessionID], diffs)
|
yield* storage.write(["session_diff", input.sessionID], diffs).pipe(Effect.ignore)
|
||||||
Bus.publish(Session.Event.Diff, {
|
yield* bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs })
|
||||||
|
yield* sessions.setRevert({
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
diff: diffs,
|
revert: rev,
|
||||||
})
|
|
||||||
return Session.setRevert({
|
|
||||||
sessionID: input.sessionID,
|
|
||||||
revert,
|
|
||||||
summary: {
|
summary: {
|
||||||
additions: diffs.reduce((sum, x) => sum + x.additions, 0),
|
additions: diffs.reduce((sum, x) => sum + x.additions, 0),
|
||||||
deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
|
deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
|
||||||
files: diffs.length,
|
files: diffs.length,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
return yield* sessions.get(input.sessionID)
|
||||||
return session
|
})
|
||||||
}
|
|
||||||
|
|
||||||
export async function unrevert(input: { sessionID: SessionID }) {
|
const unrevert = Effect.fn("SessionRevert.unrevert")(function* (input: { sessionID: SessionID }) {
|
||||||
log.info("unreverting", input)
|
log.info("unreverting", input)
|
||||||
await SessionPrompt.assertNotBusy(input.sessionID)
|
yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
|
||||||
const session = await Session.get(input.sessionID)
|
const session = yield* sessions.get(input.sessionID)
|
||||||
if (!session.revert) return session
|
if (!session.revert) return session
|
||||||
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
if (session.revert.snapshot) yield* snap.restore(session.revert!.snapshot!)
|
||||||
return Session.clearRevert(input.sessionID)
|
yield* sessions.clearRevert(input.sessionID)
|
||||||
}
|
return yield* sessions.get(input.sessionID)
|
||||||
|
})
|
||||||
|
|
||||||
export async function cleanup(session: Session.Info) {
|
const cleanup = Effect.fn("SessionRevert.cleanup")(function* (session: Session.Info) {
|
||||||
if (!session.revert) return
|
if (!session.revert) return
|
||||||
const sessionID = session.id
|
const sessionID = session.id
|
||||||
const msgs = await Session.messages({ sessionID })
|
const msgs = yield* sessions.messages({ sessionID })
|
||||||
const messageID = session.revert.messageID
|
const messageID = session.revert.messageID
|
||||||
const remove = [] as MessageV2.WithParts[]
|
const remove = [] as MessageV2.WithParts[]
|
||||||
let target: MessageV2.WithParts | undefined
|
let target: MessageV2.WithParts | undefined
|
||||||
for (const msg of msgs) {
|
for (const msg of msgs) {
|
||||||
if (msg.info.id < messageID) {
|
if (msg.info.id < messageID) continue
|
||||||
continue
|
|
||||||
}
|
|
||||||
if (msg.info.id > messageID) {
|
if (msg.info.id > messageID) {
|
||||||
remove.push(msg)
|
remove.push(msg)
|
||||||
continue
|
continue
|
||||||
|
|
@ -110,26 +120,54 @@ export namespace SessionRevert {
|
||||||
}
|
}
|
||||||
for (const msg of remove) {
|
for (const msg of remove) {
|
||||||
SyncEvent.run(MessageV2.Event.Removed, {
|
SyncEvent.run(MessageV2.Event.Removed, {
|
||||||
sessionID: sessionID,
|
sessionID,
|
||||||
messageID: msg.info.id,
|
messageID: msg.info.id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if (session.revert.partID && target) {
|
if (session.revert.partID && target) {
|
||||||
const partID = session.revert.partID
|
const partID = session.revert.partID
|
||||||
const removeStart = target.parts.findIndex((part) => part.id === partID)
|
const idx = target.parts.findIndex((part) => part.id === partID)
|
||||||
if (removeStart >= 0) {
|
if (idx >= 0) {
|
||||||
const preserveParts = target.parts.slice(0, removeStart)
|
const removeParts = target.parts.slice(idx)
|
||||||
const removeParts = target.parts.slice(removeStart)
|
target.parts = target.parts.slice(0, idx)
|
||||||
target.parts = preserveParts
|
|
||||||
for (const part of removeParts) {
|
for (const part of removeParts) {
|
||||||
SyncEvent.run(MessageV2.Event.PartRemoved, {
|
SyncEvent.run(MessageV2.Event.PartRemoved, {
|
||||||
sessionID: sessionID,
|
sessionID,
|
||||||
messageID: target.info.id,
|
messageID: target.info.id,
|
||||||
partID: part.id,
|
partID: part.id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await Session.clearRevert(sessionID)
|
yield* sessions.clearRevert(sessionID)
|
||||||
|
})
|
||||||
|
|
||||||
|
return Service.of({ revert, unrevert, cleanup })
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
export const defaultLayer = Layer.unwrap(
|
||||||
|
Effect.sync(() =>
|
||||||
|
layer.pipe(
|
||||||
|
Layer.provide(Session.defaultLayer),
|
||||||
|
Layer.provide(Snapshot.defaultLayer),
|
||||||
|
Layer.provide(Storage.defaultLayer),
|
||||||
|
Layer.provide(Bus.layer),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
const { runPromise } = makeRuntime(Service, defaultLayer)
|
||||||
|
|
||||||
|
export async function revert(input: RevertInput) {
|
||||||
|
return runPromise((svc) => svc.revert(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function unrevert(input: { sessionID: SessionID }) {
|
||||||
|
return runPromise((svc) => svc.unrevert(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function cleanup(session: Session.Info) {
|
||||||
|
return runPromise((svc) => svc.cleanup(session))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,59 @@ import { Instance } from "../../src/project/instance"
|
||||||
import { MessageID, PartID } from "../../src/session/schema"
|
import { MessageID, PartID } from "../../src/session/schema"
|
||||||
import { tmpdir } from "../fixture/fixture"
|
import { tmpdir } from "../fixture/fixture"
|
||||||
|
|
||||||
const projectRoot = path.join(__dirname, "../..")
|
|
||||||
Log.init({ print: false })
|
Log.init({ print: false })
|
||||||
|
|
||||||
|
function user(sessionID: string, agent = "default") {
|
||||||
|
return Session.updateMessage({
|
||||||
|
id: MessageID.ascending(),
|
||||||
|
role: "user" as const,
|
||||||
|
sessionID: sessionID as any,
|
||||||
|
agent,
|
||||||
|
model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") },
|
||||||
|
time: { created: Date.now() },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function assistant(sessionID: string, parentID: string, dir: string) {
|
||||||
|
return Session.updateMessage({
|
||||||
|
id: MessageID.ascending(),
|
||||||
|
role: "assistant" as const,
|
||||||
|
sessionID: sessionID as any,
|
||||||
|
mode: "default",
|
||||||
|
agent: "default",
|
||||||
|
path: { cwd: dir, root: dir },
|
||||||
|
cost: 0,
|
||||||
|
tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
||||||
|
modelID: ModelID.make("gpt-4"),
|
||||||
|
providerID: ProviderID.make("openai"),
|
||||||
|
parentID: parentID as any,
|
||||||
|
time: { created: Date.now() },
|
||||||
|
finish: "end_turn",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function text(sessionID: string, messageID: string, content: string) {
|
||||||
|
return Session.updatePart({
|
||||||
|
id: PartID.ascending(),
|
||||||
|
messageID: messageID as any,
|
||||||
|
sessionID: sessionID as any,
|
||||||
|
type: "text" as const,
|
||||||
|
text: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function tool(sessionID: string, messageID: string) {
|
||||||
|
return Session.updatePart({
|
||||||
|
id: PartID.ascending(),
|
||||||
|
messageID: messageID as any,
|
||||||
|
sessionID: sessionID as any,
|
||||||
|
type: "tool" as const,
|
||||||
|
tool: "bash",
|
||||||
|
callID: "call-1",
|
||||||
|
state: { status: "completed" as const, input: {}, output: "done", title: "", metadata: {}, time: { start: 0, end: 1 } },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
describe("revert + compact workflow", () => {
|
describe("revert + compact workflow", () => {
|
||||||
test("should properly handle compact command after revert", async () => {
|
test("should properly handle compact command after revert", async () => {
|
||||||
await using tmp = await tmpdir({ git: true })
|
await using tmp = await tmpdir({ git: true })
|
||||||
|
|
@ -283,4 +333,98 @@ describe("revert + compact workflow", () => {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("cleanup with partID removes parts from the revert point onward", async () => {
|
||||||
|
await using tmp = await tmpdir({ git: true })
|
||||||
|
await Instance.provide({
|
||||||
|
directory: tmp.path,
|
||||||
|
fn: async () => {
|
||||||
|
const session = await Session.create({})
|
||||||
|
const sid = session.id
|
||||||
|
|
||||||
|
const u1 = await user(sid)
|
||||||
|
const p1 = await text(sid, u1.id, "first part")
|
||||||
|
const p2 = await tool(sid, u1.id)
|
||||||
|
const p3 = await text(sid, u1.id, "third part")
|
||||||
|
|
||||||
|
// Set revert state pointing at a specific part
|
||||||
|
await Session.setRevert({
|
||||||
|
sessionID: sid,
|
||||||
|
revert: { messageID: u1.id, partID: p2.id },
|
||||||
|
summary: { additions: 0, deletions: 0, files: 0 },
|
||||||
|
})
|
||||||
|
|
||||||
|
const info = await Session.get(sid)
|
||||||
|
await SessionRevert.cleanup(info)
|
||||||
|
|
||||||
|
const msgs = await Session.messages({ sessionID: sid })
|
||||||
|
expect(msgs.length).toBe(1)
|
||||||
|
// Only the first part should remain (before the revert partID)
|
||||||
|
expect(msgs[0].parts.length).toBe(1)
|
||||||
|
expect(msgs[0].parts[0].id).toBe(p1.id)
|
||||||
|
|
||||||
|
const cleared = await Session.get(sid)
|
||||||
|
expect(cleared.revert).toBeUndefined()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test("cleanup removes messages after revert point but keeps earlier ones", async () => {
|
||||||
|
await using tmp = await tmpdir({ git: true })
|
||||||
|
await Instance.provide({
|
||||||
|
directory: tmp.path,
|
||||||
|
fn: async () => {
|
||||||
|
const session = await Session.create({})
|
||||||
|
const sid = session.id
|
||||||
|
|
||||||
|
const u1 = await user(sid)
|
||||||
|
await text(sid, u1.id, "hello")
|
||||||
|
const a1 = await assistant(sid, u1.id, tmp.path)
|
||||||
|
await text(sid, a1.id, "hi back")
|
||||||
|
|
||||||
|
const u2 = await user(sid)
|
||||||
|
await text(sid, u2.id, "second question")
|
||||||
|
const a2 = await assistant(sid, u2.id, tmp.path)
|
||||||
|
await text(sid, a2.id, "second answer")
|
||||||
|
|
||||||
|
// Revert from u2 onward
|
||||||
|
await Session.setRevert({
|
||||||
|
sessionID: sid,
|
||||||
|
revert: { messageID: u2.id },
|
||||||
|
summary: { additions: 0, deletions: 0, files: 0 },
|
||||||
|
})
|
||||||
|
|
||||||
|
const info = await Session.get(sid)
|
||||||
|
await SessionRevert.cleanup(info)
|
||||||
|
|
||||||
|
const msgs = await Session.messages({ sessionID: sid })
|
||||||
|
const ids = msgs.map((m) => m.info.id)
|
||||||
|
expect(ids).toContain(u1.id)
|
||||||
|
expect(ids).toContain(a1.id)
|
||||||
|
expect(ids).not.toContain(u2.id)
|
||||||
|
expect(ids).not.toContain(a2.id)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test("cleanup is a no-op when session has no revert state", async () => {
|
||||||
|
await using tmp = await tmpdir({ git: true })
|
||||||
|
await Instance.provide({
|
||||||
|
directory: tmp.path,
|
||||||
|
fn: async () => {
|
||||||
|
const session = await Session.create({})
|
||||||
|
const sid = session.id
|
||||||
|
|
||||||
|
const u1 = await user(sid)
|
||||||
|
await text(sid, u1.id, "hello")
|
||||||
|
|
||||||
|
const info = await Session.get(sid)
|
||||||
|
expect(info.revert).toBeUndefined()
|
||||||
|
await SessionRevert.cleanup(info)
|
||||||
|
|
||||||
|
const msgs = await Session.messages({ sessionID: sid })
|
||||||
|
expect(msgs.length).toBe(1)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue