summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--packages/opencode/src/session/prompt.ts81
-rw-r--r--packages/opencode/src/tool/registry.ts27
-rw-r--r--packages/opencode/src/tool/tool.ts9
-rw-r--r--packages/opencode/test/session/prompt-effect.test.ts147
4 files changed, 185 insertions, 79 deletions
diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts
index 3c9988ea3..a18f9e379 100644
--- a/packages/opencode/src/session/prompt.ts
+++ b/packages/opencode/src/session/prompt.ts
@@ -559,7 +559,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
}) {
const { task, model, lastUser, sessionID, session, msgs } = input
const ctx = yield* InstanceState.context
- const taskTool = yield* registry.fromID(TaskTool.id)
+ const { task: taskTool } = yield* registry.named()
const taskModel = task.model ? yield* getModel(task.model.providerID, task.model.modelID, sessionID) : model
const assistantMessage: MessageV2.Assistant = yield* sessions.updateMessage({
id: MessageID.ascending(),
@@ -1080,6 +1080,21 @@ NOTE: At any point in time through this workflow you should feel free to ask the
const filepath = fileURLToPath(part.url)
if (yield* fsys.isDir(filepath)) part.mime = "application/x-directory"
+ const { read } = yield* registry.named()
+ const execRead = (args: Parameters<typeof read.execute>[0], extra?: Tool.Context["extra"]) =>
+ Effect.promise((signal: AbortSignal) =>
+ read.execute(args, {
+ sessionID: input.sessionID,
+ abort: signal,
+ agent: input.agent!,
+ messageID: info.id,
+ extra: { bypassCwdCheck: true, ...extra },
+ messages: [],
+ metadata: async () => {},
+ ask: async () => {},
+ }),
+ )
+
if (part.mime === "text/plain") {
let offset: number | undefined
let limit: number | undefined
@@ -1116,29 +1131,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
},
]
- const read = yield* registry.fromID("read").pipe(
- Effect.flatMap((t) =>
- provider.getModel(info.model.providerID, info.model.modelID).pipe(
- Effect.flatMap((mdl) =>
- Effect.promise(() =>
- t.execute(args, {
- sessionID: input.sessionID,
- abort: new AbortController().signal,
- agent: input.agent!,
- messageID: info.id,
- extra: { bypassCwdCheck: true, model: mdl },
- messages: [],
- metadata: async () => {},
- ask: async () => {},
- }),
- ),
- ),
- ),
- ),
+ const exit = yield* provider.getModel(info.model.providerID, info.model.modelID).pipe(
+ Effect.flatMap((mdl) => execRead(args, { model: mdl })),
Effect.exit,
)
- if (Exit.isSuccess(read)) {
- const result = read.value
+ if (Exit.isSuccess(exit)) {
+ const result = exit.value
pieces.push({
messageID: info.id,
sessionID: input.sessionID,
@@ -1160,7 +1158,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
pieces.push({ ...part, messageID: info.id, sessionID: input.sessionID })
}
} else {
- const error = Cause.squash(read.cause)
+ const error = Cause.squash(exit.cause)
log.error("failed to read file", { error })
const message = error instanceof Error ? error.message : String(error)
yield* bus.publish(Session.Event.Error, {
@@ -1180,22 +1178,25 @@ NOTE: At any point in time through this workflow you should feel free to ask the
if (part.mime === "application/x-directory") {
const args = { filePath: filepath }
- const result = yield* registry.fromID("read").pipe(
- Effect.flatMap((t) =>
- Effect.promise(() =>
- t.execute(args, {
- sessionID: input.sessionID,
- abort: new AbortController().signal,
- agent: input.agent!,
- messageID: info.id,
- extra: { bypassCwdCheck: true },
- messages: [],
- metadata: async () => {},
- ask: async () => {},
- }),
- ),
- ),
- )
+ const exit = yield* execRead(args).pipe(Effect.exit)
+ if (Exit.isFailure(exit)) {
+ const error = Cause.squash(exit.cause)
+ log.error("failed to read directory", { error })
+ const message = error instanceof Error ? error.message : String(error)
+ yield* bus.publish(Session.Event.Error, {
+ sessionID: input.sessionID,
+ error: new NamedError.Unknown({ message }).toObject(),
+ })
+ return [
+ {
+ messageID: info.id,
+ sessionID: input.sessionID,
+ type: "text",
+ synthetic: true,
+ text: `Read tool failed to read ${filepath} with the following error: ${message}`,
+ },
+ ]
+ }
return [
{
messageID: info.id,
@@ -1209,7 +1210,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
sessionID: input.sessionID,
type: "text",
synthetic: true,
- text: result.output,
+ text: exit.value.output,
},
{ ...part, messageID: info.id, sessionID: input.sessionID },
]
diff --git a/packages/opencode/src/tool/registry.ts b/packages/opencode/src/tool/registry.ts
index 63e1a97ea..800c45ced 100644
--- a/packages/opencode/src/tool/registry.ts
+++ b/packages/opencode/src/tool/registry.ts
@@ -42,24 +42,25 @@ import { Agent } from "../agent/agent"
export namespace ToolRegistry {
const log = Log.create({ service: "tool.registry" })
+ type TaskDef = Tool.InferDef<typeof TaskTool>
+ type ReadDef = Tool.InferDef<typeof ReadTool>
+
type State = {
custom: Tool.Def[]
builtin: Tool.Def[]
+ task: TaskDef
+ read: ReadDef
}
export interface Interface {
readonly ids: () => Effect.Effect<string[]>
readonly all: () => Effect.Effect<Tool.Def[]>
- readonly named: {
- task: Tool.Info
- read: Tool.Info
- }
+ readonly named: () => Effect.Effect<{ task: TaskDef; read: ReadDef }>
readonly tools: (model: {
providerID: ProviderID
modelID: ModelID
agent: Agent.Info
}) => Effect.Effect<Tool.Def[]>
- readonly fromID: (id: string) => Effect.Effect<Tool.Def>
}
export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/ToolRegistry") {}
@@ -183,6 +184,8 @@ export namespace ToolRegistry {
...(Flag.OPENCODE_EXPERIMENTAL_LSP_TOOL ? [tool.lsp] : []),
...(Flag.OPENCODE_EXPERIMENTAL_PLAN_MODE && Flag.OPENCODE_CLIENT === "cli" ? [tool.plan] : []),
],
+ task: tool.task,
+ read: tool.read,
}
}),
)
@@ -192,13 +195,6 @@ export namespace ToolRegistry {
return [...s.builtin, ...s.custom] as Tool.Def[]
})
- const fromID: Interface["fromID"] = Effect.fn("ToolRegistry.fromID")(function* (id: string) {
- const tools = yield* all()
- const match = tools.find((tool) => tool.id === id)
- if (!match) return yield* Effect.die(`Tool not found: ${id}`)
- return match
- })
-
const ids: Interface["ids"] = Effect.fn("ToolRegistry.ids")(function* () {
return (yield* all()).map((tool) => tool.id)
})
@@ -245,7 +241,12 @@ export namespace ToolRegistry {
)
})
- return Service.of({ ids, all, named: { task, read }, tools, fromID })
+ const named: Interface["named"] = Effect.fn("ToolRegistry.named")(function* () {
+ const s = yield* InstanceState.get(state)
+ return { task: s.task, read: s.read }
+ })
+
+ return Service.of({ ids, all, named, tools })
}),
)
diff --git a/packages/opencode/src/tool/tool.ts b/packages/opencode/src/tool/tool.ts
index 66e1b8e78..ae347341c 100644
--- a/packages/opencode/src/tool/tool.ts
+++ b/packages/opencode/src/tool/tool.ts
@@ -60,6 +60,13 @@ export namespace Tool {
export type InferMetadata<T> =
T extends Info<any, infer M> ? M : T extends Effect.Effect<Info<any, infer M>, any, any> ? M : never
+ export type InferDef<T> =
+ T extends Info<infer P, infer M>
+ ? Def<P, M>
+ : T extends Effect.Effect<Info<infer P, infer M>, any, any>
+ ? Def<P, M>
+ : never
+
function wrap<Parameters extends z.ZodType, Result extends Metadata>(
id: string,
init: (() => Promise<DefWithoutID<Parameters, Result>>) | DefWithoutID<Parameters, Result>,
@@ -118,7 +125,7 @@ export namespace Tool {
)
}
- export function init(info: Info): Effect.Effect<Def> {
+ export function init<P extends z.ZodType, M extends Metadata>(info: Info<P, M>): Effect.Effect<Def<P, M>> {
return Effect.gen(function* () {
const init = yield* Effect.promise(() => info.init())
return {
diff --git a/packages/opencode/test/session/prompt-effect.test.ts b/packages/opencode/test/session/prompt-effect.test.ts
index 5693e139d..38d7ed9f5 100644
--- a/packages/opencode/test/session/prompt-effect.test.ts
+++ b/packages/opencode/test/session/prompt-effect.test.ts
@@ -631,31 +631,22 @@ it.live(
const ready = defer<void>()
const aborted = defer<void>()
const registry = yield* ToolRegistry.Service
- const init = registry.named.task.init
- registry.named.task.init = async () => ({
- description: "task",
- parameters: z.object({
- description: z.string(),
- prompt: z.string(),
- subagent_type: z.string(),
- task_id: z.string().optional(),
- command: z.string().optional(),
- }),
- execute: async (_args, ctx) => {
- ready.resolve()
- ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
- await new Promise<void>(() => {})
- return {
- title: "",
- metadata: {
- sessionId: SessionID.make("task"),
- model: ref,
- },
- output: "",
- }
- },
- })
- yield* Effect.addFinalizer(() => Effect.sync(() => void (registry.named.task.init = init)))
+ const { task } = yield* registry.named()
+ const original = task.execute
+ task.execute = async (_args, ctx) => {
+ ready.resolve()
+ ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
+ await new Promise<void>(() => {})
+ return {
+ title: "",
+ metadata: {
+ sessionId: SessionID.make("task"),
+ model: ref,
+ },
+ output: "",
+ }
+ }
+ yield* Effect.addFinalizer(() => Effect.sync(() => void (task.execute = original)))
const { prompt, chat } = yield* boot()
const msg = yield* user(chat.id, "hello")
@@ -1240,3 +1231,109 @@ unix(
),
30_000,
)
+
+// Abort signal propagation tests for inline tool execution
+
+/** Override a tool's execute to hang until aborted. Returns ready/aborted defers and a finalizer. */
+function hangUntilAborted(tool: { execute: (...args: any[]) => any }) {
+ const ready = defer<void>()
+ const aborted = defer<void>()
+ const original = tool.execute
+ tool.execute = async (_args: any, ctx: any) => {
+ ready.resolve()
+ ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
+ await new Promise<void>(() => {})
+ return { title: "", metadata: {}, output: "" }
+ }
+ const restore = Effect.addFinalizer(() => Effect.sync(() => void (tool.execute = original)))
+ return { ready, aborted, restore }
+}
+
+it.live(
+ "interrupt propagates abort signal to read tool via file part (text/plain)",
+ () =>
+ provideTmpdirInstance(
+ (dir) =>
+ Effect.gen(function* () {
+ const registry = yield* ToolRegistry.Service
+ const { read } = yield* registry.named()
+ const { ready, aborted, restore } = hangUntilAborted(read)
+ yield* restore
+
+ const prompt = yield* SessionPrompt.Service
+ const sessions = yield* Session.Service
+ const chat = yield* sessions.create({ title: "Abort Test" })
+
+ const testFile = path.join(dir, "test.txt")
+ yield* Effect.promise(() => Bun.write(testFile, "hello world"))
+
+ const fiber = yield* prompt
+ .prompt({
+ sessionID: chat.id,
+ agent: "build",
+ parts: [
+ { type: "text", text: "read this" },
+ { type: "file", url: `file://${testFile}`, filename: "test.txt", mime: "text/plain" },
+ ],
+ })
+ .pipe(Effect.forkChild)
+
+ yield* Effect.promise(() => ready.promise)
+ yield* Fiber.interrupt(fiber)
+
+ yield* Effect.promise(() =>
+ Promise.race([
+ aborted.promise,
+ new Promise<void>((_, reject) =>
+ setTimeout(() => reject(new Error("abort signal not propagated within 2s")), 2_000),
+ ),
+ ]),
+ )
+ }),
+ { git: true, config: cfg },
+ ),
+ 30_000,
+)
+
+it.live(
+ "interrupt propagates abort signal to read tool via file part (directory)",
+ () =>
+ provideTmpdirInstance(
+ (dir) =>
+ Effect.gen(function* () {
+ const registry = yield* ToolRegistry.Service
+ const { read } = yield* registry.named()
+ const { ready, aborted, restore } = hangUntilAborted(read)
+ yield* restore
+
+ const prompt = yield* SessionPrompt.Service
+ const sessions = yield* Session.Service
+ const chat = yield* sessions.create({ title: "Abort Test" })
+
+ const fiber = yield* prompt
+ .prompt({
+ sessionID: chat.id,
+ agent: "build",
+ parts: [
+ { type: "text", text: "read this" },
+ { type: "file", url: `file://${dir}`, filename: "dir", mime: "application/x-directory" },
+ ],
+ })
+ .pipe(Effect.forkChild)
+
+ yield* Effect.promise(() => ready.promise)
+ yield* Fiber.interrupt(fiber)
+
+ yield* Effect.promise(() =>
+ Promise.race([
+ aborted.promise,
+ new Promise<void>((_, reject) =>
+ setTimeout(() => reject(new Error("abort signal not propagated within 2s")), 2_000),
+ ),
+ ]),
+ )
+ }),
+ { git: true, config: cfg },
+ ),
+ 30_000,
+)