diff options
| author | Adam Malczewski <[email protected]> | 2026-06-04 22:58:19 +0900 |
|---|---|---|
| committer | Adam Malczewski <[email protected]> | 2026-06-04 22:58:19 +0900 |
| commit | ae22da591474d4be7daf16be552ad7437ef1828b (patch) | |
| tree | bb3725770a75e7cef3a17523fc576a63e0342185 /packages/kernel | |
| parent | 75f78873425ada97bc8428e9fe80760ab7be7fc7 (diff) | |
| download | dispatch-ae22da591474d4be7daf16be552ad7437ef1828b.tar.gz dispatch-ae22da591474d4be7daf16be552ad7437ef1828b.zip | |
feat(kernel): runTurn turn loop — tool dispatch policy (eager/semaphore/dedup/concurrencySafe/abort), 16 tests
Diffstat (limited to 'packages/kernel')
| -rw-r--r-- | packages/kernel/src/index.ts | 1 | ||||
| -rw-r--r-- | packages/kernel/src/runtime/dispatch.ts | 124 | ||||
| -rw-r--r-- | packages/kernel/src/runtime/events.ts | 57 | ||||
| -rw-r--r-- | packages/kernel/src/runtime/index.ts | 12 | ||||
| -rw-r--r-- | packages/kernel/src/runtime/run-turn.test.ts | 755 | ||||
| -rw-r--r-- | packages/kernel/src/runtime/run-turn.ts | 270 |
6 files changed, 1219 insertions, 0 deletions
diff --git a/packages/kernel/src/index.ts b/packages/kernel/src/index.ts index ac1a734..064eb11 100644 --- a/packages/kernel/src/index.ts +++ b/packages/kernel/src/index.ts @@ -5,3 +5,4 @@ export * from "./bus/index.js"; export * from "./contracts/index.js"; +export * from "./runtime/index.js"; diff --git a/packages/kernel/src/runtime/dispatch.ts b/packages/kernel/src/runtime/dispatch.ts new file mode 100644 index 0000000..c6c5f8e --- /dev/null +++ b/packages/kernel/src/runtime/dispatch.ts @@ -0,0 +1,124 @@ +import type { ToolDispatchPolicy } from "../contracts/dispatch.js"; +import type { EventEmitter } from "../contracts/runtime.js"; +import type { ToolCall, ToolContract, ToolExecuteContext, ToolResult } from "../contracts/tool.js"; +import { toolOutputEvent } from "./events.js"; + +export interface StepDispatcher { + submit(call: ToolCall): void; + drain(): Promise<Map<string, ToolResult>>; +} + +export async function executeToolCall( + call: ToolCall, + tool: ToolContract | undefined, + signal: AbortSignal, + emit: EventEmitter, + tabId: string, + turnId: string, +): Promise<ToolResult> { + if (tool === undefined) { + return { content: `Unknown tool: ${call.name}`, isError: true }; + } + if (signal.aborted) { + return { content: "Aborted", isError: true }; + } + const ctx: ToolExecuteContext = { + toolCallId: call.id, + signal, + onOutput: (data, stream) => { + emit(toolOutputEvent(tabId, turnId, call.id, data, stream)); + }, + }; + try { + return await tool.execute(call.input, ctx); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return { content: `Tool execution error: ${message}`, isError: true }; + } +} + +interface QueueEntry { + readonly call: ToolCall; + readonly tool: ToolContract | undefined; + readonly resolve: (result: ToolResult) => void; +} + +export function createStepDispatcher( + toolMap: Map<string, ToolContract>, + policy: ToolDispatchPolicy, + signal: AbortSignal, + emit: EventEmitter, + tabId: string, + turnId: string, +): StepDispatcher { + let activeCount = 0; + let unsafeRunning = false; + const queue: QueueEntry[] = []; + const allPromises: Array<{ id: string; promise: Promise<ToolResult> }> = []; + const dedupMap = new Map<string, Promise<ToolResult>>(); + + function canStart(isConcurrencySafe: boolean): boolean { + if (unsafeRunning) return false; + if (!isConcurrencySafe && activeCount > 0) return false; + if (policy.maxConcurrent === 0) return true; + return activeCount < policy.maxConcurrent; + } + + function tryStartNext(): void { + while (queue.length > 0) { + const next = queue[0]; + if (next === undefined) break; + const isSafe = next.tool?.concurrencySafe !== false; + if (!canStart(isSafe)) break; + queue.shift(); + activeCount++; + if (!isSafe) unsafeRunning = true; + void runAndResolve(next); + } + } + + async function runAndResolve(entry: QueueEntry): Promise<void> { + const result = await executeToolCall(entry.call, entry.tool, signal, emit, tabId, turnId); + activeCount--; + if (entry.tool?.concurrencySafe === false) unsafeRunning = false; + entry.resolve(result); + tryStartNext(); + } + + function submit(call: ToolCall): void { + const tool = toolMap.get(call.name); + const key = `${call.name}:${JSON.stringify(call.input)}`; + + const existing = dedupMap.get(key); + if (existing !== undefined) { + allPromises.push({ id: call.id, promise: existing }); + return; + } + + const promise = new Promise<ToolResult>((resolve) => { + queue.push({ call, tool, resolve }); + tryStartNext(); + }); + + dedupMap.set(key, promise); + allPromises.push({ id: call.id, promise }); + } + + async function drain(): Promise<Map<string, ToolResult>> { + if (signal.aborted) { + for (const item of queue) { + item.resolve({ content: "Aborted", isError: true }); + } + queue.length = 0; + } + + const results = new Map<string, ToolResult>(); + for (const entry of allPromises) { + const result = await entry.promise; + results.set(entry.id, result); + } + return results; + } + + return { submit, drain }; +} diff --git a/packages/kernel/src/runtime/events.ts b/packages/kernel/src/runtime/events.ts new file mode 100644 index 0000000..62218be --- /dev/null +++ b/packages/kernel/src/runtime/events.ts @@ -0,0 +1,57 @@ +import type { AgentEvent } from "../contracts/events.js"; +import type { Usage } from "../contracts/provider.js"; + +export function textDeltaEvent(tabId: string, turnId: string, delta: string): AgentEvent { + return { type: "text-delta", tabId, turnId, delta }; +} + +export function reasoningDeltaEvent(tabId: string, turnId: string, delta: string): AgentEvent { + return { type: "reasoning-delta", tabId, turnId, delta }; +} + +export function toolCallEvent( + tabId: string, + turnId: string, + toolCallId: string, + toolName: string, + input: unknown, +): AgentEvent { + return { type: "tool-call", tabId, turnId, toolCallId, toolName, input }; +} + +export function toolResultEvent( + tabId: string, + turnId: string, + toolCallId: string, + toolName: string, + content: string, + isError: boolean, +): AgentEvent { + return { type: "tool-result", tabId, turnId, toolCallId, toolName, content, isError }; +} + +export function toolOutputEvent( + tabId: string, + turnId: string, + toolCallId: string, + data: string, + stream: "stdout" | "stderr", +): AgentEvent { + return { type: "tool-output", tabId, turnId, toolCallId, data, stream }; +} + +export function usageEvent(tabId: string, turnId: string, usage: Usage): AgentEvent { + return { type: "usage", tabId, turnId, usage }; +} + +export function errorEvent( + tabId: string, + turnId: string, + message: string, + code?: string, +): AgentEvent { + if (code !== undefined) { + return { type: "error", tabId, turnId, message, code }; + } + return { type: "error", tabId, turnId, message }; +} diff --git a/packages/kernel/src/runtime/index.ts b/packages/kernel/src/runtime/index.ts new file mode 100644 index 0000000..e1156e3 --- /dev/null +++ b/packages/kernel/src/runtime/index.ts @@ -0,0 +1,12 @@ +export type { StepDispatcher } from "./dispatch.js"; +export { createStepDispatcher, executeToolCall } from "./dispatch.js"; +export { + errorEvent, + reasoningDeltaEvent, + textDeltaEvent, + toolCallEvent, + toolOutputEvent, + toolResultEvent, + usageEvent, +} from "./events.js"; +export { MAX_STEPS, runTurn } from "./run-turn.js"; diff --git a/packages/kernel/src/runtime/run-turn.test.ts b/packages/kernel/src/runtime/run-turn.test.ts new file mode 100644 index 0000000..eba05e6 --- /dev/null +++ b/packages/kernel/src/runtime/run-turn.test.ts @@ -0,0 +1,755 @@ +import { describe, expect, it } from "vitest"; +import type { ChatMessage } from "../contracts/conversation.js"; +import type { AgentEvent } from "../contracts/events.js"; +import type { ProviderContract, ProviderEvent } from "../contracts/provider.js"; +import type { ToolContract, ToolExecuteContext, ToolResult } from "../contracts/tool.js"; +import { runTurn } from "./run-turn.js"; + +function delay(ms: number): Promise<void> { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +function createFakeProvider(script: ProviderEvent[][]): ProviderContract { + let callIndex = 0; + return { + id: "fake", + stream(_messages, _tools) { + const events = script[callIndex] ?? []; + callIndex++; + return (async function* () { + for (const event of events) { + yield event; + } + })(); + }, + }; +} + +function createFakeTool( + name: string, + handler?: (input: unknown, ctx: ToolExecuteContext) => Promise<ToolResult>, + opts?: { concurrencySafe?: boolean }, +): ToolContract { + return { + name, + description: `Fake tool: ${name}`, + parameters: { type: "object" }, + ...(opts?.concurrencySafe !== undefined ? { concurrencySafe: opts.concurrencySafe } : {}), + execute: handler ?? (async (input) => ({ content: `${name}: ${JSON.stringify(input)}` })), + }; +} + +function createCollectingEmit(): { events: AgentEvent[]; emit: (event: AgentEvent) => void } { + const events: AgentEvent[] = []; + return { events, emit: (event) => events.push(event) }; +} + +const userMessage: ChatMessage = { + role: "user", + chunks: [{ type: "text", text: "hello" }], +}; + +describe("runTurn", () => { + it("text-only turn emits correct events and returns correct result", async () => { + const provider = createFakeProvider([ + [ + { type: "text-delta", delta: "Hello" }, + { type: "text-delta", delta: " world" }, + { type: "reasoning-delta", delta: "thinking..." }, + { type: "usage", usage: { inputTokens: 10, outputTokens: 5 } }, + { type: "finish", reason: "stop" }, + ], + ]); + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + expect(result.finishReason).toBe("stop"); + expect(result.messages).toHaveLength(1); + expect(result.messages[0]?.role).toBe("assistant"); + + const chunks = result.messages[0]?.chunks ?? []; + expect(chunks).toHaveLength(2); + expect(chunks[0]).toEqual({ type: "text", text: "Hello world" }); + expect(chunks[1]).toEqual({ type: "thinking", text: "thinking..." }); + + expect(result.usage).toEqual({ inputTokens: 10, outputTokens: 5 }); + + const eventTypes = events.map((e) => e.type); + expect(eventTypes).toEqual(["text-delta", "text-delta", "reasoning-delta", "usage"]); + }); + + it("turn with one tool call executes tool, feeds result back, then finishes", async () => { + const tool = createFakeTool("greet", async (input) => ({ + content: `Hello, ${(input as { name: string }).name}!`, + })); + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "greet", input: { name: "World" } }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "Done." }, + { type: "finish", reason: "stop" }, + ], + ]); + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + expect(result.finishReason).toBe("stop"); + expect(result.messages).toHaveLength(3); + expect(result.messages[0]?.role).toBe("assistant"); + expect(result.messages[1]?.role).toBe("tool"); + expect(result.messages[2]?.role).toBe("assistant"); + + const toolResultChunk = result.messages[1]?.chunks[0]; + expect(toolResultChunk?.type).toBe("tool-result"); + if (toolResultChunk?.type === "tool-result") { + expect(toolResultChunk.content).toBe("Hello, World!"); + expect(toolResultChunk.toolCallId).toBe("tc1"); + expect(toolResultChunk.isError).toBe(false); + } + + const eventTypes = events.map((e) => e.type); + expect(eventTypes).toContain("tool-call"); + expect(eventTypes).toContain("tool-result"); + expect(eventTypes).toContain("text-delta"); + }); + + it("passes updated messages to subsequent provider calls", async () => { + const capturedMessages: ChatMessage[][] = []; + let callIndex = 0; + const script: ProviderEvent[][] = [ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "echo", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]; + + const provider: ProviderContract = { + id: "fake", + stream(messages, _tools) { + capturedMessages.push([...messages]); + const events = script[callIndex] ?? []; + callIndex++; + return (async function* () { + for (const event of events) yield event; + })(); + }, + }; + + const tool = createFakeTool("echo", async () => ({ content: "echoed" })); + + await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit: () => {}, + }); + + expect(capturedMessages).toHaveLength(2); + expect(capturedMessages[0] ?? []).toHaveLength(1); + expect(capturedMessages[0]?.[0]?.role).toBe("user"); + + expect(capturedMessages[1] ?? []).toHaveLength(3); + expect(capturedMessages[1]?.[0]?.role).toBe("user"); + expect(capturedMessages[1]?.[1]?.role).toBe("assistant"); + expect(capturedMessages[1]?.[2]?.role).toBe("tool"); + }); + + it("maxConcurrent: 1 runs tools sequentially", async () => { + const log: string[] = []; + + const toolA = createFakeTool("a", async () => { + log.push("a:start"); + await delay(10); + log.push("a:end"); + return { content: "a" }; + }); + + const toolB = createFakeTool("b", async () => { + log.push("b:start"); + await delay(10); + log.push("b:end"); + return { content: "b" }; + }); + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "a", input: {} }, + { type: "tool-call", toolCallId: "tc2", toolName: "b", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + await runTurn({ + provider, + messages: [userMessage], + tools: [toolA, toolB], + dispatch: { maxConcurrent: 1, eager: false }, + emit: () => {}, + }); + + const aEndIdx = log.indexOf("a:end"); + const bStartIdx = log.indexOf("b:start"); + expect(aEndIdx).toBeLessThan(bStartIdx); + }); + + it("maxConcurrent: 2 runs tools in parallel", async () => { + const log: string[] = []; + + const toolA = createFakeTool("a", async () => { + log.push("a:start"); + await delay(20); + log.push("a:end"); + return { content: "a" }; + }); + + const toolB = createFakeTool("b", async () => { + log.push("b:start"); + await delay(20); + log.push("b:end"); + return { content: "b" }; + }); + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "a", input: {} }, + { type: "tool-call", toolCallId: "tc2", toolName: "b", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + await runTurn({ + provider, + messages: [userMessage], + tools: [toolA, toolB], + dispatch: { maxConcurrent: 2, eager: false }, + emit: () => {}, + }); + + const aStartIdx = log.indexOf("a:start"); + const bStartIdx = log.indexOf("b:start"); + const aEndIdx = log.indexOf("a:end"); + const bEndIdx = log.indexOf("b:end"); + + expect(aStartIdx).toBeLessThan(aEndIdx); + expect(bStartIdx).toBeLessThan(bEndIdx); + expect(aStartIdx).toBeLessThan(bEndIdx); + expect(bStartIdx).toBeLessThan(aEndIdx); + }); + + it("maxConcurrent: 0 runs all tools in parallel (unlimited)", async () => { + const log: string[] = []; + + const toolA = createFakeTool("a", async () => { + log.push("a:start"); + await delay(20); + log.push("a:end"); + return { content: "a" }; + }); + + const toolB = createFakeTool("b", async () => { + log.push("b:start"); + await delay(20); + log.push("b:end"); + return { content: "b" }; + }); + + const toolC = createFakeTool("c", async () => { + log.push("c:start"); + await delay(20); + log.push("c:end"); + return { content: "c" }; + }); + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "a", input: {} }, + { type: "tool-call", toolCallId: "tc2", toolName: "b", input: {} }, + { type: "tool-call", toolCallId: "tc3", toolName: "c", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + await runTurn({ + provider, + messages: [userMessage], + tools: [toolA, toolB, toolC], + dispatch: { maxConcurrent: 0, eager: false }, + emit: () => {}, + }); + + const aStartIdx = log.indexOf("a:start"); + const bStartIdx = log.indexOf("b:start"); + const cStartIdx = log.indexOf("c:start"); + const aEndIdx = log.indexOf("a:end"); + const bEndIdx = log.indexOf("b:end"); + const cEndIdx = log.indexOf("c:end"); + + expect(aStartIdx).toBeLessThan(aEndIdx); + expect(bStartIdx).toBeLessThan(bEndIdx); + expect(cStartIdx).toBeLessThan(cEndIdx); + expect(aStartIdx).toBeLessThan(bEndIdx); + expect(bStartIdx).toBeLessThan(aEndIdx); + expect(cStartIdx).toBeLessThan(aEndIdx); + }); + + it("eager: true launches tool before step finish", async () => { + const log: string[] = []; + + const tool = createFakeTool("test", async () => { + log.push("tool:start"); + await delay(5); + log.push("tool:end"); + return { content: "done" }; + }); + + let callCount = 0; + const provider: ProviderContract = { + id: "fake", + stream(_messages, _tools) { + const idx = callCount++; + if (idx === 0) { + return (async function* () { + yield { + type: "tool-call", + toolCallId: "tc1", + toolName: "test", + input: {}, + } as ProviderEvent; + log.push("provider:after-tool-call"); + await delay(50); + yield { type: "finish", reason: "tool-calls" } as ProviderEvent; + log.push("provider:finish"); + })(); + } + return (async function* () { + yield { type: "text-delta", delta: "done" } as ProviderEvent; + yield { type: "finish", reason: "stop" } as ProviderEvent; + })(); + }, + }; + + await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: true }, + emit: () => {}, + }); + + const toolStartIdx = log.indexOf("tool:start"); + const finishIdx = log.indexOf("provider:finish"); + expect(toolStartIdx).toBeLessThan(finishIdx); + }); + + it("eager: false does not launch tool before step finish", async () => { + const log: string[] = []; + + const tool = createFakeTool("test", async () => { + log.push("tool:start"); + await delay(5); + log.push("tool:end"); + return { content: "done" }; + }); + + let callCount = 0; + const provider: ProviderContract = { + id: "fake", + stream(_messages, _tools) { + const idx = callCount++; + if (idx === 0) { + return (async function* () { + yield { + type: "tool-call", + toolCallId: "tc1", + toolName: "test", + input: {}, + } as ProviderEvent; + log.push("provider:after-tool-call"); + await delay(50); + yield { type: "finish", reason: "tool-calls" } as ProviderEvent; + log.push("provider:finish"); + })(); + } + return (async function* () { + yield { type: "text-delta", delta: "done" } as ProviderEvent; + yield { type: "finish", reason: "stop" } as ProviderEvent; + })(); + }, + }; + + await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit: () => {}, + }); + + const toolStartIdx = log.indexOf("tool:start"); + const finishIdx = log.indexOf("provider:finish"); + expect(toolStartIdx).toBeGreaterThan(finishIdx); + }); + + it("abort mid-turn synthesizes error results for unresolved tool calls", async () => { + const ac = new AbortController(); + + const tool = createFakeTool("slow", async (_input, ctx) => { + await delay(200); + if (ctx.signal.aborted) return { content: "Aborted", isError: true }; + return { content: "done" }; + }); + + const provider: ProviderContract = { + id: "fake", + stream(_messages, _tools) { + return (async function* () { + yield { + type: "tool-call", + toolCallId: "tc1", + toolName: "slow", + input: {}, + } as ProviderEvent; + yield { + type: "tool-call", + toolCallId: "tc2", + toolName: "slow", + input: { x: 1 }, + } as ProviderEvent; + ac.abort(); + await delay(10); + yield { type: "finish", reason: "tool-calls" } as ProviderEvent; + })(); + }, + }; + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + signal: ac.signal, + }); + + expect(result.finishReason).toBe("aborted"); + + const toolResults = events.filter((e) => e.type === "tool-result"); + for (const tr of toolResults) { + if (tr.type === "tool-result") { + expect(tr.isError).toBe(true); + } + } + }); + + it("abort before any step returns aborted immediately", async () => { + const ac = new AbortController(); + ac.abort(); + + const provider = createFakeProvider([ + [ + { type: "text-delta", delta: "should not appear" }, + { type: "finish", reason: "stop" }, + ], + ]); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [], + dispatch: { maxConcurrent: 1, eager: false }, + emit: () => {}, + signal: ac.signal, + }); + + expect(result.finishReason).toBe("aborted"); + expect(result.messages).toHaveLength(0); + }); + + it("de-duplicates identical tool calls in a batch", async () => { + let execCount = 0; + + const tool = createFakeTool("dedup", async (_input) => { + execCount++; + return { content: `result-${execCount}` }; + }); + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "dedup", input: { x: 1 } }, + { type: "tool-call", toolCallId: "tc2", toolName: "dedup", input: { x: 1 } }, + { type: "tool-call", toolCallId: "tc3", toolName: "dedup", input: { x: 2 } }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + expect(execCount).toBe(2); + + const toolResults = events.filter((e) => e.type === "tool-result"); + expect(toolResults).toHaveLength(3); + + const tc1Result = toolResults.find((e) => e.type === "tool-result" && e.toolCallId === "tc1"); + const tc2Result = toolResults.find((e) => e.type === "tool-result" && e.toolCallId === "tc2"); + const tc3Result = toolResults.find((e) => e.type === "tool-result" && e.toolCallId === "tc3"); + + expect(tc1Result).toBeDefined(); + expect(tc2Result).toBeDefined(); + expect(tc3Result).toBeDefined(); + + if (tc1Result?.type === "tool-result" && tc2Result?.type === "tool-result") { + expect(tc1Result.content).toBe(tc2Result.content); + expect(tc1Result.content).toBe("result-1"); + } + if (tc3Result?.type === "tool-result") { + expect(tc3Result.content).toBe("result-2"); + } + + expect(result.finishReason).toBe("stop"); + }); + + it("serializes non-concurrency-safe tools even with maxConcurrent > 1", async () => { + const log: string[] = []; + + const unsafeTool: ToolContract = { + name: "unsafe", + description: "Unsafe tool", + parameters: { type: "object" }, + concurrencySafe: false, + execute: async () => { + log.push("unsafe:start"); + await delay(10); + log.push("unsafe:end"); + return { content: "done" }; + }, + }; + + const safeTool: ToolContract = { + name: "safe", + description: "Safe tool", + parameters: { type: "object" }, + execute: async () => { + log.push("safe:start"); + await delay(10); + log.push("safe:end"); + return { content: "done" }; + }, + }; + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "unsafe", input: {} }, + { type: "tool-call", toolCallId: "tc2", toolName: "safe", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + await runTurn({ + provider, + messages: [userMessage], + tools: [unsafeTool, safeTool], + dispatch: { maxConcurrent: 5, eager: false }, + emit: () => {}, + }); + + const unsafeEndIdx = log.indexOf("unsafe:end"); + const safeStartIdx = log.indexOf("safe:start"); + expect(unsafeEndIdx).toBeLessThan(safeStartIdx); + }); + + it("handles unknown tool name gracefully", async () => { + const provider = createFakeProvider([ + [ + { + type: "tool-call", + toolCallId: "tc1", + toolName: "nonexistent", + input: {}, + }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + const toolResults = events.filter((e) => e.type === "tool-result"); + expect(toolResults).toHaveLength(1); + if (toolResults[0]?.type === "tool-result") { + expect(toolResults[0]?.isError).toBe(true); + expect(toolResults[0]?.content).toContain("Unknown tool"); + } + + expect(result.finishReason).toBe("stop"); + }); + + it("handles provider error gracefully", async () => { + const provider: ProviderContract = { + id: "fake", + stream() { + return (async function* () { + yield { type: "text-delta", delta: "partial" } as ProviderEvent; + throw new Error("provider crashed"); + })(); + }, + }; + + const { events, emit } = createCollectingEmit(); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + expect(result.finishReason).toBe("error"); + + const errorEvents = events.filter((e) => e.type === "error"); + expect(errorEvents).toHaveLength(1); + if (errorEvents[0]?.type === "error") { + expect(errorEvents[0]?.message).toContain("provider crashed"); + } + }); + + it("aggregates usage across multiple steps", async () => { + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "echo", input: {} }, + { type: "usage", usage: { inputTokens: 10, outputTokens: 5 } }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "usage", usage: { inputTokens: 20, outputTokens: 10 } }, + { type: "finish", reason: "stop" }, + ], + ]); + + const tool = createFakeTool("echo", async () => ({ content: "echoed" })); + + const result = await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit: () => {}, + }); + + expect(result.usage).toEqual({ inputTokens: 30, outputTokens: 15 }); + }); + + it("emits tool-output events from tool ctx.onOutput", async () => { + const tool: ToolContract = { + name: "streaming", + description: "A tool that streams output", + parameters: { type: "object" }, + execute: async (_input, ctx) => { + ctx.onOutput("line 1\n", "stdout"); + ctx.onOutput("err 1\n", "stderr"); + return { content: "done" }; + }, + }; + + const provider = createFakeProvider([ + [ + { type: "tool-call", toolCallId: "tc1", toolName: "streaming", input: {} }, + { type: "finish", reason: "tool-calls" }, + ], + [ + { type: "text-delta", delta: "done" }, + { type: "finish", reason: "stop" }, + ], + ]); + + const { events, emit } = createCollectingEmit(); + + await runTurn({ + provider, + messages: [userMessage], + tools: [tool], + dispatch: { maxConcurrent: 1, eager: false }, + emit, + }); + + const outputs = events.filter((e) => e.type === "tool-output"); + expect(outputs).toHaveLength(2); + if (outputs[0]?.type === "tool-output") { + expect(outputs[0]?.data).toBe("line 1\n"); + expect(outputs[0]?.stream).toBe("stdout"); + expect(outputs[0]?.toolCallId).toBe("tc1"); + } + if (outputs[1]?.type === "tool-output") { + expect(outputs[1]?.data).toBe("err 1\n"); + expect(outputs[1]?.stream).toBe("stderr"); + } + }); +}); diff --git a/packages/kernel/src/runtime/run-turn.ts b/packages/kernel/src/runtime/run-turn.ts new file mode 100644 index 0000000..46b1465 --- /dev/null +++ b/packages/kernel/src/runtime/run-turn.ts @@ -0,0 +1,270 @@ +import type { ChatMessage, Chunk } from "../contracts/conversation.js"; +import type { ProviderContract, ProviderEvent, Usage } from "../contracts/provider.js"; +import type { EventEmitter, RunTurnInput, RunTurnResult } from "../contracts/runtime.js"; +import type { ToolCall, ToolContract } from "../contracts/tool.js"; +import { createStepDispatcher, type StepDispatcher } from "./dispatch.js"; +import { + errorEvent, + reasoningDeltaEvent, + textDeltaEvent, + toolCallEvent, + toolResultEvent, + usageEvent, +} from "./events.js"; + +export const MAX_STEPS = 50; + +function zeroUsage(): Usage { + return { inputTokens: 0, outputTokens: 0 }; +} + +function addUsage(a: Usage, b: Usage): Usage { + const inputTokens = a.inputTokens + b.inputTokens; + const outputTokens = a.outputTokens + b.outputTokens; + + if (a.cacheReadTokens !== undefined || b.cacheReadTokens !== undefined) { + const cacheReadTokens = (a.cacheReadTokens ?? 0) + (b.cacheReadTokens ?? 0); + if (a.cacheWriteTokens !== undefined || b.cacheWriteTokens !== undefined) { + return { + inputTokens, + outputTokens, + cacheReadTokens, + cacheWriteTokens: (a.cacheWriteTokens ?? 0) + (b.cacheWriteTokens ?? 0), + }; + } + return { inputTokens, outputTokens, cacheReadTokens }; + } + + if (a.cacheWriteTokens !== undefined || b.cacheWriteTokens !== undefined) { + return { + inputTokens, + outputTokens, + cacheWriteTokens: (a.cacheWriteTokens ?? 0) + (b.cacheWriteTokens ?? 0), + }; + } + + return { inputTokens, outputTokens }; +} + +function appendTextDelta(chunks: Chunk[], delta: string): void { + const lastIdx = chunks.length - 1; + const last = chunks[lastIdx]; + if (last !== undefined && last.type === "text") { + chunks[lastIdx] = { type: "text", text: last.text + delta }; + } else { + chunks.push({ type: "text", text: delta }); + } +} + +function appendThinkingDelta(chunks: Chunk[], delta: string): void { + const lastIdx = chunks.length - 1; + const last = chunks[lastIdx]; + if (last !== undefined && last.type === "thinking") { + chunks[lastIdx] = { type: "thinking", text: last.text + delta }; + } else { + chunks.push({ type: "thinking", text: delta }); + } +} + +interface StepContext { + readonly provider: ProviderContract; + readonly messages: ChatMessage[]; + readonly tools: readonly ToolContract[]; + readonly toolMap: Map<string, ToolContract>; + readonly dispatch: RunTurnInput["dispatch"]; + readonly emit: EventEmitter; + readonly signal: AbortSignal; + readonly tabId: string; + readonly turnId: string; +} + +interface StepResult { + readonly assistantMessage: ChatMessage | undefined; + readonly toolCalls: ToolCall[]; + readonly toolMessages: ChatMessage[]; + readonly usage: Usage; + readonly finishReason: string; +} + +function processEvent( + event: ProviderEvent, + chunks: Chunk[], + toolCalls: ToolCall[], + dispatcher: StepDispatcher, + ctx: StepContext, +): void { + switch (event.type) { + case "text-delta": + appendTextDelta(chunks, event.delta); + ctx.emit(textDeltaEvent(ctx.tabId, ctx.turnId, event.delta)); + break; + case "reasoning-delta": + appendThinkingDelta(chunks, event.delta); + ctx.emit(reasoningDeltaEvent(ctx.tabId, ctx.turnId, event.delta)); + break; + case "tool-call": { + const call: ToolCall = { + id: event.toolCallId, + name: event.toolName, + input: event.input, + }; + toolCalls.push(call); + chunks.push({ + type: "tool-call", + toolCallId: event.toolCallId, + toolName: event.toolName, + input: event.input, + }); + ctx.emit(toolCallEvent(ctx.tabId, ctx.turnId, event.toolCallId, event.toolName, event.input)); + if (ctx.dispatch.eager) { + dispatcher.submit(call); + } + break; + } + case "usage": + ctx.emit(usageEvent(ctx.tabId, ctx.turnId, event.usage)); + break; + case "finish": + break; + case "error": + if (event.code !== undefined) { + chunks.push({ type: "error", message: event.message, code: event.code }); + } else { + chunks.push({ type: "error", message: event.message }); + } + ctx.emit(errorEvent(ctx.tabId, ctx.turnId, event.message, event.code)); + break; + } +} + +async function executeStep(ctx: StepContext): Promise<StepResult> { + const chunks: Chunk[] = []; + const toolCalls: ToolCall[] = []; + let stepUsage = zeroUsage(); + let finishReason = "stop"; + + const dispatcher = createStepDispatcher( + ctx.toolMap, + ctx.dispatch, + ctx.signal, + ctx.emit, + ctx.tabId, + ctx.turnId, + ); + + try { + const stream = ctx.provider.stream(ctx.messages, ctx.tools); + for await (const event of stream) { + if (ctx.signal.aborted) break; + processEvent(event, chunks, toolCalls, dispatcher, ctx); + if (event.type === "usage") { + stepUsage = addUsage(stepUsage, event.usage); + } + if (event.type === "finish") { + finishReason = event.reason; + } + } + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + chunks.push({ type: "error", message }); + ctx.emit(errorEvent(ctx.tabId, ctx.turnId, message)); + finishReason = "error"; + } + + if (!ctx.dispatch.eager) { + for (const call of toolCalls) { + dispatcher.submit(call); + } + } + + const results = await dispatcher.drain(); + + const toolMessages: ChatMessage[] = []; + for (const call of toolCalls) { + const result = results.get(call.id); + if (result !== undefined) { + const isError = result.isError ?? false; + ctx.emit(toolResultEvent(ctx.tabId, ctx.turnId, call.id, call.name, result.content, isError)); + toolMessages.push({ + role: "tool", + chunks: [ + { + type: "tool-result", + toolCallId: call.id, + toolName: call.name, + content: result.content, + isError, + }, + ], + }); + } + } + + const assistantMessage: ChatMessage | undefined = + chunks.length > 0 ? { role: "assistant", chunks } : undefined; + + return { assistantMessage, toolCalls, toolMessages, usage: stepUsage, finishReason }; +} + +export async function runTurn(input: RunTurnInput): Promise<RunTurnResult> { + const messages: ChatMessage[] = [...input.messages]; + const resultMessages: ChatMessage[] = []; + let totalUsage = zeroUsage(); + let finishReason = "stop"; + + const toolMap = new Map<string, ToolContract>(); + for (const tool of input.tools) { + toolMap.set(tool.name, tool); + } + + const tabId = ""; + const turnId = ""; + const signal = input.signal ?? new AbortController().signal; + + for (let step = 0; step < MAX_STEPS; step++) { + if (signal.aborted) { + finishReason = "aborted"; + break; + } + + const stepResult = await executeStep({ + provider: input.provider, + messages, + tools: input.tools, + toolMap, + dispatch: input.dispatch, + emit: input.emit, + signal, + tabId, + turnId, + }); + + totalUsage = addUsage(totalUsage, stepResult.usage); + + if (stepResult.assistantMessage !== undefined) { + messages.push(stepResult.assistantMessage); + resultMessages.push(stepResult.assistantMessage); + } + + for (const msg of stepResult.toolMessages) { + messages.push(msg); + resultMessages.push(msg); + } + + if (signal.aborted) { + finishReason = "aborted"; + break; + } + + if (stepResult.toolCalls.length === 0) { + finishReason = stepResult.finishReason; + break; + } + + if (step === MAX_STEPS - 1) { + finishReason = "max-steps"; + } + } + + return { messages: resultMessages, usage: totalUsage, finishReason }; +} |
