feature/workspace-domain
Dax Raad 2026-01-25 20:16:56 -05:00
parent 1e7b4768b1
commit 105688bf90
14 changed files with 275 additions and 187 deletions

View File

@ -97,8 +97,8 @@
"chokidar": "4.0.3",
"clipboardy": "4.0.0",
"decimal.js": "10.5.0",
"drizzle-orm": "0.44.2",
"diff": "catalog:",
"drizzle-orm": "0.44.2",
"fuzzysort": "3.1.0",
"gray-matter": "4.0.3",
"hono": "catalog:",

View File

@ -2,7 +2,7 @@ import type { Argv } from "yargs"
import { cmd } from "./cmd"
import { bootstrap } from "../bootstrap"
import { UI } from "../ui"
import { db } from "../../storage/db"
import { Database } from "../../storage/db"
import { ProjectTable } from "../../project/project.sql"
import { Project } from "../../project/project"
import {
@ -56,7 +56,7 @@ const ExportCommand = cmd({
// Export projects
const projectDir = path.join(outDir, "project")
await fs.mkdir(projectDir, { recursive: true })
for (const row of db().select().from(ProjectTable).all()) {
for (const row of Database.use((db) => db.select().from(ProjectTable).all())) {
const project = Project.fromRow(row)
await Bun.write(path.join(projectDir, `${row.id}.json`), JSON.stringify(project, null, 2))
stats.projects++
@ -64,7 +64,7 @@ const ExportCommand = cmd({
// Export sessions (organized by projectID)
const sessionDir = path.join(outDir, "session")
for (const row of db().select().from(SessionTable).all()) {
for (const row of Database.use((db) => db.select().from(SessionTable).all())) {
const dir = path.join(sessionDir, row.projectID)
await fs.mkdir(dir, { recursive: true })
await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(Session.fromRow(row), null, 2))
@ -73,7 +73,7 @@ const ExportCommand = cmd({
// Export messages (organized by sessionID)
const messageDir = path.join(outDir, "message")
for (const row of db().select().from(MessageTable).all()) {
for (const row of Database.use((db) => db.select().from(MessageTable).all())) {
const dir = path.join(messageDir, row.sessionID)
await fs.mkdir(dir, { recursive: true })
await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(row.data, null, 2))
@ -82,7 +82,7 @@ const ExportCommand = cmd({
// Export parts (organized by messageID)
const partDir = path.join(outDir, "part")
for (const row of db().select().from(PartTable).all()) {
for (const row of Database.use((db) => db.select().from(PartTable).all())) {
const dir = path.join(partDir, row.messageID)
await fs.mkdir(dir, { recursive: true })
await Bun.write(path.join(dir, `${row.id}.json`), JSON.stringify(row.data, null, 2))
@ -92,7 +92,7 @@ const ExportCommand = cmd({
// Export session diffs
const diffDir = path.join(outDir, "session_diff")
await fs.mkdir(diffDir, { recursive: true })
for (const row of db().select().from(SessionDiffTable).all()) {
for (const row of Database.use((db) => db.select().from(SessionDiffTable).all())) {
await Bun.write(path.join(diffDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2))
stats.diffs++
}
@ -100,7 +100,7 @@ const ExportCommand = cmd({
// Export todos
const todoDir = path.join(outDir, "todo")
await fs.mkdir(todoDir, { recursive: true })
for (const row of db().select().from(TodoTable).all()) {
for (const row of Database.use((db) => db.select().from(TodoTable).all())) {
await Bun.write(path.join(todoDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2))
stats.todos++
}
@ -108,7 +108,7 @@ const ExportCommand = cmd({
// Export permissions
const permDir = path.join(outDir, "permission")
await fs.mkdir(permDir, { recursive: true })
for (const row of db().select().from(PermissionTable).all()) {
for (const row of Database.use((db) => db.select().from(PermissionTable).all())) {
await Bun.write(path.join(permDir, `${row.projectID}.json`), JSON.stringify(row.data, null, 2))
stats.permissions++
}
@ -116,7 +116,7 @@ const ExportCommand = cmd({
// Export session shares
const sessionShareDir = path.join(outDir, "session_share")
await fs.mkdir(sessionShareDir, { recursive: true })
for (const row of db().select().from(SessionShareTable).all()) {
for (const row of Database.use((db) => db.select().from(SessionShareTable).all())) {
await Bun.write(path.join(sessionShareDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2))
stats.sessionShares++
}
@ -124,7 +124,7 @@ const ExportCommand = cmd({
// Export shares
const shareDir = path.join(outDir, "share")
await fs.mkdir(shareDir, { recursive: true })
for (const row of db().select().from(ShareTable).all()) {
for (const row of Database.use((db) => db.select().from(ShareTable).all())) {
await Bun.write(path.join(shareDir, `${row.sessionID}.json`), JSON.stringify(row.data, null, 2))
stats.shares++
}

View File

@ -2,7 +2,7 @@ import type { Argv } from "yargs"
import { Session } from "../../session"
import { cmd } from "./cmd"
import { bootstrap } from "../bootstrap"
import { db } from "../../storage/db"
import { Database } from "../../storage/db"
import { SessionTable, MessageTable, PartTable } from "../../session/session.sql"
import { Instance } from "../../project/instance"
import { EOL } from "os"
@ -82,31 +82,35 @@ export const ImportCommand = cmd({
return
}
db().insert(SessionTable).values(Session.toRow(exportData.info)).onConflictDoNothing().run()
Database.use((db) => db.insert(SessionTable).values(Session.toRow(exportData.info)).onConflictDoNothing().run())
for (const msg of exportData.messages) {
db()
.insert(MessageTable)
.values({
id: msg.info.id,
sessionID: exportData.info.id,
createdAt: msg.info.time?.created ?? Date.now(),
data: msg.info,
})
.onConflictDoNothing()
.run()
for (const part of msg.parts) {
db()
.insert(PartTable)
Database.use((db) =>
db
.insert(MessageTable)
.values({
id: part.id,
messageID: msg.info.id,
id: msg.info.id,
sessionID: exportData.info.id,
data: part,
createdAt: msg.info.time?.created ?? Date.now(),
data: msg.info,
})
.onConflictDoNothing()
.run()
.run(),
)
for (const part of msg.parts) {
Database.use((db) =>
db
.insert(PartTable)
.values({
id: part.id,
messageID: msg.info.id,
sessionID: exportData.info.id,
data: part,
})
.onConflictDoNothing()
.run(),
)
}
}

View File

@ -2,7 +2,7 @@ import type { Argv } from "yargs"
import { cmd } from "./cmd"
import { Session } from "../../session"
import { bootstrap } from "../bootstrap"
import { db } from "../../storage/db"
import { Database } from "../../storage/db"
import { SessionTable } from "../../session/session.sql"
import { Project } from "../../project/project"
import { Instance } from "../../project/instance"
@ -84,7 +84,7 @@ async function getCurrentProject(): Promise<Project.Info> {
}
async function getAllSessions(): Promise<Session.Info[]> {
const rows = db().select().from(SessionTable).all()
const rows = Database.use((db) => db.select().from(SessionTable).all())
return rows.map((row) => Session.fromRow(row))
}

View File

@ -3,9 +3,8 @@ import { BusEvent } from "@/bus/bus-event"
import { Config } from "@/config/config"
import { Identifier } from "@/id/id"
import { Instance } from "@/project/instance"
import { db } from "@/storage/db"
import { Database, eq } from "@/storage/db"
import { PermissionTable } from "@/session/session.sql"
import { eq } from "drizzle-orm"
import { fn } from "@/util/fn"
import { Log } from "@/util/log"
import { Wildcard } from "@/util/wildcard"
@ -109,7 +108,9 @@ export namespace PermissionNext {
const state = Instance.state(() => {
const projectID = Instance.project.id
const row = db().select().from(PermissionTable).where(eq(PermissionTable.projectID, projectID)).get()
const row = Database.use((db) =>
db.select().from(PermissionTable).where(eq(PermissionTable.projectID, projectID)).get(),
)
const stored = row?.data ?? ([] as Ruleset)
const pending: Record<

View File

@ -3,10 +3,9 @@ import fs from "fs/promises"
import { Filesystem } from "../util/filesystem"
import path from "path"
import { $ } from "bun"
import { db } from "../storage/db"
import { Database, eq } from "../storage/db"
import { ProjectTable } from "./project.sql"
import { SessionTable } from "../session/session.sql"
import { eq } from "drizzle-orm"
import { Log } from "../util/log"
import { Flag } from "@/flag/flag"
import { work } from "../util/queue"
@ -199,7 +198,7 @@ export namespace Project {
}
})
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, id)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, id)).get())
const existing = await iife(async () => {
if (row) return fromRow(row)
const fresh: Info = {
@ -253,7 +252,9 @@ export namespace Project {
time_initialized: result.time.initialized,
sandboxes: result.sandboxes,
}
db().insert(ProjectTable).values(insert).onConflictDoUpdate({ target: ProjectTable.id, set: updateSet }).run()
Database.use((db) =>
db.insert(ProjectTable).values(insert).onConflictDoUpdate({ target: ProjectTable.id, set: updateSet }).run(),
)
GlobalBus.emit("event", {
payload: {
type: Event.Updated.type,
@ -294,10 +295,12 @@ export namespace Project {
}
async function migrateFromGlobal(newProjectID: string, worktree: string) {
const globalRow = db().select().from(ProjectTable).where(eq(ProjectTable.id, "global")).get()
const globalRow = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, "global")).get())
if (!globalRow) return
const globalSessions = db().select().from(SessionTable).where(eq(SessionTable.projectID, "global")).all()
const globalSessions = Database.use((db) =>
db.select().from(SessionTable).where(eq(SessionTable.projectID, "global")).all(),
)
if (globalSessions.length === 0) return
log.info("migrating sessions from global", { newProjectID, worktree, count: globalSessions.length })
@ -307,32 +310,38 @@ export namespace Project {
if (row.directory && row.directory !== worktree) return
log.info("migrating session", { sessionID: row.id, from: "global", to: newProjectID })
db().update(SessionTable).set({ projectID: newProjectID }).where(eq(SessionTable.id, row.id)).run()
Database.use((db) =>
db.update(SessionTable).set({ projectID: newProjectID }).where(eq(SessionTable.id, row.id)).run(),
)
}).catch((error) => {
log.error("failed to migrate sessions from global to project", { error, projectId: newProjectID })
})
}
export function setInitialized(projectID: string) {
db()
.update(ProjectTable)
.set({
time_initialized: Date.now(),
})
.where(eq(ProjectTable.id, projectID))
.run()
Database.use((db) =>
db
.update(ProjectTable)
.set({
time_initialized: Date.now(),
})
.where(eq(ProjectTable.id, projectID))
.run(),
)
}
export function list() {
return db()
.select()
.from(ProjectTable)
.all()
.map((row) => fromRow(row))
return Database.use((db) =>
db
.select()
.from(ProjectTable)
.all()
.map((row) => fromRow(row)),
)
}
export function get(projectID: string): Info | undefined {
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get())
if (!row) return undefined
return fromRow(row)
}
@ -345,17 +354,19 @@ export namespace Project {
commands: Info.shape.commands.optional(),
}),
async (input) => {
const result = db()
.update(ProjectTable)
.set({
name: input.name,
icon_url: input.icon?.url,
icon_color: input.icon?.color,
time_updated: Date.now(),
})
.where(eq(ProjectTable.id, input.projectID))
.returning()
.get()
const result = Database.use((db) =>
db
.update(ProjectTable)
.set({
name: input.name,
icon_url: input.icon?.url,
icon_color: input.icon?.color,
time_updated: Date.now(),
})
.where(eq(ProjectTable.id, input.projectID))
.returning()
.get(),
)
if (!result) throw new Error(`Project not found: ${input.projectID}`)
const data = fromRow(result)
GlobalBus.emit("event", {
@ -369,7 +380,7 @@ export namespace Project {
)
export async function sandboxes(projectID: string) {
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get())
if (!row) return []
const data = fromRow(row)
const valid: string[] = []
@ -381,16 +392,18 @@ export namespace Project {
}
export async function addSandbox(projectID: string, directory: string) {
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get())
if (!row) throw new Error(`Project not found: ${projectID}`)
const sandboxes = [...row.sandboxes]
if (!sandboxes.includes(directory)) sandboxes.push(directory)
const result = db()
.update(ProjectTable)
.set({ sandboxes, time_updated: Date.now() })
.where(eq(ProjectTable.id, projectID))
.returning()
.get()
const result = Database.use((db) =>
db
.update(ProjectTable)
.set({ sandboxes, time_updated: Date.now() })
.where(eq(ProjectTable.id, projectID))
.returning()
.get(),
)
if (!result) throw new Error(`Project not found: ${projectID}`)
const data = fromRow(result)
GlobalBus.emit("event", {
@ -403,15 +416,17 @@ export namespace Project {
}
export async function removeSandbox(projectID: string, directory: string) {
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get())
if (!row) throw new Error(`Project not found: ${projectID}`)
const sandboxes = row.sandboxes.filter((s: string) => s !== directory)
const result = db()
.update(ProjectTable)
.set({ sandboxes, time_updated: Date.now() })
.where(eq(ProjectTable.id, projectID))
.returning()
.get()
const result = Database.use((db) =>
db
.update(ProjectTable)
.set({ sandboxes, time_updated: Date.now() })
.where(eq(ProjectTable.id, projectID))
.returning()
.get(),
)
if (!result) throw new Error(`Project not found: ${projectID}`)
const data = fromRow(result)
GlobalBus.emit("event", {

View File

@ -10,10 +10,9 @@ import { Flag } from "../flag/flag"
import { Identifier } from "../id/id"
import { Installation } from "../installation"
import { db, NotFoundError } from "../storage/db"
import { Database, NotFoundError, eq } from "../storage/db"
import { SessionTable, MessageTable, PartTable, SessionDiffTable } from "./session.sql"
import { ShareTable } from "../share/share.sql"
import { eq } from "drizzle-orm"
import { Log } from "../util/log"
import { MessageV2 } from "./message-v2"
import { Instance } from "../project/instance"
@ -283,7 +282,7 @@ export namespace Session {
},
}
log.info("created", result)
db().insert(SessionTable).values(toRow(result)).run()
Database.use((db) => db.insert(SessionTable).values(toRow(result)).run())
Bus.publish(Event.Created, {
info: result,
})
@ -312,13 +311,13 @@ export namespace Session {
}
export const get = fn(Identifier.schema("session"), async (id) => {
const row = db().select().from(SessionTable).where(eq(SessionTable.id, id)).get()
const row = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.id, id)).get())
if (!row) throw new NotFoundError({ message: `Session not found: ${id}` })
return fromRow(row)
})
export const getShare = fn(Identifier.schema("session"), async (id) => {
const row = db().select().from(ShareTable).where(eq(ShareTable.sessionID, id)).get()
const row = Database.use((db) => db.select().from(ShareTable).where(eq(ShareTable.sessionID, id)).get())
return row?.data
})
@ -355,14 +354,14 @@ export namespace Session {
})
export function update(id: string, editor: (session: Info) => void, options?: { touch?: boolean }) {
const row = db().select().from(SessionTable).where(eq(SessionTable.id, id)).get()
const row = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.id, id)).get())
if (!row) throw new Error(`Session not found: ${id}`)
const data = fromRow(row)
editor(data)
if (options?.touch !== false) {
data.time.updated = Date.now()
}
db().update(SessionTable).set(toRow(data)).where(eq(SessionTable.id, id)).run()
Database.use((db) => db.update(SessionTable).set(toRow(data)).where(eq(SessionTable.id, id)).run())
Bus.publish(Event.Updated, {
info: data,
})
@ -370,7 +369,9 @@ export namespace Session {
}
export const diff = fn(Identifier.schema("session"), async (sessionID) => {
const row = db().select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, sessionID)).get()
const row = Database.use((db) =>
db.select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, sessionID)).get(),
)
return row?.data ?? []
})
@ -392,14 +393,16 @@ export namespace Session {
export function* list() {
const project = Instance.project
const rows = db().select().from(SessionTable).where(eq(SessionTable.projectID, project.id)).all()
const rows = Database.use((db) =>
db.select().from(SessionTable).where(eq(SessionTable.projectID, project.id)).all(),
)
for (const row of rows) {
yield fromRow(row)
}
}
export const children = fn(Identifier.schema("session"), async (parentID) => {
const rows = db().select().from(SessionTable).where(eq(SessionTable.parentID, parentID)).all()
const rows = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.parentID, parentID)).all())
return rows.map((row) => fromRow(row))
})
@ -412,7 +415,7 @@ export namespace Session {
}
await unshare(sessionID).catch(() => {})
// CASCADE delete handles messages and parts automatically
db().delete(SessionTable).where(eq(SessionTable.id, sessionID)).run()
Database.use((db) => db.delete(SessionTable).where(eq(SessionTable.id, sessionID)).run())
Bus.publish(Event.Deleted, {
info: session,
})
@ -423,16 +426,18 @@ export namespace Session {
export const updateMessage = fn(MessageV2.Info, async (msg) => {
const createdAt = msg.role === "user" ? msg.time.created : msg.time.created
db()
.insert(MessageTable)
.values({
id: msg.id,
sessionID: msg.sessionID,
createdAt,
data: msg,
})
.onConflictDoUpdate({ target: MessageTable.id, set: { data: msg } })
.run()
Database.use((db) =>
db
.insert(MessageTable)
.values({
id: msg.id,
sessionID: msg.sessionID,
createdAt,
data: msg,
})
.onConflictDoUpdate({ target: MessageTable.id, set: { data: msg } })
.run(),
)
Bus.publish(MessageV2.Event.Updated, {
info: msg,
})
@ -446,7 +451,7 @@ export namespace Session {
}),
async (input) => {
// CASCADE delete handles parts automatically
db().delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run()
Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, input.messageID)).run())
Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: input.messageID,
@ -462,7 +467,7 @@ export namespace Session {
partID: Identifier.schema("part"),
}),
async (input) => {
db().delete(PartTable).where(eq(PartTable.id, input.partID)).run()
Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, input.partID)).run())
Bus.publish(MessageV2.Event.PartRemoved, {
sessionID: input.sessionID,
messageID: input.messageID,
@ -487,16 +492,18 @@ export namespace Session {
export const updatePart = fn(UpdatePartInput, async (input) => {
const part = "delta" in input ? input.part : input
const delta = "delta" in input ? input.delta : undefined
db()
.insert(PartTable)
.values({
id: part.id,
messageID: part.messageID,
sessionID: part.sessionID,
data: part,
})
.onConflictDoUpdate({ target: PartTable.id, set: { data: part } })
.run()
Database.use((db) =>
db
.insert(PartTable)
.values({
id: part.id,
messageID: part.messageID,
sessionID: part.sessionID,
data: part,
})
.onConflictDoUpdate({ target: PartTable.id, set: { data: part } })
.run(),
)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
delta,

View File

@ -6,9 +6,8 @@ import { Identifier } from "../id/id"
import { LSP } from "../lsp"
import { Snapshot } from "@/snapshot"
import { fn } from "@/util/fn"
import { db } from "@/storage/db"
import { Database, eq, desc } from "@/storage/db"
import { MessageTable, PartTable } from "./session.sql"
import { eq, desc } from "drizzle-orm"
import { ProviderTransform } from "@/provider/transform"
import { STATUS_CODES } from "http"
import { iife } from "@/util/iife"
@ -609,12 +608,14 @@ export namespace MessageV2 {
}
export const stream = fn(Identifier.schema("session"), async function* (sessionID) {
const rows = db()
.select()
.from(MessageTable)
.where(eq(MessageTable.sessionID, sessionID))
.orderBy(desc(MessageTable.createdAt))
.all()
const rows = Database.use((db) =>
db
.select()
.from(MessageTable)
.where(eq(MessageTable.sessionID, sessionID))
.orderBy(desc(MessageTable.createdAt))
.all(),
)
for (const row of rows) {
yield {
info: row.data,
@ -624,7 +625,7 @@ export namespace MessageV2 {
})
export const parts = fn(Identifier.schema("message"), async (messageID) => {
const rows = db().select().from(PartTable).where(eq(PartTable.messageID, messageID)).all()
const rows = Database.use((db) => db.select().from(PartTable).where(eq(PartTable.messageID, messageID)).all())
const result = rows.map((row) => row.data)
result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result
@ -636,7 +637,7 @@ export namespace MessageV2 {
messageID: Identifier.schema("message"),
}),
async (input) => {
const row = db().select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get()
const row = Database.use((db) => db.select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get())
if (!row) throw new Error(`Message not found: ${input.messageID}`)
return {
info: row.data,

View File

@ -5,9 +5,8 @@ import { MessageV2 } from "./message-v2"
import { Session } from "."
import { Log } from "../util/log"
import { splitWhen } from "remeda"
import { db } from "../storage/db"
import { Database, eq } from "../storage/db"
import { SessionDiffTable, MessageTable, PartTable } from "./session.sql"
import { eq } from "drizzle-orm"
import { Bus } from "../bus"
import { SessionPrompt } from "./prompt"
import { SessionSummary } from "./summary"
@ -62,11 +61,13 @@ export namespace SessionRevert {
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID)
const diffs = await SessionSummary.computeDiff({ messages: rangeMessages })
db()
.insert(SessionDiffTable)
.values({ sessionID: input.sessionID, data: diffs })
.onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } })
.run()
Database.use((db) =>
db
.insert(SessionDiffTable)
.values({ sessionID: input.sessionID, data: diffs })
.onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } })
.run(),
)
Bus.publish(Session.Event.Diff, {
sessionID: input.sessionID,
diff: diffs,
@ -103,7 +104,7 @@ export namespace SessionRevert {
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
msgs = preserve
for (const msg of remove) {
db().delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run()
Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run())
await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
}
const last = preserve.at(-1)
@ -112,7 +113,7 @@ export namespace SessionRevert {
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
last.parts = preserveParts
for (const part of removeParts) {
db().delete(PartTable).where(eq(PartTable.id, part.id)).run()
Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, part.id)).run())
await Bus.publish(MessageV2.Event.PartRemoved, {
sessionID: sessionID,
messageID: last.info.id,

View File

@ -11,9 +11,8 @@ import { Snapshot } from "@/snapshot"
import { Log } from "@/util/log"
import path from "path"
import { Instance } from "@/project/instance"
import { db } from "@/storage/db"
import { Database, eq } from "@/storage/db"
import { SessionDiffTable } from "./session.sql"
import { eq } from "drizzle-orm"
import { Bus } from "@/bus"
import { LLM } from "./llm"
@ -56,11 +55,13 @@ export namespace SessionSummary {
files: diffs.length,
}
})
db()
.insert(SessionDiffTable)
.values({ sessionID: input.sessionID, data: diffs })
.onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } })
.run()
Database.use((db) =>
db
.insert(SessionDiffTable)
.values({ sessionID: input.sessionID, data: diffs })
.onConflictDoUpdate({ target: SessionDiffTable.sessionID, set: { data: diffs } })
.run(),
)
Bus.publish(Session.Event.Diff, {
sessionID: input.sessionID,
diff: diffs,
@ -122,7 +123,9 @@ export namespace SessionSummary {
messageID: Identifier.schema("message").optional(),
}),
async (input) => {
const row = db().select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, input.sessionID)).get()
const row = Database.use((db) =>
db.select().from(SessionDiffTable).where(eq(SessionDiffTable.sessionID, input.sessionID)).get(),
)
return row?.data ?? []
},
)

View File

@ -1,9 +1,8 @@
import { BusEvent } from "@/bus/bus-event"
import { Bus } from "@/bus"
import z from "zod"
import { db } from "../storage/db"
import { Database, eq } from "../storage/db"
import { TodoTable } from "./session.sql"
import { eq } from "drizzle-orm"
export namespace Todo {
export const Info = z
@ -27,16 +26,18 @@ export namespace Todo {
}
export function update(input: { sessionID: string; todos: Info[] }) {
db()
.insert(TodoTable)
.values({ sessionID: input.sessionID, data: input.todos })
.onConflictDoUpdate({ target: TodoTable.sessionID, set: { data: input.todos } })
.run()
Database.use((db) =>
db
.insert(TodoTable)
.values({ sessionID: input.sessionID, data: input.todos })
.onConflictDoUpdate({ target: TodoTable.sessionID, set: { data: input.todos } })
.run(),
)
Bus.publish(Event.Updated, input)
}
export function get(sessionID: string) {
const row = db().select().from(TodoTable).where(eq(TodoTable.sessionID, sessionID)).get()
const row = Database.use((db) => db.select().from(TodoTable).where(eq(TodoTable.sessionID, sessionID)).get())
return row?.data ?? []
}
}

View File

@ -4,9 +4,8 @@ import { ulid } from "ulid"
import { Provider } from "@/provider/provider"
import { Session } from "@/session"
import { MessageV2 } from "@/session/message-v2"
import { db } from "@/storage/db"
import { Database, eq } from "@/storage/db"
import { SessionShareTable } from "./share.sql"
import { eq } from "drizzle-orm"
import { Log } from "@/util/log"
import type * as SDK from "@opencode-ai/sdk/v2"
@ -79,17 +78,21 @@ export namespace ShareNext {
})
.then((x) => x.json())
.then((x) => x as { id: string; url: string; secret: string })
db()
.insert(SessionShareTable)
.values({ sessionID, data: result })
.onConflictDoUpdate({ target: SessionShareTable.sessionID, set: { data: result } })
.run()
Database.use((db) =>
db
.insert(SessionShareTable)
.values({ sessionID, data: result })
.onConflictDoUpdate({ target: SessionShareTable.sessionID, set: { data: result } })
.run(),
)
fullSync(sessionID)
return result
}
function get(sessionID: string) {
const row = db().select().from(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).get()
const row = Database.use((db) =>
db.select().from(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).get(),
)
return row?.data
}
@ -166,7 +169,7 @@ export namespace ShareNext {
secret: share.secret,
}),
})
db().delete(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).run()
Database.use((db) => db.delete(SessionShareTable).where(eq(SessionShareTable.sessionID, sessionID)).run())
}
async function fullSync(sessionID: string) {

View File

@ -1,5 +1,9 @@
import { Database } from "bun:sqlite"
import { drizzle } from "drizzle-orm/bun-sqlite"
import { Database as BunDatabase } from "bun:sqlite"
import { drizzle, type BunSQLiteDatabase } from "drizzle-orm/bun-sqlite"
import type { SQLiteTransaction } from "drizzle-orm/sqlite-core"
import type { ExtractTablesWithRelations } from "drizzle-orm"
export * from "drizzle-orm"
import { Context } from "../util/context"
import { lazy } from "../util/lazy"
import { Global } from "../global"
import { Log } from "../util/log"
@ -18,29 +22,85 @@ export const NotFoundError = NamedError.create(
const log = Log.create({ service: "db" })
export type DB = ReturnType<typeof drizzle>
export namespace Database {
export type Transaction = SQLiteTransaction<
"sync",
void,
Record<string, never>,
ExtractTablesWithRelations<Record<string, never>>
>
const connection = lazy(() => {
const dbPath = path.join(Global.Path.data, "opencode.db")
log.info("opening database", { path: dbPath })
type Client = BunSQLiteDatabase<Record<string, never>>
const sqlite = new Database(dbPath, { create: true })
const client = lazy(() => {
const dbPath = path.join(Global.Path.data, "opencode.db")
log.info("opening database", { path: dbPath })
sqlite.run("PRAGMA journal_mode = WAL")
sqlite.run("PRAGMA synchronous = NORMAL")
sqlite.run("PRAGMA busy_timeout = 5000")
sqlite.run("PRAGMA cache_size = -64000")
sqlite.run("PRAGMA foreign_keys = ON")
const sqlite = new BunDatabase(dbPath, { create: true })
migrate(sqlite)
sqlite.run("PRAGMA journal_mode = WAL")
sqlite.run("PRAGMA synchronous = NORMAL")
sqlite.run("PRAGMA busy_timeout = 5000")
sqlite.run("PRAGMA cache_size = -64000")
sqlite.run("PRAGMA foreign_keys = ON")
// Run JSON migration asynchronously after schema is ready
migrateFromJson(sqlite).catch((e) => log.error("json migration failed", { error: e }))
migrate(sqlite)
return drizzle(sqlite)
})
migrateFromJson(sqlite).catch((e) => log.error("json migration failed", { error: e }))
function migrate(sqlite: Database) {
return drizzle(sqlite)
})
export type TxOrDb = Transaction | Client
const TransactionContext = Context.create<{
tx: TxOrDb
effects: (() => void | Promise<void>)[]
}>("database")
export function use<T>(callback: (trx: TxOrDb) => T): T {
try {
const { tx } = TransactionContext.use()
return callback(tx)
} catch (err) {
if (err instanceof Context.NotFound) {
const effects: (() => void | Promise<void>)[] = []
const result = TransactionContext.provide({ effects, tx: client() }, () => callback(client()))
for (const effect of effects) effect()
return result
}
throw err
}
}
export function effect(effect: () => void | Promise<void>) {
try {
const { effects } = TransactionContext.use()
effects.push(effect)
} catch {
effect()
}
}
export function transaction<T>(callback: (tx: TxOrDb) => T): T {
try {
const { tx } = TransactionContext.use()
return callback(tx)
} catch (err) {
if (err instanceof Context.NotFound) {
const effects: (() => void | Promise<void>)[] = []
const result = client().transaction((tx) => {
return TransactionContext.provide({ tx, effects }, () => callback(tx))
})
for (const effect of effects) effect()
return result
}
throw err
}
}
}
function migrate(sqlite: BunDatabase) {
sqlite.exec(`
CREATE TABLE IF NOT EXISTS _migrations (
name TEXT PRIMARY KEY,
@ -59,8 +119,6 @@ function migrate(sqlite: Database) {
if (applied.has(migration.name)) continue
log.info("applying migration", { name: migration.name })
// Split by statement breakpoint and execute each statement
// Use IF NOT EXISTS variants to handle partial migrations
const statements = migration.sql.split("--> statement-breakpoint")
for (const stmt of statements) {
const trimmed = stmt.trim()
@ -69,7 +127,6 @@ function migrate(sqlite: Database) {
try {
sqlite.exec(trimmed)
} catch (e: any) {
// Ignore "already exists" errors for idempotency
if (e?.message?.includes("already exists")) {
log.info("skipping existing object", { statement: trimmed.slice(0, 50) })
continue
@ -81,7 +138,3 @@ function migrate(sqlite: Database) {
sqlite.run("INSERT INTO _migrations (name, applied_at) VALUES (?, ?)", [migration.name, Date.now()])
}
}
export function db() {
return connection()
}

View File

@ -7,9 +7,8 @@ import { Global } from "../global"
import { Instance } from "../project/instance"
import { InstanceBootstrap } from "../project/bootstrap"
import { Project } from "../project/project"
import { db } from "../storage/db"
import { Database, eq } from "../storage/db"
import { ProjectTable } from "../project/project.sql"
import { eq } from "drizzle-orm"
import { fn } from "../util/fn"
import { Log } from "../util/log"
import { BusEvent } from "@/bus/bus-event"
@ -320,7 +319,7 @@ export namespace Worktree {
},
})
const row = db().select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get()
const row = Database.use((db) => db.select().from(ProjectTable).where(eq(ProjectTable.id, projectID)).get())
const project = row ? Project.fromRow(row) : undefined
const startup = project?.commands?.start?.trim() ?? ""