summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--packages/opencode/src/session/revert.ts246
-rw-r--r--packages/opencode/test/session/revert-compact.test.ts146
2 files changed, 287 insertions, 105 deletions
diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts
index b1e9840e4..92049b12b 100644
--- a/packages/opencode/src/session/revert.ts
+++ b/packages/opencode/src/session/revert.ts
@@ -1,12 +1,14 @@
import z from "zod"
-import { SessionID, MessageID, PartID } from "./schema"
+import { Effect, Layer, ServiceMap } from "effect"
+import { makeRuntime } from "@/effect/run-service"
+import { Bus } from "../bus"
import { Snapshot } from "../snapshot"
-import { MessageV2 } from "./message-v2"
-import { Session } from "."
-import { Log } from "../util/log"
-import { SyncEvent } from "../sync"
import { Storage } from "@/storage/storage"
-import { Bus } from "../bus"
+import { SyncEvent } from "../sync"
+import { Log } from "../util/log"
+import { Session } from "."
+import { MessageV2 } from "./message-v2"
+import { SessionID, MessageID, PartID } from "./schema"
import { SessionPrompt } from "./prompt"
import { SessionSummary } from "./summary"
@@ -20,116 +22,152 @@ export namespace SessionRevert {
})
export type RevertInput = z.infer<typeof RevertInput>
- export async function revert(input: RevertInput) {
- await SessionPrompt.assertNotBusy(input.sessionID)
- const all = await Session.messages({ sessionID: input.sessionID })
- let lastUser: MessageV2.User | undefined
- const session = await Session.get(input.sessionID)
-
- let revert: Session.Info["revert"]
- const patches: Snapshot.Patch[] = []
- for (const msg of all) {
- if (msg.info.role === "user") lastUser = msg.info
- const remaining = []
- for (const part of msg.parts) {
- if (revert) {
- if (part.type === "patch") {
- patches.push(part)
+ export interface Interface {
+ readonly revert: (input: RevertInput) => Effect.Effect<Session.Info>
+ readonly unrevert: (input: { sessionID: SessionID }) => Effect.Effect<Session.Info>
+ readonly cleanup: (session: Session.Info) => Effect.Effect<void>
+ }
+
+ export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/SessionRevert") {}
+
+ export const layer = Layer.effect(
+ Service,
+ Effect.gen(function* () {
+ const sessions = yield* Session.Service
+ const snap = yield* Snapshot.Service
+ const storage = yield* Storage.Service
+ const bus = yield* Bus.Service
+
+ const revert = Effect.fn("SessionRevert.revert")(function* (input: RevertInput) {
+ yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
+ const all = yield* sessions.messages({ sessionID: input.sessionID })
+ let lastUser: MessageV2.User | undefined
+ const session = yield* sessions.get(input.sessionID)
+
+ let rev: Session.Info["revert"]
+ const patches: Snapshot.Patch[] = []
+ for (const msg of all) {
+ if (msg.info.role === "user") lastUser = msg.info
+ const remaining = []
+ for (const part of msg.parts) {
+ if (rev) {
+ if (part.type === "patch") patches.push(part)
+ continue
+ }
+
+ if (!rev) {
+ if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
+ const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
+ rev = {
+ messageID: !partID && lastUser ? lastUser.id : msg.info.id,
+ partID,
+ }
+ }
+ remaining.push(part)
+ }
}
- continue
}
- if (!revert) {
- if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
- // if no useful parts left in message, same as reverting whole message
- const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
- revert = {
- messageID: !partID && lastUser ? lastUser.id : msg.info.id,
- partID,
+ if (!rev) return session
+
+ rev.snapshot = session.revert?.snapshot ?? (yield* snap.track())
+ yield* snap.revert(patches)
+ if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string)
+ const range = all.filter((msg) => msg.info.id >= rev!.messageID)
+ const diffs = yield* Effect.promise(() => SessionSummary.computeDiff({ messages: range }))
+ yield* storage.write(["session_diff", input.sessionID], diffs).pipe(Effect.ignore)
+ yield* bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs })
+ yield* sessions.setRevert({
+ sessionID: input.sessionID,
+ revert: rev,
+ summary: {
+ additions: diffs.reduce((sum, x) => sum + x.additions, 0),
+ deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
+ files: diffs.length,
+ },
+ })
+ return yield* sessions.get(input.sessionID)
+ })
+
+ const unrevert = Effect.fn("SessionRevert.unrevert")(function* (input: { sessionID: SessionID }) {
+ log.info("unreverting", input)
+ yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
+ const session = yield* sessions.get(input.sessionID)
+ if (!session.revert) return session
+ if (session.revert.snapshot) yield* snap.restore(session.revert!.snapshot!)
+ yield* sessions.clearRevert(input.sessionID)
+ return yield* sessions.get(input.sessionID)
+ })
+
+ const cleanup = Effect.fn("SessionRevert.cleanup")(function* (session: Session.Info) {
+ if (!session.revert) return
+ const sessionID = session.id
+ const msgs = yield* sessions.messages({ sessionID })
+ const messageID = session.revert.messageID
+ const remove = [] as MessageV2.WithParts[]
+ let target: MessageV2.WithParts | undefined
+ for (const msg of msgs) {
+ if (msg.info.id < messageID) continue
+ if (msg.info.id > messageID) {
+ remove.push(msg)
+ continue
+ }
+ if (session.revert.partID) {
+ target = msg
+ continue
+ }
+ remove.push(msg)
+ }
+ for (const msg of remove) {
+ SyncEvent.run(MessageV2.Event.Removed, {
+ sessionID,
+ messageID: msg.info.id,
+ })
+ }
+ if (session.revert.partID && target) {
+ const partID = session.revert.partID
+ const idx = target.parts.findIndex((part) => part.id === partID)
+ if (idx >= 0) {
+ const removeParts = target.parts.slice(idx)
+ target.parts = target.parts.slice(0, idx)
+ for (const part of removeParts) {
+ SyncEvent.run(MessageV2.Event.PartRemoved, {
+ sessionID,
+ messageID: target.info.id,
+ partID: part.id,
+ })
}
}
- remaining.push(part)
}
- }
- }
-
- if (revert) {
- const session = await Session.get(input.sessionID)
- revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
- await Snapshot.revert(patches)
- if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
- const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID)
- const diffs = await SessionSummary.computeDiff({ messages: rangeMessages })
- await Storage.write(["session_diff", input.sessionID], diffs)
- Bus.publish(Session.Event.Diff, {
- sessionID: input.sessionID,
- diff: diffs,
- })
- return Session.setRevert({
- sessionID: input.sessionID,
- revert,
- summary: {
- additions: diffs.reduce((sum, x) => sum + x.additions, 0),
- deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
- files: diffs.length,
- },
+ yield* sessions.clearRevert(sessionID)
})
- }
- return session
+
+ return Service.of({ revert, unrevert, cleanup })
+ }),
+ )
+
+ export const defaultLayer = Layer.unwrap(
+ Effect.sync(() =>
+ layer.pipe(
+ Layer.provide(Session.defaultLayer),
+ Layer.provide(Snapshot.defaultLayer),
+ Layer.provide(Storage.defaultLayer),
+ Layer.provide(Bus.layer),
+ ),
+ ),
+ )
+
+ const { runPromise } = makeRuntime(Service, defaultLayer)
+
+ export async function revert(input: RevertInput) {
+ return runPromise((svc) => svc.revert(input))
}
export async function unrevert(input: { sessionID: SessionID }) {
- log.info("unreverting", input)
- await SessionPrompt.assertNotBusy(input.sessionID)
- const session = await Session.get(input.sessionID)
- if (!session.revert) return session
- if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
- return Session.clearRevert(input.sessionID)
+ return runPromise((svc) => svc.unrevert(input))
}
export async function cleanup(session: Session.Info) {
- if (!session.revert) return
- const sessionID = session.id
- const msgs = await Session.messages({ sessionID })
- const messageID = session.revert.messageID
- const remove = [] as MessageV2.WithParts[]
- let target: MessageV2.WithParts | undefined
- for (const msg of msgs) {
- if (msg.info.id < messageID) {
- continue
- }
- if (msg.info.id > messageID) {
- remove.push(msg)
- continue
- }
- if (session.revert.partID) {
- target = msg
- continue
- }
- remove.push(msg)
- }
- for (const msg of remove) {
- SyncEvent.run(MessageV2.Event.Removed, {
- sessionID: sessionID,
- messageID: msg.info.id,
- })
- }
- if (session.revert.partID && target) {
- const partID = session.revert.partID
- const removeStart = target.parts.findIndex((part) => part.id === partID)
- if (removeStart >= 0) {
- const preserveParts = target.parts.slice(0, removeStart)
- const removeParts = target.parts.slice(removeStart)
- target.parts = preserveParts
- for (const part of removeParts) {
- SyncEvent.run(MessageV2.Event.PartRemoved, {
- sessionID: sessionID,
- messageID: target.info.id,
- partID: part.id,
- })
- }
- }
- }
- await Session.clearRevert(sessionID)
+ return runPromise((svc) => svc.cleanup(session))
}
}
diff --git a/packages/opencode/test/session/revert-compact.test.ts b/packages/opencode/test/session/revert-compact.test.ts
index fb37a3a8d..fe7055779 100644
--- a/packages/opencode/test/session/revert-compact.test.ts
+++ b/packages/opencode/test/session/revert-compact.test.ts
@@ -10,9 +10,59 @@ import { Instance } from "../../src/project/instance"
import { MessageID, PartID } from "../../src/session/schema"
import { tmpdir } from "../fixture/fixture"
-const projectRoot = path.join(__dirname, "../..")
Log.init({ print: false })
+function user(sessionID: string, agent = "default") {
+ return Session.updateMessage({
+ id: MessageID.ascending(),
+ role: "user" as const,
+ sessionID: sessionID as any,
+ agent,
+ model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") },
+ time: { created: Date.now() },
+ })
+}
+
+function assistant(sessionID: string, parentID: string, dir: string) {
+ return Session.updateMessage({
+ id: MessageID.ascending(),
+ role: "assistant" as const,
+ sessionID: sessionID as any,
+ mode: "default",
+ agent: "default",
+ path: { cwd: dir, root: dir },
+ cost: 0,
+ tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } },
+ modelID: ModelID.make("gpt-4"),
+ providerID: ProviderID.make("openai"),
+ parentID: parentID as any,
+ time: { created: Date.now() },
+ finish: "end_turn",
+ })
+}
+
+function text(sessionID: string, messageID: string, content: string) {
+ return Session.updatePart({
+ id: PartID.ascending(),
+ messageID: messageID as any,
+ sessionID: sessionID as any,
+ type: "text" as const,
+ text: content,
+ })
+}
+
+function tool(sessionID: string, messageID: string) {
+ return Session.updatePart({
+ id: PartID.ascending(),
+ messageID: messageID as any,
+ sessionID: sessionID as any,
+ type: "tool" as const,
+ tool: "bash",
+ callID: "call-1",
+ state: { status: "completed" as const, input: {}, output: "done", title: "", metadata: {}, time: { start: 0, end: 1 } },
+ })
+}
+
describe("revert + compact workflow", () => {
test("should properly handle compact command after revert", async () => {
await using tmp = await tmpdir({ git: true })
@@ -283,4 +333,98 @@ describe("revert + compact workflow", () => {
},
})
})
+
+ test("cleanup with partID removes parts from the revert point onward", async () => {
+ await using tmp = await tmpdir({ git: true })
+ await Instance.provide({
+ directory: tmp.path,
+ fn: async () => {
+ const session = await Session.create({})
+ const sid = session.id
+
+ const u1 = await user(sid)
+ const p1 = await text(sid, u1.id, "first part")
+ const p2 = await tool(sid, u1.id)
+ const p3 = await text(sid, u1.id, "third part")
+
+ // Set revert state pointing at a specific part
+ await Session.setRevert({
+ sessionID: sid,
+ revert: { messageID: u1.id, partID: p2.id },
+ summary: { additions: 0, deletions: 0, files: 0 },
+ })
+
+ const info = await Session.get(sid)
+ await SessionRevert.cleanup(info)
+
+ const msgs = await Session.messages({ sessionID: sid })
+ expect(msgs.length).toBe(1)
+ // Only the first part should remain (before the revert partID)
+ expect(msgs[0].parts.length).toBe(1)
+ expect(msgs[0].parts[0].id).toBe(p1.id)
+
+ const cleared = await Session.get(sid)
+ expect(cleared.revert).toBeUndefined()
+ },
+ })
+ })
+
+ test("cleanup removes messages after revert point but keeps earlier ones", async () => {
+ await using tmp = await tmpdir({ git: true })
+ await Instance.provide({
+ directory: tmp.path,
+ fn: async () => {
+ const session = await Session.create({})
+ const sid = session.id
+
+ const u1 = await user(sid)
+ await text(sid, u1.id, "hello")
+ const a1 = await assistant(sid, u1.id, tmp.path)
+ await text(sid, a1.id, "hi back")
+
+ const u2 = await user(sid)
+ await text(sid, u2.id, "second question")
+ const a2 = await assistant(sid, u2.id, tmp.path)
+ await text(sid, a2.id, "second answer")
+
+ // Revert from u2 onward
+ await Session.setRevert({
+ sessionID: sid,
+ revert: { messageID: u2.id },
+ summary: { additions: 0, deletions: 0, files: 0 },
+ })
+
+ const info = await Session.get(sid)
+ await SessionRevert.cleanup(info)
+
+ const msgs = await Session.messages({ sessionID: sid })
+ const ids = msgs.map((m) => m.info.id)
+ expect(ids).toContain(u1.id)
+ expect(ids).toContain(a1.id)
+ expect(ids).not.toContain(u2.id)
+ expect(ids).not.toContain(a2.id)
+ },
+ })
+ })
+
+ test("cleanup is a no-op when session has no revert state", async () => {
+ await using tmp = await tmpdir({ git: true })
+ await Instance.provide({
+ directory: tmp.path,
+ fn: async () => {
+ const session = await Session.create({})
+ const sid = session.id
+
+ const u1 = await user(sid)
+ await text(sid, u1.id, "hello")
+
+ const info = await Session.get(sid)
+ expect(info.revert).toBeUndefined()
+ await SessionRevert.cleanup(info)
+
+ const msgs = await Session.messages({ sessionID: sid })
+ expect(msgs.length).toBe(1)
+ },
+ })
+ })
})