diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index 5ad143c93e..443c47a534 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -256,9 +256,17 @@ export namespace Session { export const touch = fn(Identifier.schema("session"), async (sessionID) => { const now = Date.now() - Database.use((db) => db.update(SessionTable).set({ time_updated: now }).where(eq(SessionTable.id, sessionID)).run()) - const info = await get(sessionID) - Bus.publish(Event.Updated, { info }) + Database.use((db) => { + const row = db + .update(SessionTable) + .set({ time_updated: now }) + .where(eq(SessionTable.id, sessionID)) + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + }) }) export async function createNext(input: { @@ -283,9 +291,13 @@ export namespace Session { }, } log.info("created", result) - Database.use((db) => db.insert(SessionTable).values(toRow(result)).run()) - Bus.publish(Event.Created, { - info: result, + Database.use((db) => { + db.insert(SessionTable).values(toRow(result)).run() + Database.effect(() => + Bus.publish(Event.Created, { + info: result, + }), + ) }) const cfg = await Config.get() if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto")) @@ -323,9 +335,12 @@ export namespace Session { } const { ShareNext } = await import("@/share/share-next") const share = await ShareNext.create(id) - Database.use((db) => db.update(SessionTable).set({ share_url: share.url }).where(eq(SessionTable.id, id)).run()) - const info = await get(id) - Bus.publish(Event.Updated, { info }) + Database.use((db) => { + const row = db.update(SessionTable).set({ share_url: share.url }).where(eq(SessionTable.id, id)).returning().get() + if (!row) throw new NotFoundError({ message: `Session not found: ${id}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + }) return share }) @@ -333,9 +348,12 @@ export namespace Session { // Use ShareNext to remove the share (same as share function uses ShareNext to create) const { ShareNext } = await import("@/share/share-next") await ShareNext.remove(id) - Database.use((db) => db.update(SessionTable).set({ share_url: null }).where(eq(SessionTable.id, id)).run()) - const info = await get(id) - Bus.publish(Event.Updated, { info }) + Database.use((db) => { + const row = db.update(SessionTable).set({ share_url: null }).where(eq(SessionTable.id, id)).returning().get() + if (!row) throw new NotFoundError({ message: `Session not found: ${id}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + }) }) export const setTitle = fn( @@ -344,12 +362,18 @@ export namespace Session { title: z.string(), }), async (input) => { - Database.use((db) => - db.update(SessionTable).set({ title: input.title }).where(eq(SessionTable.id, input.sessionID)).run(), - ) - const info = await get(input.sessionID) - Bus.publish(Event.Updated, { info }) - return info + return Database.use((db) => { + const row = db + .update(SessionTable) + .set({ title: input.title }) + .where(eq(SessionTable.id, input.sessionID)) + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }, ) @@ -359,12 +383,18 @@ export namespace Session { time: z.number().optional(), }), async (input) => { - Database.use((db) => - db.update(SessionTable).set({ time_archived: input.time }).where(eq(SessionTable.id, input.sessionID)).run(), - ) - const info = await get(input.sessionID) - Bus.publish(Event.Updated, { info }) - return info + return Database.use((db) => { + const row = db + .update(SessionTable) + .set({ time_archived: input.time }) + .where(eq(SessionTable.id, input.sessionID)) + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }, ) @@ -374,16 +404,18 @@ export namespace Session { permission: PermissionNext.Ruleset, }), async (input) => { - Database.use((db) => - db + return Database.use((db) => { + const row = db .update(SessionTable) .set({ permission: input.permission, time_updated: Date.now() }) .where(eq(SessionTable.id, input.sessionID)) - .run(), - ) - const info = await get(input.sessionID) - Bus.publish(Event.Updated, { info }) - return info + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }, ) @@ -394,8 +426,8 @@ export namespace Session { summary: Info.shape.summary, }), async (input) => { - Database.use((db) => - db + return Database.use((db) => { + const row = db .update(SessionTable) .set({ revert_message_id: input.revert?.messageID ?? null, @@ -408,17 +440,19 @@ export namespace Session { time_updated: Date.now(), }) .where(eq(SessionTable.id, input.sessionID)) - .run(), - ) - const info = await get(input.sessionID) - Bus.publish(Event.Updated, { info }) - return info + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }, ) export const clearRevert = fn(Identifier.schema("session"), async (sessionID) => { - Database.use((db) => - db + return Database.use((db) => { + const row = db .update(SessionTable) .set({ revert_message_id: null, @@ -428,11 +462,13 @@ export namespace Session { time_updated: Date.now(), }) .where(eq(SessionTable.id, sessionID)) - .run(), - ) - const info = await get(sessionID) - Bus.publish(Event.Updated, { info }) - return info + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }) export const setSummary = fn( @@ -441,8 +477,8 @@ export namespace Session { summary: Info.shape.summary, }), async (input) => { - Database.use((db) => - db + return Database.use((db) => { + const row = db .update(SessionTable) .set({ summary_additions: input.summary?.additions, @@ -451,11 +487,13 @@ export namespace Session { time_updated: Date.now(), }) .where(eq(SessionTable.id, input.sessionID)) - .run(), - ) - const info = await get(input.sessionID) - Bus.publish(Event.Updated, { info }) - return info + .returning() + .get() + if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` }) + const info = fromRow(row) + Database.effect(() => Bus.publish(Event.Updated, { info })) + return info + }) }, ) @@ -506,9 +544,13 @@ export namespace Session { } await unshare(sessionID).catch(() => {}) // CASCADE delete handles messages and parts automatically - Database.use((db) => db.delete(SessionTable).where(eq(SessionTable.id, sessionID)).run()) - Bus.publish(Event.Deleted, { - info: session, + Database.use((db) => { + db.delete(SessionTable).where(eq(SessionTable.id, sessionID)).run() + Database.effect(() => + Bus.publish(Event.Deleted, { + info: session, + }), + ) }) } catch (e) { log.error(e) @@ -517,9 +559,8 @@ export namespace Session { export const updateMessage = fn(MessageV2.Info, async (msg) => { const created_at = msg.role === "user" ? msg.time.created : msg.time.created - Database.use((db) => - db - .insert(MessageTable) + Database.use((db) => { + db.insert(MessageTable) .values({ id: msg.id, session_id: msg.sessionID, @@ -527,10 +568,12 @@ export namespace Session { data: msg, }) .onConflictDoUpdate({ target: MessageTable.id, set: { data: msg } }) - .run(), - ) - Bus.publish(MessageV2.Event.Updated, { - info: msg, + .run() + Database.effect(() => + Bus.publish(MessageV2.Event.Updated, { + info: msg, + }), + ) }) return msg }) @@ -542,10 +585,14 @@ export namespace Session { }), async (input) => { // CASCADE delete handles parts automatically - Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run()) - Bus.publish(MessageV2.Event.Removed, { - sessionID: input.sessionID, - messageID: input.messageID, + Database.use((db) => { + db.delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run() + Database.effect(() => + Bus.publish(MessageV2.Event.Removed, { + sessionID: input.sessionID, + messageID: input.messageID, + }), + ) }) return input.messageID }, @@ -558,11 +605,15 @@ export namespace Session { partID: Identifier.schema("part"), }), async (input) => { - Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, input.partID)).run()) - Bus.publish(MessageV2.Event.PartRemoved, { - sessionID: input.sessionID, - messageID: input.messageID, - partID: input.partID, + Database.use((db) => { + db.delete(PartTable).where(eq(PartTable.id, input.partID)).run() + Database.effect(() => + Bus.publish(MessageV2.Event.PartRemoved, { + sessionID: input.sessionID, + messageID: input.messageID, + partID: input.partID, + }), + ) }) return input.partID }, @@ -583,9 +634,8 @@ export namespace Session { export const updatePart = fn(UpdatePartInput, async (input) => { const part = "delta" in input ? input.part : input const delta = "delta" in input ? input.delta : undefined - Database.use((db) => - db - .insert(PartTable) + Database.use((db) => { + db.insert(PartTable) .values({ id: part.id, message_id: part.messageID, @@ -593,11 +643,13 @@ export namespace Session { data: part, }) .onConflictDoUpdate({ target: PartTable.id, set: { data: part } }) - .run(), - ) - Bus.publish(MessageV2.Event.PartUpdated, { - part, - delta, + .run() + Database.effect(() => + Bus.publish(MessageV2.Event.PartUpdated, { + part, + delta, + }), + ) }) return part }) diff --git a/packages/opencode/src/storage/db.ts b/packages/opencode/src/storage/db.ts index 1f6cc30807..f49028a18b 100644 --- a/packages/opencode/src/storage/db.ts +++ b/packages/opencode/src/storage/db.ts @@ -108,7 +108,7 @@ export namespace Database { } } - export function effect(fn: () => void | Promise) { + export function effect(fn: () => any | Promise) { try { ctx.use().effects.push(fn) } catch {