summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorKit Langton <[email protected]>2026-03-31 20:07:58 -0400
committeropencode <[email protected]>2026-04-01 00:44:15 +0000
commit181b5f62361a6ce2d0a6b3e0ba266ed50a6dd1ab (patch)
tree643f031e43f2798b616c4c589f643181585a719a
parent6314f09c14fdd6a3ab8bedc4f7b7182647551d12 (diff)
downloadopencode-181b5f62361a6ce2d0a6b3e0ba266ed50a6dd1ab.tar.gz
opencode-181b5f62361a6ce2d0a6b3e0ba266ed50a6dd1ab.zip
refactor(prompt): use Provider service in effect layers (#20167)
-rw-r--r--packages/opencode/src/agent/agent.ts8
-rw-r--r--packages/opencode/src/provider/provider.ts7
-rw-r--r--packages/opencode/src/session/compaction.ts18
-rw-r--r--packages/opencode/src/session/prompt.ts78
-rw-r--r--packages/opencode/test/fake/provider.ts81
-rw-r--r--packages/opencode/test/session/compaction.test.ts44
-rw-r--r--packages/opencode/test/session/prompt-concurrency.test.ts247
-rw-r--r--packages/opencode/test/session/prompt-effect.test.ts2
8 files changed, 163 insertions, 322 deletions
diff --git a/packages/opencode/src/agent/agent.ts b/packages/opencode/src/agent/agent.ts
index 96b71f816..0c6fe6ec9 100644
--- a/packages/opencode/src/agent/agent.ts
+++ b/packages/opencode/src/agent/agent.ts
@@ -75,6 +75,7 @@ export namespace Agent {
const config = yield* Config.Service
const auth = yield* Auth.Service
const skill = yield* Skill.Service
+ const provider = yield* Provider.Service
const state = yield* InstanceState.make<State>(
Effect.fn("Agent.state")(function* (ctx) {
@@ -330,9 +331,9 @@ export namespace Agent {
model?: { providerID: ProviderID; modelID: ModelID }
}) {
const cfg = yield* config.get()
- const model = input.model ?? (yield* Effect.promise(() => Provider.defaultModel()))
- const resolved = yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID))
- const language = yield* Effect.promise(() => Provider.getLanguage(resolved))
+ const model = input.model ?? (yield* provider.defaultModel())
+ const resolved = yield* provider.getModel(model.providerID, model.modelID)
+ const language = yield* provider.getLanguage(resolved)
const system = [PROMPT_GENERATE]
yield* Effect.promise(() =>
@@ -393,6 +394,7 @@ export namespace Agent {
)
export const defaultLayer = layer.pipe(
+ Layer.provide(Provider.defaultLayer),
Layer.provide(Auth.defaultLayer),
Layer.provide(Config.defaultLayer),
Layer.provide(Skill.defaultLayer),
diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts
index c6784f450..b2f7d848d 100644
--- a/packages/opencode/src/provider/provider.ts
+++ b/packages/opencode/src/provider/provider.ts
@@ -1541,10 +1541,9 @@ export namespace Provider {
}),
)
- const { runPromise } = makeRuntime(
- Service,
- layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer)),
- )
+ export const defaultLayer = layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Auth.defaultLayer))
+
+ const { runPromise } = makeRuntime(Service, defaultLayer)
export async function list() {
return runPromise((svc) => svc.list())
diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts
index 02a8d9484..999a37b12 100644
--- a/packages/opencode/src/session/compaction.ts
+++ b/packages/opencode/src/session/compaction.ts
@@ -63,7 +63,13 @@ export namespace SessionCompaction {
export const layer: Layer.Layer<
Service,
never,
- Bus.Service | Config.Service | Session.Service | Agent.Service | Plugin.Service | SessionProcessor.Service
+ | Bus.Service
+ | Config.Service
+ | Session.Service
+ | Agent.Service
+ | Plugin.Service
+ | SessionProcessor.Service
+ | Provider.Service
> = Layer.effect(
Service,
Effect.gen(function* () {
@@ -73,6 +79,7 @@ export namespace SessionCompaction {
const agents = yield* Agent.Service
const plugin = yield* Plugin.Service
const processors = yield* SessionProcessor.Service
+ const provider = yield* Provider.Service
const isOverflow = Effect.fn("SessionCompaction.isOverflow")(function* (input: {
tokens: MessageV2.Assistant["tokens"]
@@ -170,11 +177,9 @@ export namespace SessionCompaction {
}
const agent = yield* agents.get("compaction")
- const model = yield* Effect.promise(() =>
- agent.model
- ? Provider.getModel(agent.model.providerID, agent.model.modelID)
- : Provider.getModel(userMessage.model.providerID, userMessage.model.modelID),
- )
+ const model = agent.model
+ ? yield* provider.getModel(agent.model.providerID, agent.model.modelID)
+ : yield* provider.getModel(userMessage.model.providerID, userMessage.model.modelID)
// Allow plugins to inject context or replace compaction prompt.
const compacting = yield* plugin.trigger(
"experimental.session.compacting",
@@ -377,6 +382,7 @@ When constructing the summary, try to stick to this template:
export const defaultLayer = Layer.unwrap(
Effect.sync(() =>
layer.pipe(
+ Layer.provide(Provider.defaultLayer),
Layer.provide(Session.defaultLayer),
Layer.provide(SessionProcessor.defaultLayer),
Layer.provide(Agent.defaultLayer),
diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts
index 083c23cc6..78f4fae52 100644
--- a/packages/opencode/src/session/prompt.ts
+++ b/packages/opencode/src/session/prompt.ts
@@ -84,6 +84,7 @@ export namespace SessionPrompt {
const status = yield* SessionStatus.Service
const sessions = yield* Session.Service
const agents = yield* Agent.Service
+ const provider = yield* Provider.Service
const processor = yield* SessionProcessor.Service
const compaction = yield* SessionCompaction.Service
const plugin = yield* Plugin.Service
@@ -206,14 +207,14 @@ export namespace SessionPrompt {
const ag = yield* agents.get("title")
if (!ag) return
+ const mdl = ag.model
+ ? yield* provider.getModel(ag.model.providerID, ag.model.modelID)
+ : ((yield* provider.getSmallModel(input.providerID)) ??
+ (yield* provider.getModel(input.providerID, input.modelID)))
+ const msgs = onlySubtasks
+ ? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
+ : yield* Effect.promise(() => MessageV2.toModelMessages(context, mdl))
const text = yield* Effect.promise(async (signal) => {
- const mdl = ag.model
- ? await Provider.getModel(ag.model.providerID, ag.model.modelID)
- : ((await Provider.getSmallModel(input.providerID)) ??
- (await Provider.getModel(input.providerID, input.modelID)))
- const msgs = onlySubtasks
- ? [{ role: "user" as const, content: subtasks.map((p) => p.prompt).join("\n") }]
- : await MessageV2.toModelMessages(context, mdl)
const result = await LLM.stream({
agent: ag,
user: firstInfo,
@@ -932,21 +933,35 @@ NOTE: At any point in time through this workflow you should feel free to ask the
return { info: msg, parts: [part] }
})
- const getModel = (providerID: ProviderID, modelID: ModelID, sessionID: SessionID) =>
- Effect.promise(() =>
- Provider.getModel(providerID, modelID).catch((e) => {
- if (Provider.ModelNotFoundError.isInstance(e)) {
- const hint = e.data.suggestions?.length ? ` Did you mean: ${e.data.suggestions.join(", ")}?` : ""
- Bus.publish(Session.Event.Error, {
- sessionID,
- error: new NamedError.Unknown({
- message: `Model not found: ${e.data.providerID}/${e.data.modelID}.${hint}`,
- }).toObject(),
- })
- }
- throw e
- }),
- )
+ const getModel = Effect.fn("SessionPrompt.getModel")(function* (
+ providerID: ProviderID,
+ modelID: ModelID,
+ sessionID: SessionID,
+ ) {
+ const exit = yield* provider.getModel(providerID, modelID).pipe(Effect.exit)
+ if (Exit.isSuccess(exit)) return exit.value
+ const err = Cause.squash(exit.cause)
+ if (Provider.ModelNotFoundError.isInstance(err)) {
+ const hint = err.data.suggestions?.length ? ` Did you mean: ${err.data.suggestions.join(", ")}?` : ""
+ yield* bus.publish(Session.Event.Error, {
+ sessionID,
+ error: new NamedError.Unknown({
+ message: `Model not found: ${err.data.providerID}/${err.data.modelID}.${hint}`,
+ }).toObject(),
+ })
+ }
+ return yield* Effect.failCause(exit.cause)
+ })
+
+ const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
+ const model = yield* Effect.promise(async () => {
+ for await (const item of MessageV2.stream(sessionID)) {
+ if (item.info.role === "user" && item.info.model) return item.info.model
+ }
+ })
+ if (model) return model
+ return yield* provider.defaultModel()
+ })
const createUserMessage = Effect.fn("SessionPrompt.createUserMessage")(function* (input: PromptInput) {
const agentName = input.agent || (yield* agents.defaultAgent())
@@ -960,9 +975,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
}
const model = input.model ?? ag.model ?? (yield* lastModel(input.sessionID))
+ const same = ag.model && model.providerID === ag.model.providerID && model.modelID === ag.model.modelID
const full =
- !input.variant && ag.variant
- ? yield* Effect.promise(() => Provider.getModel(model.providerID, model.modelID).catch(() => undefined))
+ !input.variant && ag.variant && same
+ ? yield* provider
+ .getModel(model.providerID, model.modelID)
+ .pipe(Effect.catch(() => Effect.succeed(undefined)))
: undefined
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
@@ -1109,7 +1127,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
]
const read = yield* Effect.promise(() => ReadTool.init()).pipe(
Effect.flatMap((t) =>
- Effect.promise(() => Provider.getModel(info.model.providerID, info.model.modelID)).pipe(
+ provider.getModel(info.model.providerID, info.model.modelID).pipe(
Effect.flatMap((mdl) =>
Effect.promise(() =>
t.execute(args, {
@@ -1711,6 +1729,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
Layer.provide(FileTime.defaultLayer),
Layer.provide(ToolRegistry.defaultLayer),
Layer.provide(Truncate.layer),
+ Layer.provide(Provider.defaultLayer),
Layer.provide(AppFileSystem.defaultLayer),
Layer.provide(Plugin.defaultLayer),
Layer.provide(Session.defaultLayer),
@@ -1856,15 +1875,6 @@ NOTE: At any point in time through this workflow you should feel free to ask the
return runPromise((svc) => svc.command(CommandInput.parse(input)))
}
- const lastModel = Effect.fnUntraced(function* (sessionID: SessionID) {
- return yield* Effect.promise(async () => {
- for await (const item of MessageV2.stream(sessionID)) {
- if (item.info.role === "user" && item.info.model) return item.info.model
- }
- return Provider.defaultModel()
- })
- })
-
/** @internal Exported for testing */
export function createStructuredOutputTool(input: {
schema: Record<string, any>
diff --git a/packages/opencode/test/fake/provider.ts b/packages/opencode/test/fake/provider.ts
new file mode 100644
index 000000000..b6f72f53d
--- /dev/null
+++ b/packages/opencode/test/fake/provider.ts
@@ -0,0 +1,81 @@
+import { Effect, Layer } from "effect"
+import { Provider } from "../../src/provider/provider"
+import { ModelID, ProviderID } from "../../src/provider/schema"
+
+export namespace ProviderTest {
+ export function model(override: Partial<Provider.Model> = {}): Provider.Model {
+ const id = override.id ?? ModelID.make("gpt-5.2")
+ const providerID = override.providerID ?? ProviderID.make("openai")
+ return {
+ id,
+ providerID,
+ name: "Test Model",
+ capabilities: {
+ toolcall: true,
+ attachment: false,
+ reasoning: false,
+ temperature: true,
+ interleaved: false,
+ input: { text: true, image: false, audio: false, video: false, pdf: false },
+ output: { text: true, image: false, audio: false, video: false, pdf: false },
+ },
+ api: { id, url: "https://example.com", npm: "@ai-sdk/openai" },
+ cost: { input: 0, output: 0, cache: { read: 0, write: 0 } },
+ limit: { context: 200_000, output: 10_000 },
+ status: "active",
+ options: {},
+ headers: {},
+ release_date: "2025-01-01",
+ ...override,
+ }
+ }
+
+ export function info(override: Partial<Provider.Info> = {}, mdl = model()): Provider.Info {
+ const id = override.id ?? mdl.providerID
+ return {
+ id,
+ name: "Test Provider",
+ source: "config",
+ env: [],
+ options: {},
+ models: { [mdl.id]: mdl },
+ ...override,
+ }
+ }
+
+ export function fake(override: Partial<Provider.Interface> & { model?: Provider.Model; info?: Provider.Info } = {}) {
+ const mdl = override.model ?? model()
+ const row = override.info ?? info({}, mdl)
+ return {
+ model: mdl,
+ info: row,
+ layer: Layer.succeed(
+ Provider.Service,
+ Provider.Service.of({
+ list: Effect.fn("TestProvider.list")(() => Effect.succeed({ [row.id]: row })),
+ getProvider: Effect.fn("TestProvider.getProvider")((providerID) => {
+ if (providerID === row.id) return Effect.succeed(row)
+ return Effect.die(new Error(`Unknown test provider: ${providerID}`))
+ }),
+ getModel: Effect.fn("TestProvider.getModel")((providerID, modelID) => {
+ if (providerID === row.id && modelID === mdl.id) return Effect.succeed(mdl)
+ return Effect.die(new Error(`Unknown test model: ${providerID}/${modelID}`))
+ }),
+ getLanguage: Effect.fn("TestProvider.getLanguage")(() =>
+ Effect.die(new Error("ProviderTest.getLanguage not configured")),
+ ),
+ closest: Effect.fn("TestProvider.closest")((providerID) =>
+ Effect.succeed(providerID === row.id ? { providerID: row.id, modelID: mdl.id } : undefined),
+ ),
+ getSmallModel: Effect.fn("TestProvider.getSmallModel")((providerID) =>
+ Effect.succeed(providerID === row.id ? mdl : undefined),
+ ),
+ defaultModel: Effect.fn("TestProvider.defaultModel")(() =>
+ Effect.succeed({ providerID: row.id, modelID: mdl.id }),
+ ),
+ ...override,
+ }),
+ ),
+ }
+ }
+}
diff --git a/packages/opencode/test/session/compaction.test.ts b/packages/opencode/test/session/compaction.test.ts
index e6d715728..f1d61babf 100644
--- a/packages/opencode/test/session/compaction.test.ts
+++ b/packages/opencode/test/session/compaction.test.ts
@@ -1,4 +1,4 @@
-import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
+import { afterEach, describe, expect, mock, test } from "bun:test"
import { APICallError } from "ai"
import { Cause, Effect, Exit, Layer, ManagedRuntime } from "effect"
import * as Stream from "effect/Stream"
@@ -20,9 +20,9 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema"
import { SessionStatus } from "../../src/session/status"
import { ModelID, ProviderID } from "../../src/provider/schema"
import type { Provider } from "../../src/provider/provider"
-import * as ProviderModule from "../../src/provider/provider"
import * as SessionProcessorModule from "../../src/session/processor"
import { Snapshot } from "../../src/snapshot"
+import { ProviderTest } from "../fake/provider"
Log.init({ print: false })
@@ -65,6 +65,8 @@ function createModel(opts: {
} as Provider.Model
}
+const wide = () => ProviderTest.fake({ model: createModel({ context: 100_000, output: 32_000 }) })
+
async function user(sessionID: SessionID, text: string) {
const msg = await Session.updateMessage({
id: MessageID.ascending(),
@@ -162,10 +164,11 @@ function layer(result: "continue" | "compact") {
)
}
-function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer) {
+function runtime(result: "continue" | "compact", plugin = Plugin.defaultLayer, provider = ProviderTest.fake()) {
const bus = Bus.layer
return ManagedRuntime.make(
Layer.mergeAll(SessionCompaction.layer, bus).pipe(
+ Layer.provide(provider.layer),
Layer.provide(Session.defaultLayer),
Layer.provide(layer(result)),
Layer.provide(Agent.defaultLayer),
@@ -198,12 +201,13 @@ function llm() {
}
}
-function liveRuntime(layer: Layer.Layer<LLM.Service>) {
+function liveRuntime(layer: Layer.Layer<LLM.Service>, provider = ProviderTest.fake()) {
const bus = Bus.layer
const status = SessionStatus.layer.pipe(Layer.provide(bus))
const processor = SessionProcessorModule.SessionProcessor.layer
return ManagedRuntime.make(
Layer.mergeAll(SessionCompaction.layer.pipe(Layer.provide(processor)), processor, bus, status).pipe(
+ Layer.provide(provider.layer),
Layer.provide(Session.defaultLayer),
Layer.provide(Snapshot.defaultLayer),
Layer.provide(layer),
@@ -544,14 +548,12 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id })
const done = defer()
let seen = false
- const rt = runtime("continue")
+ const rt = runtime("continue", Plugin.defaultLayer, wide())
let unsub: (() => void) | undefined
try {
unsub = await rt.runPromise(
@@ -596,11 +598,9 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
- const rt = runtime("compact")
+ const rt = runtime("compact", Plugin.defaultLayer, wide())
try {
const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise(
@@ -636,11 +636,9 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
- const rt = runtime("continue")
+ const rt = runtime("continue", Plugin.defaultLayer, wide())
try {
const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise(
@@ -678,8 +676,6 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
await user(session.id, "root")
const replay = await user(session.id, "image")
@@ -693,7 +689,7 @@ describe("session.compaction.process", () => {
url: "https://example.com/cat.png",
})
const msg = await user(session.id, "current")
- const rt = runtime("continue")
+ const rt = runtime("continue", Plugin.defaultLayer, wide())
try {
const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise(
@@ -728,13 +724,11 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
await user(session.id, "earlier")
const msg = await user(session.id, "current")
- const rt = runtime("continue")
+ const rt = runtime("continue", Plugin.defaultLayer, wide())
try {
const msgs = await Session.messages({ sessionID: session.id })
const result = await rt.runPromise(
@@ -790,13 +784,11 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id })
const abort = new AbortController()
- const rt = liveRuntime(stub.layer)
+ const rt = liveRuntime(stub.layer, wide())
let off: (() => void) | undefined
let run: Promise<"continue" | "stop"> | undefined
try {
@@ -866,13 +858,11 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
const msgs = await Session.messages({ sessionID: session.id })
const abort = new AbortController()
- const rt = runtime("continue", plugin(ready))
+ const rt = runtime("continue", plugin(ready), wide())
let run: Promise<"continue" | "stop"> | undefined
try {
run = rt
@@ -970,11 +960,9 @@ describe("session.compaction.process", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
- spyOn(ProviderModule.Provider, "getModel").mockResolvedValue(createModel({ context: 100_000, output: 32_000 }))
-
const session = await Session.create({})
const msg = await user(session.id, "hello")
- const rt = liveRuntime(stub.layer)
+ const rt = liveRuntime(stub.layer, wide())
try {
const msgs = await Session.messages({ sessionID: session.id })
await rt.runPromise(
diff --git a/packages/opencode/test/session/prompt-concurrency.test.ts b/packages/opencode/test/session/prompt-concurrency.test.ts
deleted file mode 100644
index 19e1c4bf5..000000000
--- a/packages/opencode/test/session/prompt-concurrency.test.ts
+++ /dev/null
@@ -1,247 +0,0 @@
-import { describe, expect, spyOn, test } from "bun:test"
-import { Instance } from "../../src/project/instance"
-import { Provider } from "../../src/provider/provider"
-import { Session } from "../../src/session"
-import { MessageV2 } from "../../src/session/message-v2"
-import { SessionPrompt } from "../../src/session/prompt"
-import { SessionStatus } from "../../src/session/status"
-import { MessageID, PartID, SessionID } from "../../src/session/schema"
-import { Log } from "../../src/util/log"
-import { tmpdir } from "../fixture/fixture"
-
-Log.init({ print: false })
-
-function deferred() {
- let resolve!: () => void
- const promise = new Promise<void>((done) => {
- resolve = done
- })
- return { promise, resolve }
-}
-
-// Helper: seed a session with a user message + finished assistant message
-// so loop() exits immediately without calling any LLM
-async function seed(sessionID: SessionID) {
- const userMsg: MessageV2.Info = {
- id: MessageID.ascending(),
- role: "user",
- sessionID,
- time: { created: Date.now() },
- agent: "build",
- model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
- }
- await Session.updateMessage(userMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg.id,
- sessionID,
- type: "text",
- text: "hello",
- })
-
- const assistantMsg: MessageV2.Info = {
- id: MessageID.ascending(),
- role: "assistant",
- parentID: userMsg.id,
- sessionID,
- mode: "build",
- agent: "build",
- cost: 0,
- path: { cwd: "/tmp", root: "/tmp" },
- tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
- modelID: "gpt-5.2" as any,
- providerID: "openai" as any,
- time: { created: Date.now(), completed: Date.now() },
- finish: "stop",
- }
- await Session.updateMessage(assistantMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: assistantMsg.id,
- sessionID,
- type: "text",
- text: "hi there",
- })
-
- return { userMsg, assistantMsg }
-}
-
-describe("session.prompt concurrency", () => {
- test("loop returns assistant message and sets status to idle", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- await seed(session.id)
-
- const result = await SessionPrompt.loop({ sessionID: session.id })
- expect(result.info.role).toBe("assistant")
- if (result.info.role === "assistant") expect(result.info.finish).toBe("stop")
-
- const status = await SessionStatus.get(session.id)
- expect(status.type).toBe("idle")
- },
- })
- })
-
- test("concurrent loop callers get the same result", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- await seed(session.id)
-
- const [a, b] = await Promise.all([
- SessionPrompt.loop({ sessionID: session.id }),
- SessionPrompt.loop({ sessionID: session.id }),
- ])
-
- expect(a.info.id).toBe(b.info.id)
- expect(a.info.role).toBe("assistant")
- },
- })
- })
-
- test("assertNotBusy throws when loop is running", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- const userMsg: MessageV2.Info = {
- id: MessageID.ascending(),
- role: "user",
- sessionID: session.id,
- time: { created: Date.now() },
- agent: "build",
- model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
- }
- await Session.updateMessage(userMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg.id,
- sessionID: session.id,
- type: "text",
- text: "hello",
- })
-
- const ready = deferred()
- const gate = deferred()
- const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
- ready.resolve()
- await gate.promise
- throw new Error("test stop")
- })
-
- try {
- const loopPromise = SessionPrompt.loop({ sessionID: session.id }).catch(() => undefined)
- await ready.promise
-
- await expect(SessionPrompt.assertNotBusy(session.id)).rejects.toBeInstanceOf(Session.BusyError)
-
- gate.resolve()
- await loopPromise
- } finally {
- gate.resolve()
- getModel.mockRestore()
- }
-
- // After loop completes, assertNotBusy should succeed
- await SessionPrompt.assertNotBusy(session.id)
- },
- })
- })
-
- test("cancel sets status to idle", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- // Seed only a user message — loop must call getModel to proceed
- const userMsg: MessageV2.Info = {
- id: MessageID.ascending(),
- role: "user",
- sessionID: session.id,
- time: { created: Date.now() },
- agent: "build",
- model: { providerID: "openai" as any, modelID: "gpt-5.2" as any },
- }
- await Session.updateMessage(userMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg.id,
- sessionID: session.id,
- type: "text",
- text: "hello",
- })
- // Also seed an assistant message so lastAssistant() fallback can find it
- const assistantMsg: MessageV2.Info = {
- id: MessageID.ascending(),
- role: "assistant",
- parentID: userMsg.id,
- sessionID: session.id,
- mode: "build",
- agent: "build",
- cost: 0,
- path: { cwd: "/tmp", root: "/tmp" },
- tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
- modelID: "gpt-5.2" as any,
- providerID: "openai" as any,
- time: { created: Date.now() },
- }
- await Session.updateMessage(assistantMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: assistantMsg.id,
- sessionID: session.id,
- type: "text",
- text: "hi there",
- })
-
- const ready = deferred()
- const gate = deferred()
- const getModel = spyOn(Provider, "getModel").mockImplementation(async () => {
- ready.resolve()
- await gate.promise
- throw new Error("test stop")
- })
-
- try {
- // Start loop — it will block in getModel (assistant has no finish, so loop continues)
- const loopPromise = SessionPrompt.loop({ sessionID: session.id })
-
- await ready.promise
-
- await SessionPrompt.cancel(session.id)
-
- const status = await SessionStatus.get(session.id)
- expect(status.type).toBe("idle")
-
- // loop should resolve cleanly, not throw "All fibers interrupted"
- const result = await loopPromise
- expect(result.info.role).toBe("assistant")
- expect(result.info.id).toBe(assistantMsg.id)
- } finally {
- gate.resolve()
- getModel.mockRestore()
- }
- },
- })
- }, 10000)
-
- test("cancel on idle session just sets idle", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- await SessionPrompt.cancel(session.id)
- const status = await SessionStatus.get(session.id)
- expect(status.type).toBe("idle")
- },
- })
- })
-})
diff --git a/packages/opencode/test/session/prompt-effect.test.ts b/packages/opencode/test/session/prompt-effect.test.ts
index 98111bb3a..28b4cf15c 100644
--- a/packages/opencode/test/session/prompt-effect.test.ts
+++ b/packages/opencode/test/session/prompt-effect.test.ts
@@ -12,6 +12,7 @@ import { LSP } from "../../src/lsp"
import { MCP } from "../../src/mcp"
import { Permission } from "../../src/permission"
import { Plugin } from "../../src/plugin"
+import { Provider as ProviderSvc } from "../../src/provider/provider"
import type { Provider } from "../../src/provider/provider"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { Session } from "../../src/session"
@@ -151,6 +152,7 @@ function makeHttp() {
Permission.layer,
Plugin.defaultLayer,
Config.defaultLayer,
+ ProviderSvc.defaultLayer,
filetime,
lsp,
mcp,