diff --git a/packages/opencode/package.json b/packages/opencode/package.json index 6c3f3a5cca..a3b00771d3 100644 --- a/packages/opencode/package.json +++ b/packages/opencode/package.json @@ -97,8 +97,8 @@ "chokidar": "4.0.3", "clipboardy": "4.0.0", "decimal.js": "10.5.0", - "drizzle-orm": "0.44.2", "diff": "catalog:", + "drizzle-orm": "0.44.2", "fuzzysort": "3.1.0", "gray-matter": "4.0.3", "hono": "catalog:", diff --git a/packages/opencode/src/cli/cmd/database.ts b/packages/opencode/src/cli/cmd/database.ts index 5b3c1485f3..949f128ba6 100644 --- a/packages/opencode/src/cli/cmd/database.ts +++ b/packages/opencode/src/cli/cmd/database.ts @@ -2,7 +2,7 @@ import type { Argv } from "yargs" import { cmd } from "./cmd" import { bootstrap } from "../bootstrap" import { UI } from "../ui" -import { db } from "../../storage/db" +import { Database } from "../../storage/db" import { ProjectTable } from "../../project/project.sql" import { Project } from "../../project/project" import { @@ -56,7 +56,7 @@ const ExportCommand = cmd({ // Export projects const projectDir = path.join(outDir, "project") await fs.mkdir(projectDir, { recursive: true }) - for (const row of db().select().from(ProjectTable).all()) { + for (const row of Database.use((db) => db.select().from(ProjectTable).all())) { const project = Project.fromRow(row) await Bun.write(path.join(projectDir, `${row.id}.json`), JSON.stringify(project, null, 2)) stats.projects++ @@ -64,7 +64,7 @@ const ExportCommand = cmd({ // Export sessions (organized by projectID) const sessionDir = path.join(outDir, "session") - for (const row of db().select().from(SessionTable).all()) { + for (const row of Database.use((db) => db.select().from(SessionTable).all())) { const dir = path.join(sessionDir, row.projectID) await fs.mkdir(dir, { recursive: true }) await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(Session.fromRow(row), null, 2)) @@ -73,7 +73,7 @@ const ExportCommand = cmd({ // Export messages (organized by sessionID) const messageDir = path.join(outDir, "message") - for (const row of db().select().from(MessageTable).all()) { + for (const row of Database.use((db) => db.select().from(MessageTable).all())) { const dir = path.join(messageDir, row.sessionID) await fs.mkdir(dir, { recursive: true }) await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(row.data, null, 2)) @@ -82,7 +82,7 @@ const ExportCommand = cmd({ // Export parts (organized by messageID) const partDir = path.join(outDir, "part") - for (const row of db().select().from(PartTable).all()) { + for (const row of Database.use((db) => db.select().from(PartTable).all())) { const dir = path.join(partDir, row.messageID) await fs.mkdir(dir, { recursive: true }) await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(row.data, null, 2)) @@ -92,7 +92,7 @@ const ExportCommand = cmd({ // Export session diffs const diffDir = path.join(outDir, "session_diff") await fs.mkdir(diffDir, { recursive: true }) - for (const row of db().select().from(SessionDiffTable).all()) { + for (const row of Database.use((db) => db.select().from(SessionDiffTable).all())) { await Bun.write(path.join(diffDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2)) stats.diffs++ } @@ -100,7 +100,7 @@ const ExportCommand = cmd({ // Export todos const todoDir = path.join(outDir, "todo") await fs.mkdir(todoDir, { recursive: true }) - for (const row of db().select().from(TodoTable).all()) { + for (const row of Database.use((db) => db.select().from(TodoTable).all())) { await Bun.write(path.join(todoDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2)) stats.todos++ } @@ -108,7 +108,7 @@ const ExportCommand = cmd({ // Export permissions const permDir = path.join(outDir, "permission") await fs.mkdir(permDir, { recursive: true }) - for (const row of db().select().from(PermissionTable).all()) { + for (const row of Database.use((db) => db.select().from(PermissionTable).all())) { await Bun.write(path.join(permDir, `${row.projectID}.json`), JSON.stringify(row.data, null, 2)) stats.permissions++ } @@ -116,7 +116,7 @@ const ExportCommand = cmd({ // Export session shares const sessionShareDir = path.join(outDir, "session_share") await fs.mkdir(sessionShareDir, { recursive: true }) - for (const row of db().select().from(SessionShareTable).all()) { + for (const row of Database.use((db) => db.select().from(SessionShareTable).all())) { await Bun.write(path.join(sessionShareDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2)) stats.sessionShares++ } @@ -124,7 +124,7 @@ const ExportCommand = cmd({ // Export shares const shareDir = path.join(outDir, "share") await fs.mkdir(shareDir, { recursive: true }) - for (const row of db().select().from(ShareTable).all()) { + for (const row of Database.use((db) => db.select().from(ShareTable).all())) { await Bun.write(path.join(shareDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2)) stats.shares++ } diff --git a/packages/opencode/src/cli/cmd/import.ts b/packages/opencode/src/cli/cmd/import.ts index c78776b9d7..7f97b70c31 100644 --- a/packages/opencode/src/cli/cmd/import.ts +++ b/packages/opencode/src/cli/cmd/import.ts @@ -2,7 +2,7 @@ import type { Argv } from "yargs" import { Session } from "../../session" import { cmd } from "./cmd" import { bootstrap } from "../bootstrap" -import { db } from "../../storage/db" +import { Database } from "../../storage/db" import { SessionTable, MessageTable, PartTable } from "../../session/session.sql" import { Instance } from "../../project/instance" import { EOL } from "os" @@ -82,31 +82,35 @@ export const ImportCommand = cmd({ return } - db().insert(SessionTable).values(Session.toRow(exportData.info)).onConflictDoNothing().run() + Database.use((db) => db.insert(SessionTable).values(Session.toRow(exportData.info)).onConflictDoNothing().run()) for (const msg of exportData.messages) { - db() - .insert(MessageTable) - .values({ - id: msg.info.id, - sessionID: exportData.info.id, - createdAt: msg.info.time?.created ?? Date.now(), - data: msg.info, - }) - .onConflictDoNothing() - .run() - - for (const part of msg.parts) { - db() - .insert(PartTable) + Database.use((db) => + db + .insert(MessageTable) .values({ - id: part.id, - messageID: msg.info.id, + id: msg.info.id, sessionID: exportData.info.id, - data: part, + createdAt: msg.info.time?.created ?? Date.now(), + data: msg.info, }) .onConflictDoNothing() - .run() + .run(), + ) + + for (const part of msg.parts) { + Database.use((db) => + db + .insert(PartTable) + .values({ + id: part.id, + messageID: msg.info.id, + sessionID: exportData.info.id, + data: part, + }) + .onConflictDoNothing() + .run(), + ) } } diff --git a/packages/opencode/src/cli/cmd/stats.ts b/packages/opencode/src/cli/cmd/stats.ts index 21ee97fa82..1f7263b325 100644 --- a/packages/opencode/src/cli/cmd/stats.ts +++ b/packages/opencode/src/cli/cmd/stats.ts @@ -2,7 +2,7 @@ import type { Argv } from "yargs" import { cmd } from "./cmd" import { Session } from "../../session" import { bootstrap } from "../bootstrap" -import { db } from "../../storage/db" +import { Database } from "../../storage/db" import { SessionTable } from "../../session/session.sql" import { Project } from "../../project/project" import { Instance } from "../../project/instance" @@ -84,7 +84,7 @@ async function getCurrentProject(): Promise { } async function getAllSessions(): Promise { - const rows = db().select().from(SessionTable).all() + const rows = Database.use((db) => db.select().from(SessionTable).all()) return rows.map((row) => Session.fromRow(row)) } diff --git a/packages/opencode/src/permission/next.ts b/packages/opencode/src/permission/next.ts index b625bb57fc..98840867b2 100644 --- a/packages/opencode/src/permission/next.ts +++ b/packages/opencode/src/permission/next.ts @@ -3,9 +3,8 @@ import { BusEvent } from "@/bus/bus-event" import { Config } from "@/config/config" import { Identifier } from "@/id/id" import { Instance } from "@/project/instance" -import { db } from "@/storage/db" +import { Database, eq } from "@/storage/db" import { PermissionTable } from "@/session/session.sql" -import { eq } from "drizzle-orm" import { fn } from "@/util/fn" import { Log } from "@/util/log" import { Wildcard } from "@/util/wildcard" @@ -109,7 +108,9 @@ export namespace PermissionNext { const state = Instance.state(() => { const projectID = Instance.project.id - const row = db().select().from(PermissionTable).where(eq(PermissionTable.projectID, projectID)).get() + const row = Database.use((db) => + db.select().from(PermissionTable).where(eq(PermissionTable.projectID, projectID)).get(), + ) const stored = row?.data ?? ([] as Ruleset) const pending: Record< diff --git a/packages/opencode/src/project/project.ts b/packages/opencode/src/project/project.ts index 71c6a9bc7b..fc940b9588 100644 --- a/packages/opencode/src/project/project.ts +++ b/packages/opencode/src/project/project.ts @@ -3,10 +3,9 @@ import fs from "fs/promises" import { Filesystem } from "../util/filesystem" import path from "path" import { $ } from "bun" -import { db } from "../storage/db" +import { Database, eq } from "../storage/db" import { ProjectTable } from "./project.sql" import { SessionTable } from "../session/session.sql" -import { eq } from "drizzle-orm" import { Log } from "../util/log" import { Flag } from "@/flag/flag" import { work } from "../util/queue" @@ -199,7 +198,7 @@ export namespace Project { } }) - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, id)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, id)).get()) const existing = await iife(async () => { if (row) return fromRow(row) const fresh: Info = { @@ -253,7 +252,9 @@ export namespace Project { time_initialized: result.time.initialized, sandboxes: result.sandboxes, } - db().insert(ProjectTable).values(insert).onConflictDoUpdate({ target: ProjectTable.id, set: updateSet }).run() + Database.use((db) => + db.insert(ProjectTable).values(insert).onConflictDoUpdate({ target: ProjectTable.id, set: updateSet }).run(), + ) GlobalBus.emit("event", { payload: { type: Event.Updated.type, @@ -294,10 +295,12 @@ export namespace Project { } async function migrateFromGlobal(newProjectID: string, worktree: string) { - const globalRow = db().select().from(ProjectTable).where(eq(ProjectTable.id, "global")).get() + const globalRow = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, "global")).get()) if (!globalRow) return - const globalSessions = db().select().from(SessionTable).where(eq(SessionTable.projectID, "global")).all() + const globalSessions = Database.use((db) => + db.select().from(SessionTable).where(eq(SessionTable.projectID, "global")).all(), + ) if (globalSessions.length === 0) return log.info("migrating sessions from global", { newProjectID, worktree, count: globalSessions.length }) @@ -307,32 +310,38 @@ export namespace Project { if (row.directory && row.directory !== worktree) return log.info("migrating session", { sessionID: row.id, from: "global", to: newProjectID }) - db().update(SessionTable).set({ projectID: newProjectID }).where(eq(SessionTable.id, row.id)).run() + Database.use((db) => + db.update(SessionTable).set({ projectID: newProjectID }).where(eq(SessionTable.id, row.id)).run(), + ) }).catch((error) => { log.error("failed to migrate sessions from global to project", { error, projectId: newProjectID }) }) } export function setInitialized(projectID: string) { - db() - .update(ProjectTable) - .set({ - time_initialized: Date.now(), - }) - .where(eq(ProjectTable.id, projectID)) - .run() + Database.use((db) => + db + .update(ProjectTable) + .set({ + time_initialized: Date.now(), + }) + .where(eq(ProjectTable.id, projectID)) + .run(), + ) } export function list() { - return db() - .select() - .from(ProjectTable) - .all() - .map((row) => fromRow(row)) + return Database.use((db) => + db + .select() + .from(ProjectTable) + .all() + .map((row) => fromRow(row)), + ) } export function get(projectID: string): Info | undefined { - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()) if (!row) return undefined return fromRow(row) } @@ -345,17 +354,19 @@ export namespace Project { commands: Info.shape.commands.optional(), }), async (input) => { - const result = db() - .update(ProjectTable) - .set({ - name: input.name, - icon_url: input.icon?.url, - icon_color: input.icon?.color, - time_updated: Date.now(), - }) - .where(eq(ProjectTable.id, input.projectID)) - .returning() - .get() + const result = Database.use((db) => + db + .update(ProjectTable) + .set({ + name: input.name, + icon_url: input.icon?.url, + icon_color: input.icon?.color, + time_updated: Date.now(), + }) + .where(eq(ProjectTable.id, input.projectID)) + .returning() + .get(), + ) if (!result) throw new Error(`Project not found: ${input.projectID}`) const data = fromRow(result) GlobalBus.emit("event", { @@ -369,7 +380,7 @@ export namespace Project { ) export async function sandboxes(projectID: string) { - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()) if (!row) return [] const data = fromRow(row) const valid: string[] = [] @@ -381,16 +392,18 @@ export namespace Project { } export async function addSandbox(projectID: string, directory: string) { - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()) if (!row) throw new Error(`Project not found: ${projectID}`) const sandboxes = [...row.sandboxes] if (!sandboxes.includes(directory)) sandboxes.push(directory) - const result = db() - .update(ProjectTable) - .set({ sandboxes, time_updated: Date.now() }) - .where(eq(ProjectTable.id, projectID)) - .returning() - .get() + const result = Database.use((db) => + db + .update(ProjectTable) + .set({ sandboxes, time_updated: Date.now() }) + .where(eq(ProjectTable.id, projectID)) + .returning() + .get(), + ) if (!result) throw new Error(`Project not found: ${projectID}`) const data = fromRow(result) GlobalBus.emit("event", { @@ -403,15 +416,17 @@ export namespace Project { } export async function removeSandbox(projectID: string, directory: string) { - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()) if (!row) throw new Error(`Project not found: ${projectID}`) const sandboxes = row.sandboxes.filter((s: string) => s !== directory) - const result = db() - .update(ProjectTable) - .set({ sandboxes, time_updated: Date.now() }) - .where(eq(ProjectTable.id, projectID)) - .returning() - .get() + const result = Database.use((db) => + db + .update(ProjectTable) + .set({ sandboxes, time_updated: Date.now() }) + .where(eq(ProjectTable.id, projectID)) + .returning() + .get(), + ) if (!result) throw new Error(`Project not found: ${projectID}`) const data = fromRow(result) GlobalBus.emit("event", { diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index 84b567b942..db06483eea 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -10,10 +10,9 @@ import { Flag } from "../flag/flag" import { Identifier } from "../id/id" import { Installation } from "../installation" -import { db, NotFoundError } from "../storage/db" +import { Database, NotFoundError, eq } from "../storage/db" import { SessionTable, MessageTable, PartTable, SessionDiffTable } from "./session.sql" import { ShareTable } from "../share/share.sql" -import { eq } from "drizzle-orm" import { Log } from "../util/log" import { MessageV2 } from "./message-v2" import { Instance } from "../project/instance" @@ -283,7 +282,7 @@ export namespace Session { }, } log.info("created", result) - db().insert(SessionTable).values(toRow(result)).run() + Database.use((db) => db.insert(SessionTable).values(toRow(result)).run()) Bus.publish(Event.Created, { info: result, }) @@ -312,13 +311,13 @@ export namespace Session { } export const get = fn(Identifier.schema("session"), async (id) => { - const row = db().select().from(SessionTable).where(eq(SessionTable.id, id)).get() + const row = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.id, id)).get()) if (!row) throw new NotFoundError({ message: `Session not found: ${id}` }) return fromRow(row) }) export const getShare = fn(Identifier.schema("session"), async (id) => { - const row = db().select().from(ShareTable).where(eq(ShareTable.sessionID, id)).get() + const row = Database.use((db) => db.select().from(ShareTable).where(eq(ShareTable.sessionID, id)).get()) return row?.data }) @@ -355,14 +354,14 @@ export namespace Session { }) export function update(id: string, editor: (session: Info) => void, options?: { touch?: boolean }) { - const row = db().select().from(SessionTable).where(eq(SessionTable.id, id)).get() + const row = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.id, id)).get()) if (!row) throw new Error(`Session not found: ${id}`) const data = fromRow(row) editor(data) if (options?.touch !== false) { data.time.updated = Date.now() } - db().update(SessionTable).set(toRow(data)).where(eq(SessionTable.id, id)).run() + Database.use((db) => db.update(SessionTable).set(toRow(data)).where(eq(SessionTable.id, id)).run()) Bus.publish(Event.Updated, { info: data, }) @@ -370,7 +369,9 @@ export namespace Session { } export const diff = fn(Identifier.schema("session"), async (sessionID) => { - const row = db().select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, sessionID)).get() + const row = Database.use((db) => + db.select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, sessionID)).get(), + ) return row?.data ?? [] }) @@ -392,14 +393,16 @@ export namespace Session { export function* list() { const project = Instance.project - const rows = db().select().from(SessionTable).where(eq(SessionTable.projectID, project.id)).all() + const rows = Database.use((db) => + db.select().from(SessionTable).where(eq(SessionTable.projectID, project.id)).all(), + ) for (const row of rows) { yield fromRow(row) } } export const children = fn(Identifier.schema("session"), async (parentID) => { - const rows = db().select().from(SessionTable).where(eq(SessionTable.parentID, parentID)).all() + const rows = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.parentID, parentID)).all()) return rows.map((row) => fromRow(row)) }) @@ -412,7 +415,7 @@ export namespace Session { } await unshare(sessionID).catch(() => {}) // CASCADE delete handles messages and parts automatically - db().delete(SessionTable).where(eq(SessionTable.id, sessionID)).run() + Database.use((db) => db.delete(SessionTable).where(eq(SessionTable.id, sessionID)).run()) Bus.publish(Event.Deleted, { info: session, }) @@ -423,16 +426,18 @@ export namespace Session { export const updateMessage = fn(MessageV2.Info, async (msg) => { const createdAt = msg.role === "user" ? msg.time.created : msg.time.created - db() - .insert(MessageTable) - .values({ - id: msg.id, - sessionID: msg.sessionID, - createdAt, - data: msg, - }) - .onConflictDoUpdate({ target: MessageTable.id, set: { data: msg } }) - .run() + Database.use((db) => + db + .insert(MessageTable) + .values({ + id: msg.id, + sessionID: msg.sessionID, + createdAt, + data: msg, + }) + .onConflictDoUpdate({ target: MessageTable.id, set: { data: msg } }) + .run(), + ) Bus.publish(MessageV2.Event.Updated, { info: msg, }) @@ -446,7 +451,7 @@ export namespace Session { }), async (input) => { // CASCADE delete handles parts automatically - db().delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run() + Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run()) Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: input.messageID, @@ -462,7 +467,7 @@ export namespace Session { partID: Identifier.schema("part"), }), async (input) => { - db().delete(PartTable).where(eq(PartTable.id, input.partID)).run() + Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, input.partID)).run()) Bus.publish(MessageV2.Event.PartRemoved, { sessionID: input.sessionID, messageID: input.messageID, @@ -487,16 +492,18 @@ export namespace Session { export const updatePart = fn(UpdatePartInput, async (input) => { const part = "delta" in input ? input.part : input const delta = "delta" in input ? input.delta : undefined - db() - .insert(PartTable) - .values({ - id: part.id, - messageID: part.messageID, - sessionID: part.sessionID, - data: part, - }) - .onConflictDoUpdate({ target: PartTable.id, set: { data: part } }) - .run() + Database.use((db) => + db + .insert(PartTable) + .values({ + id: part.id, + messageID: part.messageID, + sessionID: part.sessionID, + data: part, + }) + .onConflictDoUpdate({ target: PartTable.id, set: { data: part } }) + .run(), + ) Bus.publish(MessageV2.Event.PartUpdated, { part, delta, diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 2dab09918a..e92252400c 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -6,9 +6,8 @@ import { Identifier } from "../id/id" import { LSP } from "../lsp" import { Snapshot } from "@/snapshot" import { fn } from "@/util/fn" -import { db } from "@/storage/db" +import { Database, eq, desc } from "@/storage/db" import { MessageTable, PartTable } from "./session.sql" -import { eq, desc } from "drizzle-orm" import { ProviderTransform } from "@/provider/transform" import { STATUS_CODES } from "http" import { iife } from "@/util/iife" @@ -609,12 +608,14 @@ export namespace MessageV2 { } export const stream = fn(Identifier.schema("session"), async function* (sessionID) { - const rows = db() - .select() - .from(MessageTable) - .where(eq(MessageTable.sessionID, sessionID)) - .orderBy(desc(MessageTable.createdAt)) - .all() + const rows = Database.use((db) => + db + .select() + .from(MessageTable) + .where(eq(MessageTable.sessionID, sessionID)) + .orderBy(desc(MessageTable.createdAt)) + .all(), + ) for (const row of rows) { yield { info: row.data, @@ -624,7 +625,7 @@ export namespace MessageV2 { }) export const parts = fn(Identifier.schema("message"), async (messageID) => { - const rows = db().select().from(PartTable).where(eq(PartTable.messageID, messageID)).all() + const rows = Database.use((db) => db.select().from(PartTable).where(eq(PartTable.messageID, messageID)).all()) const result = rows.map((row) => row.data) result.sort((a, b) => (a.id > b.id ? 1 : -1)) return result @@ -636,7 +637,7 @@ export namespace MessageV2 { messageID: Identifier.schema("message"), }), async (input) => { - const row = db().select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get() + const row = Database.use((db) => db.select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get()) if (!row) throw new Error(`Message not found: ${input.messageID}`) return { info: row.data, diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index fb6e0a5ec3..6a27fc7b9a 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -5,9 +5,8 @@ import { MessageV2 } from "./message-v2" import { Session } from "." import { Log } from "../util/log" import { splitWhen } from "remeda" -import { db } from "../storage/db" +import { Database, eq } from "../storage/db" import { SessionDiffTable, MessageTable, PartTable } from "./session.sql" -import { eq } from "drizzle-orm" import { Bus } from "../bus" import { SessionPrompt } from "./prompt" import { SessionSummary } from "./summary" @@ -62,11 +61,13 @@ export namespace SessionRevert { if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID) const diffs = await SessionSummary.computeDiff({ messages: rangeMessages }) - db() - .insert(SessionDiffTable) - .values({ sessionID: input.sessionID, data: diffs }) - .onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } }) - .run() + Database.use((db) => + db + .insert(SessionDiffTable) + .values({ sessionID: input.sessionID, data: diffs }) + .onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } }) + .run(), + ) Bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs, @@ -103,7 +104,7 @@ export namespace SessionRevert { const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID) msgs = preserve for (const msg of remove) { - db().delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run() + Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run()) await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id }) } const last = preserve.at(-1) @@ -112,7 +113,7 @@ export namespace SessionRevert { const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID) last.parts = preserveParts for (const part of removeParts) { - db().delete(PartTable).where(eq(PartTable.id, part.id)).run() + Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, part.id)).run()) await Bus.publish(MessageV2.Event.PartRemoved, { sessionID: sessionID, messageID: last.info.id, diff --git a/packages/opencode/src/session/summary.ts b/packages/opencode/src/session/summary.ts index a79850046d..2472eee7a3 100644 --- a/packages/opencode/src/session/summary.ts +++ b/packages/opencode/src/session/summary.ts @@ -11,9 +11,8 @@ import { Snapshot } from "@/snapshot" import { Log } from "@/util/log" import path from "path" import { Instance } from "@/project/instance" -import { db } from "@/storage/db" +import { Database, eq } from "@/storage/db" import { SessionDiffTable } from "./session.sql" -import { eq } from "drizzle-orm" import { Bus } from "@/bus" import { LLM } from "./llm" @@ -56,11 +55,13 @@ export namespace SessionSummary { files: diffs.length, } }) - db() - .insert(SessionDiffTable) - .values({ sessionID: input.sessionID, data: diffs }) - .onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } }) - .run() + Database.use((db) => + db + .insert(SessionDiffTable) + .values({ sessionID: input.sessionID, data: diffs }) + .onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } }) + .run(), + ) Bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs, @@ -122,7 +123,9 @@ export namespace SessionSummary { messageID: Identifier.schema("message").optional(), }), async (input) => { - const row = db().select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, input.sessionID)).get() + const row = Database.use((db) => + db.select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, input.sessionID)).get(), + ) return row?.data ?? [] }, ) diff --git a/packages/opencode/src/session/todo.ts b/packages/opencode/src/session/todo.ts index 3280744662..8ba5a0281c 100644 --- a/packages/opencode/src/session/todo.ts +++ b/packages/opencode/src/session/todo.ts @@ -1,9 +1,8 @@ import { BusEvent } from "@/bus/bus-event" import { Bus } from "@/bus" import z from "zod" -import { db } from "../storage/db" +import { Database, eq } from "../storage/db" import { TodoTable } from "./session.sql" -import { eq } from "drizzle-orm" export namespace Todo { export const Info = z @@ -27,16 +26,18 @@ export namespace Todo { } export function update(input: { sessionID: string; todos: Info[] }) { - db() - .insert(TodoTable) - .values({ sessionID: input.sessionID, data: input.todos }) - .onConflictDoUpdate({ target: TodoTable.sessionID, set: { data: input.todos } }) - .run() + Database.use((db) => + db + .insert(TodoTable) + .values({ sessionID: input.sessionID, data: input.todos }) + .onConflictDoUpdate({ target: TodoTable.sessionID, set: { data: input.todos } }) + .run(), + ) Bus.publish(Event.Updated, input) } export function get(sessionID: string) { - const row = db().select().from(TodoTable).where(eq(TodoTable.sessionID, sessionID)).get() + const row = Database.use((db) => db.select().from(TodoTable).where(eq(TodoTable.sessionID, sessionID)).get()) return row?.data ?? [] } } diff --git a/packages/opencode/src/share/share-next.ts b/packages/opencode/src/share/share-next.ts index 2d16820459..0f18cb974d 100644 --- a/packages/opencode/src/share/share-next.ts +++ b/packages/opencode/src/share/share-next.ts @@ -4,9 +4,8 @@ import { ulid } from "ulid" import { Provider } from "@/provider/provider" import { Session } from "@/session" import { MessageV2 } from "@/session/message-v2" -import { db } from "@/storage/db" +import { Database, eq } from "@/storage/db" import { SessionShareTable } from "./share.sql" -import { eq } from "drizzle-orm" import { Log } from "@/util/log" import type * as SDK from "@opencode-ai/sdk/v2" @@ -79,17 +78,21 @@ export namespace ShareNext { }) .then((x) => x.json()) .then((x) => x as { id: string; url: string; secret: string }) - db() - .insert(SessionShareTable) - .values({ sessionID, data: result }) - .onConflictDoUpdate({ target: SessionShareTable.sessionID, set: { data: result } }) - .run() + Database.use((db) => + db + .insert(SessionShareTable) + .values({ sessionID, data: result }) + .onConflictDoUpdate({ target: SessionShareTable.sessionID, set: { data: result } }) + .run(), + ) fullSync(sessionID) return result } function get(sessionID: string) { - const row = db().select().from(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).get() + const row = Database.use((db) => + db.select().from(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).get(), + ) return row?.data } @@ -166,7 +169,7 @@ export namespace ShareNext { secret: share.secret, }), }) - db().delete(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).run() + Database.use((db) => db.delete(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).run()) } async function fullSync(sessionID: string) { diff --git a/packages/opencode/src/storage/db.ts b/packages/opencode/src/storage/db.ts index 3c1a159305..475da9475b 100644 --- a/packages/opencode/src/storage/db.ts +++ b/packages/opencode/src/storage/db.ts @@ -1,5 +1,9 @@ -import { Database } from "bun:sqlite" -import { drizzle } from "drizzle-orm/bun-sqlite" +import { Database as BunDatabase } from "bun:sqlite" +import { drizzle, type BunSQLiteDatabase } from "drizzle-orm/bun-sqlite" +import type { SQLiteTransaction } from "drizzle-orm/sqlite-core" +import type { ExtractTablesWithRelations } from "drizzle-orm" +export * from "drizzle-orm" +import { Context } from "../util/context" import { lazy } from "../util/lazy" import { Global } from "../global" import { Log } from "../util/log" @@ -18,29 +22,85 @@ export const NotFoundError = NamedError.create( const log = Log.create({ service: "db" }) -export type DB = ReturnType +export namespace Database { + export type Transaction = SQLiteTransaction< + "sync", + void, + Record, + ExtractTablesWithRelations> + > -const connection = lazy(() => { - const dbPath = path.join(Global.Path.data, "opencode.db") - log.info("opening database", { path: dbPath }) + type Client = BunSQLiteDatabase> - const sqlite = new Database(dbPath, { create: true }) + const client = lazy(() => { + const dbPath = path.join(Global.Path.data, "opencode.db") + log.info("opening database", { path: dbPath }) - sqlite.run("PRAGMA journal_mode = WAL") - sqlite.run("PRAGMA synchronous = NORMAL") - sqlite.run("PRAGMA busy_timeout = 5000") - sqlite.run("PRAGMA cache_size = -64000") - sqlite.run("PRAGMA foreign_keys = ON") + const sqlite = new BunDatabase(dbPath, { create: true }) - migrate(sqlite) + sqlite.run("PRAGMA journal_mode = WAL") + sqlite.run("PRAGMA synchronous = NORMAL") + sqlite.run("PRAGMA busy_timeout = 5000") + sqlite.run("PRAGMA cache_size = -64000") + sqlite.run("PRAGMA foreign_keys = ON") - // Run JSON migration asynchronously after schema is ready - migrateFromJson(sqlite).catch((e) => log.error("json migration failed", { error: e })) + migrate(sqlite) - return drizzle(sqlite) -}) + migrateFromJson(sqlite).catch((e) => log.error("json migration failed", { error: e })) -function migrate(sqlite: Database) { + return drizzle(sqlite) + }) + + export type TxOrDb = Transaction | Client + + const TransactionContext = Context.create<{ + tx: TxOrDb + effects: (() => void | Promise)[] + }>("database") + + export function use(callback: (trx: TxOrDb) => T): T { + try { + const { tx } = TransactionContext.use() + return callback(tx) + } catch (err) { + if (err instanceof Context.NotFound) { + const effects: (() => void | Promise)[] = [] + const result = TransactionContext.provide({ effects, tx: client() }, () => callback(client())) + for (const effect of effects) effect() + return result + } + throw err + } + } + + export function effect(effect: () => void | Promise) { + try { + const { effects } = TransactionContext.use() + effects.push(effect) + } catch { + effect() + } + } + + export function transaction(callback: (tx: TxOrDb) => T): T { + try { + const { tx } = TransactionContext.use() + return callback(tx) + } catch (err) { + if (err instanceof Context.NotFound) { + const effects: (() => void | Promise)[] = [] + const result = client().transaction((tx) => { + return TransactionContext.provide({ tx, effects }, () => callback(tx)) + }) + for (const effect of effects) effect() + return result + } + throw err + } + } +} + +function migrate(sqlite: BunDatabase) { sqlite.exec(` CREATE TABLE IF NOT EXISTS _migrations ( name TEXT PRIMARY KEY, @@ -59,8 +119,6 @@ function migrate(sqlite: Database) { if (applied.has(migration.name)) continue log.info("applying migration", { name: migration.name }) - // Split by statement breakpoint and execute each statement - // Use IF NOT EXISTS variants to handle partial migrations const statements = migration.sql.split("--> statement-breakpoint") for (const stmt of statements) { const trimmed = stmt.trim() @@ -69,7 +127,6 @@ function migrate(sqlite: Database) { try { sqlite.exec(trimmed) } catch (e: any) { - // Ignore "already exists" errors for idempotency if (e?.message?.includes("already exists")) { log.info("skipping existing object", { statement: trimmed.slice(0, 50) }) continue @@ -81,7 +138,3 @@ function migrate(sqlite: Database) { sqlite.run("INSERT INTO _migrations (name, applied_at) VALUES (?, ?)", [migration.name, Date.now()]) } } - -export function db() { - return connection() -} diff --git a/packages/opencode/src/worktree/index.ts b/packages/opencode/src/worktree/index.ts index 30443d36b1..0afffda999 100644 --- a/packages/opencode/src/worktree/index.ts +++ b/packages/opencode/src/worktree/index.ts @@ -7,9 +7,8 @@ import { Global } from "../global" import { Instance } from "../project/instance" import { InstanceBootstrap } from "../project/bootstrap" import { Project } from "../project/project" -import { db } from "../storage/db" +import { Database, eq } from "../storage/db" import { ProjectTable } from "../project/project.sql" -import { eq } from "drizzle-orm" import { fn } from "../util/fn" import { Log } from "../util/log" import { BusEvent } from "@/bus/bus-event" @@ -320,7 +319,7 @@ export namespace Worktree { }, }) - const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get() + const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()) const project = row ? Project.fromRow(row) : undefined const startup = project?.commands?.start?.trim() ?? ""