summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDax Raad <[email protected]>2025-10-06 19:37:30 -0400
committerDax Raad <[email protected]>2025-10-06 19:37:44 -0400
commit10998d62b9f0964926d4da967a21889eefe82a87 (patch)
treec19089af280cd54abcb1b6a980055e26ebef1c05
parentaee240150bb75ba40a069e94e2e707c8bd25ecd7 (diff)
downloadopencode-10998d62b9f0964926d4da967a21889eefe82a87.tar.gz
opencode-10998d62b9f0964926d4da967a21889eefe82a87.zip
core: improve session API reliability with proper input validation
-rw-r--r--packages/opencode/src/cli/cmd/run.ts2
-rw-r--r--packages/opencode/src/server/server.ts34
-rw-r--r--packages/opencode/src/session/compaction.ts2
-rw-r--r--packages/opencode/src/session/index.ts206
-rw-r--r--packages/opencode/src/session/prompt.ts6
-rw-r--r--packages/opencode/src/tool/task.ts7
6 files changed, 139 insertions, 118 deletions
diff --git a/packages/opencode/src/cli/cmd/run.ts b/packages/opencode/src/cli/cmd/run.ts
index e04ed8103..7d0e68dee 100644
--- a/packages/opencode/src/cli/cmd/run.ts
+++ b/packages/opencode/src/cli/cmd/run.ts
@@ -106,7 +106,7 @@ export const RunCommand = cmd({
if (args.session) return Session.get(args.session)
- return Session.create()
+ return Session.create({})
})()
if (!session) {
diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts
index cba186dd9..26cbb5d71 100644
--- a/packages/opencode/src/server/server.ts
+++ b/packages/opencode/src/server/server.ts
@@ -31,7 +31,6 @@ import { SessionRevert } from "../session/revert"
import { lazy } from "../util/lazy"
import { Todo } from "../session/todo"
import { InstanceBootstrap } from "../project/bootstrap"
-import { Identifier } from "@/id/id"
const ERRORS = {
400: {
@@ -308,7 +307,7 @@ export namespace Server {
validator(
"param",
z.object({
- id: z.string(),
+ id: Session.get.schema,
}),
),
async (c) => {
@@ -336,7 +335,7 @@ export namespace Server {
validator(
"param",
z.object({
- id: z.string(),
+ id: Session.children.schema,
}),
),
async (c) => {
@@ -390,18 +389,10 @@ export namespace Server {
},
},
}),
- validator(
- "json",
- z
- .object({
- parentID: z.string().optional(),
- title: z.string().optional(),
- })
- .optional(),
- ),
+ validator("json", Session.create.schema.optional()),
async (c) => {
const body = c.req.valid("json") ?? {}
- const session = await Session.create(body.parentID, body.title)
+ const session = await Session.create(body)
return c.json(session)
},
)
@@ -424,7 +415,7 @@ export namespace Server {
validator(
"param",
z.object({
- id: z.string(),
+ id: Session.remove.schema,
}),
),
async (c) => {
@@ -495,14 +486,7 @@ export namespace Server {
id: z.string().meta({ description: "Session ID" }),
}),
),
- validator(
- "json",
- z.object({
- messageID: z.string(),
- providerID: z.string(),
- modelID: z.string(),
- }),
- ),
+ validator("json", Session.initialize.schema.omit({ sessionID: true })),
async (c) => {
const sessionID = c.req.valid("param").id
const body = c.req.valid("json")
@@ -529,7 +513,7 @@ export namespace Server {
validator(
"param",
z.object({
- id: Identifier.schema("session").meta({ description: "Session ID" }),
+ id: Session.fork.schema.shape.sessionID,
}),
),
validator("json", Session.fork.schema.omit({ sessionID: true })),
@@ -614,7 +598,7 @@ export namespace Server {
validator(
"param",
z.object({
- id: z.string(),
+ id: Session.unshare.schema,
}),
),
async (c) => {
@@ -717,7 +701,7 @@ export namespace Server {
),
async (c) => {
const params = c.req.valid("param")
- const message = await Session.getMessage(params.id, params.messageID)
+ const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
return c.json(message)
},
)
diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts
index e9b120c96..9282d8243 100644
--- a/packages/opencode/src/session/compaction.ts
+++ b/packages/opencode/src/session/compaction.ts
@@ -144,7 +144,7 @@ export namespace SessionCompaction {
},
],
})
- const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
+ const usage = Session.getUsage({ model: model.info, usage: generated.usage, metadata: generated.providerMetadata })
msg.cost += usage.cost
msg.tokens = usage.tokens
msg.summary = true
diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts
index c8e6d4ad4..521dcfe72 100644
--- a/packages/opencode/src/session/index.ts
+++ b/packages/opencode/src/session/index.ts
@@ -93,13 +93,21 @@ export namespace Session {
),
}
- export async function create(parentID?: string, title?: string) {
- return createNext({
- parentID,
- directory: Instance.directory,
- title,
- })
- }
+ export const create = fn(
+ z
+ .object({
+ parentID: Identifier.schema("session").optional(),
+ title: z.string().optional(),
+ })
+ .optional(),
+ async (input) => {
+ return createNext({
+ parentID: input?.parentID,
+ directory: Instance.directory,
+ title: input?.title,
+ })
+ },
+ )
export const fork = fn(
z.object({
@@ -132,11 +140,11 @@ export namespace Session {
},
)
- export async function touch(sessionID: string) {
+ export const touch = fn(Identifier.schema("session"), async (sessionID) => {
await update(sessionID, (draft) => {
draft.time.updated = Date.now()
})
- }
+ })
export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
const result: Info = {
@@ -170,16 +178,16 @@ export namespace Session {
return result
}
- export async function get(id: string) {
+ export const get = fn(Identifier.schema("session"), async (id) => {
const read = await Storage.read<Info>(["session", Instance.project.id, id])
return read as Info
- }
+ })
- export async function getShare(id: string) {
+ export const getShare = fn(Identifier.schema("session"), async (id) => {
return Storage.read<ShareInfo>(["share", id])
- }
+ })
- export async function share(id: string) {
+ export const share = fn(Identifier.schema("session"), async (id) => {
const cfg = await Config.get()
if (cfg.share === "disabled") {
throw new Error("Sharing is disabled in configuration")
@@ -202,9 +210,9 @@ export namespace Session {
}
}
return share
- }
+ })
- export async function unshare(id: string) {
+ export const unshare = fn(Identifier.schema("session"), async (id) => {
const share = await getShare(id)
if (!share) return
await Storage.remove(["share", id])
@@ -212,7 +220,7 @@ export namespace Session {
draft.share = undefined
})
await Share.remove(id, share.secret)
- }
+ })
export async function update(id: string, editor: (session: Info) => void) {
const project = Instance.project
@@ -226,7 +234,7 @@ export namespace Session {
return result
}
- export async function messages(sessionID: string) {
+ export const messages = fn(Identifier.schema("session"), async (sessionID) => {
const result = [] as MessageV2.WithParts[]
for (const p of await Storage.list(["message", sessionID])) {
const read = await Storage.read<MessageV2.Info>(p)
@@ -237,16 +245,22 @@ export namespace Session {
}
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
return result
- }
+ })
- export async function getMessage(sessionID: string, messageID: string) {
- return {
- info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]),
- parts: await getParts(messageID),
- }
- }
+ export const getMessage = fn(
+ z.object({
+ sessionID: Identifier.schema("session"),
+ messageID: Identifier.schema("message"),
+ }),
+ async (input) => {
+ return {
+ info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
+ parts: await getParts(input.messageID),
+ }
+ },
+ )
- export async function getParts(messageID: string) {
+ export const getParts = fn(Identifier.schema("message"), async (messageID) => {
const result = [] as MessageV2.Part[]
for (const item of await Storage.list(["part", messageID])) {
const read = await Storage.read<MessageV2.Part>(item)
@@ -254,7 +268,7 @@ export namespace Session {
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result
- }
+ })
export async function* list() {
const project = Instance.project
@@ -263,7 +277,7 @@ export namespace Session {
}
}
- export async function children(parentID: string) {
+ export const children = fn(Identifier.schema("session"), async (parentID) => {
const project = Instance.project
const result = [] as Session.Info[]
for (const item of await Storage.list(["session", project.id])) {
@@ -272,9 +286,9 @@ export namespace Session {
result.push(session)
}
return result
- }
+ })
- export async function remove(sessionID: string) {
+ export const remove = fn(Identifier.schema("session"), async (sessionID) => {
const project = Instance.project
try {
const session = await get(sessionID)
@@ -295,56 +309,69 @@ export namespace Session {
} catch (e) {
log.error(e)
}
- }
+ })
- export async function updateMessage(msg: MessageV2.Info) {
+ export const updateMessage = fn(MessageV2.Info, async (msg) => {
await Storage.write(["message", msg.sessionID, msg.id], msg)
Bus.publish(MessageV2.Event.Updated, {
info: msg,
})
return msg
- }
+ })
- export async function removeMessage(sessionID: string, messageID: string) {
- await Storage.remove(["message", sessionID, messageID])
- Bus.publish(MessageV2.Event.Removed, {
- sessionID,
- messageID,
- })
- return messageID
- }
+ export const removeMessage = fn(
+ z.object({
+ sessionID: Identifier.schema("session"),
+ messageID: Identifier.schema("message"),
+ }),
+ async (input) => {
+ await Storage.remove(["message", input.sessionID, input.messageID])
+ Bus.publish(MessageV2.Event.Removed, {
+ sessionID: input.sessionID,
+ messageID: input.messageID,
+ })
+ return input.messageID
+ },
+ )
- export async function updatePart(part: MessageV2.Part) {
+ export const updatePart = fn(MessageV2.Part, async (part) => {
await Storage.write(["part", part.messageID, part.id], part)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
})
return part
- }
+ })
- export function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
- const tokens = {
- input: usage.inputTokens ?? 0,
- output: usage.outputTokens ?? 0,
- reasoning: usage?.reasoningTokens ?? 0,
- cache: {
- write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
- // @ts-expect-error
- metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
- 0) as number,
- read: usage.cachedInputTokens ?? 0,
- },
- }
- return {
- cost: new Decimal(0)
- .add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000))
- .add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000))
- .add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000))
- .add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000))
- .toNumber(),
- tokens,
- }
- }
+ export const getUsage = fn(
+ z.object({
+ model: z.custom<ModelsDev.Model>(),
+ usage: z.custom<LanguageModelUsage>(),
+ metadata: z.custom<ProviderMetadata>().optional(),
+ }),
+ (input) => {
+ const tokens = {
+ input: input.usage.inputTokens ?? 0,
+ output: input.usage.outputTokens ?? 0,
+ reasoning: input.usage?.reasoningTokens ?? 0,
+ cache: {
+ write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
+ // @ts-expect-error
+ input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
+ 0) as number,
+ read: input.usage.cachedInputTokens ?? 0,
+ },
+ }
+ return {
+ cost: new Decimal(0)
+ .add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
+ .add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
+ .add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
+ .add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
+ .toNumber(),
+ tokens,
+ }
+ },
+ )
export class BusyError extends Error {
constructor(public readonly sessionID: string) {
@@ -352,27 +379,30 @@ export namespace Session {
}
}
- export async function initialize(input: {
- sessionID: string
- modelID: string
- providerID: string
- messageID: string
- }) {
- await SessionPrompt.prompt({
- sessionID: input.sessionID,
- messageID: input.messageID,
- model: {
- providerID: input.providerID,
- modelID: input.modelID,
- },
- parts: [
- {
- id: Identifier.ascending("part"),
- type: "text",
- text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
+ export const initialize = fn(
+ z.object({
+ sessionID: Identifier.schema("session"),
+ modelID: z.string(),
+ providerID: z.string(),
+ messageID: Identifier.schema("message"),
+ }),
+ async (input) => {
+ await SessionPrompt.prompt({
+ sessionID: input.sessionID,
+ messageID: input.messageID,
+ model: {
+ providerID: input.providerID,
+ modelID: input.modelID,
},
- ],
- })
- await Project.setInitialized(Instance.project.id)
- }
+ parts: [
+ {
+ id: Identifier.ascending("part"),
+ type: "text",
+ text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
+ },
+ ],
+ })
+ await Project.setInitialized(Instance.project.id)
+ },
+ )
}
diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts
index 474843dd9..9ba06f010 100644
--- a/packages/opencode/src/session/prompt.ts
+++ b/packages/opencode/src/session/prompt.ts
@@ -1031,7 +1031,11 @@ export namespace SessionPrompt {
break
case "finish-step":
- const usage = Session.getUsage(input.model, value.usage, value.providerMetadata)
+ const usage = Session.getUsage({
+ model: input.model,
+ usage: value.usage,
+ metadata: value.providerMetadata,
+ })
assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens
await Session.updatePart({
diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts
index 5875722f8..302e0cce3 100644
--- a/packages/opencode/src/tool/task.ts
+++ b/packages/opencode/src/tool/task.ts
@@ -26,8 +26,11 @@ export const TaskTool = Tool.define("task", async () => {
async execute(params, ctx) {
const agent = await Agent.get(params.subagent_type)
if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`)
- const session = await Session.create(ctx.sessionID, params.description + ` (@${agent.name} subagent)`)
- const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
+ const session = await Session.create({
+ parentID: ctx.sessionID,
+ title: params.description + ` (@${agent.name} subagent)`,
+ })
+ const msg = await Session.getMessage({ sessionID: ctx.sessionID, messageID: ctx.messageID })
if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
const messageID = Identifier.ascending("message")
const parts: Record<string, MessageV2.ToolPart> = {}