summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDax Raad <[email protected]>2025-07-02 13:00:46 -0400
committerDax Raad <[email protected]>2025-07-02 13:10:36 -0400
commit35d6273fb3eb15801676655acb54f354465119f2 (patch)
treea24df60acbf3069f7a0a182daf43a14b89798730
parentb89d4a16fd338285c4d6e3adf9a4f137d9d88b5c (diff)
downloadopencode-35d6273fb3eb15801676655acb54f354465119f2.tar.gz
opencode-35d6273fb3eb15801676655acb54f354465119f2.zip
wip: session revert/unrevert
-rw-r--r--packages/opencode/src/session/index.ts92
-rw-r--r--packages/opencode/src/session/message.ts14
2 files changed, 100 insertions, 6 deletions
diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts
index e5dbffacc..71e894f8a 100644
--- a/packages/opencode/src/session/index.ts
+++ b/packages/opencode/src/session/index.ts
@@ -34,6 +34,7 @@ import type { ModelsDev } from "../provider/models"
import { Installation } from "../installation"
import { Config } from "../config/config"
import { ProviderTransform } from "../provider/transform"
+import { Snapshot } from "../snapshot"
export namespace Session {
const log = Log.create({ service: "session" })
@@ -53,6 +54,13 @@ export namespace Session {
created: z.number(),
updated: z.number(),
}),
+ revert: z
+ .object({
+ messageID: z.string(),
+ part: z.number(),
+ snapshot: z.string().optional(),
+ })
+ .optional(),
})
.openapi({
ref: "Session",
@@ -285,6 +293,37 @@ export namespace Session {
l.info("chatting")
const model = await Provider.getModel(input.providerID, input.modelID)
let msgs = await messages(input.sessionID)
+ const session = await get(input.sessionID)
+
+ if (session.revert) {
+ const trimmed = []
+ for (const msg of msgs) {
+ if (
+ msg.id > session.revert.messageID ||
+ (msg.id === session.revert.messageID && session.revert.part === 0)
+ ) {
+ await Storage.remove(
+ "session/message/" + input.sessionID + "/" + msg.id,
+ )
+ await Bus.publish(Message.Event.Removed, {
+ sessionID: input.sessionID,
+ messageID: msg.id,
+ })
+ continue
+ }
+
+ if (msg.id === session.revert.messageID) {
+ if (session.revert.part === 0) break
+ msg.parts = msg.parts.slice(0, session.revert.part)
+ }
+ trimmed.push(msg)
+ }
+ msgs = trimmed
+ await update(input.sessionID, (draft) => {
+ draft.revert = undefined
+ })
+ }
+
const previous = msgs.at(-1)
// auto summarize if too long
@@ -319,7 +358,6 @@ export namespace Session {
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
const app = App.info()
- const session = await get(input.sessionID)
if (msgs.length === 0 && !session.parentID) {
generateText({
maxTokens: input.providerID === "google" ? 1024 : 20,
@@ -349,6 +387,7 @@ export namespace Session {
})
.catch(() => {})
}
+ const snapshot = await Snapshot.create(input.sessionID)
const msg: Message.Info = {
role: "user",
id: Identifier.ascending("message"),
@@ -359,6 +398,7 @@ export namespace Session {
},
sessionID: input.sessionID,
tool: {},
+ snapshot,
},
}
await updateMessage(msg)
@@ -373,6 +413,7 @@ export namespace Session {
role: "assistant",
parts: [],
metadata: {
+ snapshot,
assistant: {
system,
path: {
@@ -424,6 +465,7 @@ export namespace Session {
})
next.metadata!.tool![opts.toolCallId] = {
...result.metadata,
+ snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -436,6 +478,7 @@ export namespace Session {
error: true,
message: e.toString(),
title: e.toString(),
+ snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -457,6 +500,7 @@ export namespace Session {
const result = await execute(args, opts)
next.metadata!.tool![opts.toolCallId] = {
...result.metadata,
+ snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -471,6 +515,7 @@ export namespace Session {
next.metadata!.tool![opts.toolCallId] = {
error: true,
message: e.toString(),
+ snapshot: await Snapshot.create(input.sessionID),
title: "mcp",
time: {
start,
@@ -735,6 +780,51 @@ export namespace Session {
return next
}
+ export async function revert(input: {
+ sessionID: string
+ messageID: string
+ part: number
+ }) {
+ const message = await getMessage(input.sessionID, input.messageID)
+ if (!message) return
+ const part = message.parts[input.part]
+ if (!part) return
+ const session = await get(input.sessionID)
+ const snapshot =
+ session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
+ const old = (() => {
+ if (message.role === "assistant") {
+ const lastTool = message.parts.findLast(
+ (part, index) =>
+ part.type === "tool-invocation" && index < input.part,
+ )
+ if (lastTool && lastTool.type === "tool-invocation")
+ return message.metadata.tool[lastTool.toolInvocation.toolCallId]
+ .snapshot
+ }
+ return message.metadata.snapshot
+ })()
+ if (old) await Snapshot.restore(input.sessionID, old)
+ await update(input.sessionID, (draft) => {
+ draft.revert = {
+ messageID: input.messageID,
+ part: input.part,
+ snapshot,
+ }
+ })
+ }
+
+ export async function unrevert(sessionID: string) {
+ const session = await get(sessionID)
+ if (!session) return
+ if (!session.revert) return
+ if (session.revert.snapshot)
+ await Snapshot.restore(sessionID, session.revert.snapshot)
+ update(sessionID, (draft) => {
+ draft.revert = undefined
+ })
+ }
+
export async function summarize(input: {
sessionID: string
providerID: string
diff --git a/packages/opencode/src/session/message.ts b/packages/opencode/src/session/message.ts
index b2171fa44..2d319e87b 100644
--- a/packages/opencode/src/session/message.ts
+++ b/packages/opencode/src/session/message.ts
@@ -159,6 +159,7 @@ export namespace Message {
z
.object({
title: z.string(),
+ snapshot: z.string().optional(),
time: z.object({
start: z.number(),
end: z.number(),
@@ -188,11 +189,7 @@ export namespace Message {
}),
})
.optional(),
- user: z
- .object({
- snapshot: z.string().optional(),
- })
- .optional(),
+ snapshot: z.string().optional(),
})
.openapi({ ref: "MessageMetadata" }),
})
@@ -208,6 +205,13 @@ export namespace Message {
info: Info,
}),
),
+ Removed: Bus.event(
+ "message.removed",
+ z.object({
+ sessionID: z.string(),
+ messageID: z.string(),
+ }),
+ ),
PartUpdated: Bus.event(
"message.part.updated",
z.object({