summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAdam <[email protected]>2026-03-13 11:05:08 -0500
committerGitHub <[email protected]>2026-03-13 11:05:08 -0500
commit4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e (patch)
treeb7e5ed2b05aabb5ed5134520c4eb485c52eb5333
parent5c7088338c07ad632834ebd4a87feb23d255fb8a (diff)
downloadopencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.tar.gz
opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.zip
fix(app): model selection persist by session (#17348)
-rw-r--r--packages/app/e2e/fixtures.ts3
-rw-r--r--packages/app/e2e/selectors.ts3
-rw-r--r--packages/app/e2e/session/session-model-persistence.spec.ts351
-rw-r--r--packages/app/src/components/dialog-select-model-unpaid.tsx12
-rw-r--r--packages/app/src/components/dialog-select-model.tsx19
-rw-r--r--packages/app/src/components/prompt-input.tsx155
-rw-r--r--packages/app/src/components/prompt-input/submit.test.ts12
-rw-r--r--packages/app/src/components/prompt-input/submit.ts3
-rw-r--r--packages/app/src/context/local.tsx551
-rw-r--r--packages/app/src/context/model-variant.test.ts20
-rw-r--r--packages/app/src/context/model-variant.ts4
-rw-r--r--packages/app/src/pages/directory-layout.tsx6
-rw-r--r--packages/app/src/pages/session.tsx4
-rw-r--r--packages/app/src/pages/session/session-model-helpers.test.ts133
-rw-r--r--packages/app/src/pages/session/session-model-helpers.ts42
-rw-r--r--packages/app/src/pages/session/use-session-commands.tsx2
-rw-r--r--packages/app/src/testing/model-selection.ts80
-rw-r--r--packages/app/src/testing/terminal.ts6
-rw-r--r--packages/ui/src/components/select.tsx3
19 files changed, 969 insertions, 440 deletions
diff --git a/packages/app/e2e/fixtures.ts b/packages/app/e2e/fixtures.ts
index cf59eeb47..efefd479e 100644
--- a/packages/app/e2e/fixtures.ts
+++ b/packages/app/e2e/fixtures.ts
@@ -95,6 +95,9 @@ async function seedStorage(page: Page, input: { directory: string; extra?: strin
const win = window as E2EWindow
win.__opencode_e2e = {
...win.__opencode_e2e,
+ model: {
+ enabled: true,
+ },
terminal: {
enabled: true,
terminals: {},
diff --git a/packages/app/e2e/selectors.ts b/packages/app/e2e/selectors.ts
index 64b7bfe54..80b6c473d 100644
--- a/packages/app/e2e/selectors.ts
+++ b/packages/app/e2e/selectors.ts
@@ -13,6 +13,9 @@ export const sessionTodoToggleButtonSelector = '[data-action="session-todo-toggl
export const sessionTodoListSelector = '[data-slot="session-todo-list"]'
export const modelVariantCycleSelector = '[data-action="model-variant-cycle"]'
+export const promptAgentSelector = '[data-component="prompt-agent-control"]'
+export const promptModelSelector = '[data-component="prompt-model-control"]'
+export const promptVariantSelector = '[data-component="prompt-variant-control"]'
export const settingsLanguageSelectSelector = '[data-action="settings-language"]'
export const settingsColorSchemeSelector = '[data-action="settings-color-scheme"]'
export const settingsThemeSelector = '[data-action="settings-theme"]'
diff --git a/packages/app/e2e/session/session-model-persistence.spec.ts b/packages/app/e2e/session/session-model-persistence.spec.ts
new file mode 100644
index 000000000..4b09a5287
--- /dev/null
+++ b/packages/app/e2e/session/session-model-persistence.spec.ts
@@ -0,0 +1,351 @@
+import { base64Decode } from "@opencode-ai/util/encode"
+import type { Locator, Page } from "@playwright/test"
+import { test, expect } from "../fixtures"
+import { openSidebar, sessionIDFromUrl, setWorkspacesEnabled, waitSessionIdle, waitSlug } from "../actions"
+import {
+ promptAgentSelector,
+ promptModelSelector,
+ promptSelector,
+ promptVariantSelector,
+ workspaceItemSelector,
+ workspaceNewSessionSelector,
+} from "../selectors"
+import { createSdk, sessionPath } from "../utils"
+
+type Footer = {
+ agent: string
+ model: string
+ variant: string
+}
+
+type Probe = {
+ dir?: string
+ sessionID?: string
+ model?: { providerID: string; modelID: string }
+}
+
+const escape = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
+
+const text = async (locator: Locator) => ((await locator.textContent()) ?? "").trim()
+
+const modelKey = (state: Probe | null) => (state?.model ? `${state.model.providerID}:${state.model.modelID}` : null)
+
+const dirKey = (state: Probe | null) => state?.dir ?? ""
+
+async function probe(page: Page): Promise<Probe | null> {
+ return page.evaluate(() => {
+ const win = window as Window & {
+ __opencode_e2e?: {
+ model?: {
+ current?: Probe
+ }
+ }
+ }
+ return win.__opencode_e2e?.model?.current ?? null
+ })
+}
+
+async function currentDir(page: Page) {
+ let hit = ""
+ await expect
+ .poll(
+ async () => {
+ const next = dirKey(await probe(page))
+ if (next) hit = next
+ return next
+ },
+ { timeout: 30_000 },
+ )
+ .not.toBe("")
+ return hit
+}
+
+async function read(page: Page): Promise<Footer> {
+ return {
+ agent: await text(page.locator(`${promptAgentSelector} [data-slot="select-select-trigger-value"]`).first()),
+ model: await text(page.locator(`${promptModelSelector} [data-action="prompt-model"] span`).first()),
+ variant: await text(page.locator(`${promptVariantSelector} [data-slot="select-select-trigger-value"]`).first()),
+ }
+}
+
+async function waitFooter(page: Page, expected: Partial<Footer>) {
+ let hit: Footer | null = null
+ await expect
+ .poll(
+ async () => {
+ const state = await read(page)
+ const ok = Object.entries(expected).every(([key, value]) => state[key as keyof Footer] === value)
+ if (ok) hit = state
+ return ok
+ },
+ { timeout: 30_000 },
+ )
+ .toBe(true)
+ if (!hit) throw new Error("Failed to resolve prompt footer state")
+ return hit
+}
+
+async function waitModel(page: Page, value: string) {
+ await expect.poll(() => probe(page).then(modelKey), { timeout: 30_000 }).toBe(value)
+}
+
+async function choose(page: Page, root: string, value: string) {
+ const select = page.locator(root)
+ await expect(select).toBeVisible()
+ await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
+ const item = page
+ .locator('[data-slot="select-select-item"]')
+ .filter({ hasText: new RegExp(`^\\s*${escape(value)}\\s*$`) })
+ .first()
+ await expect(item).toBeVisible()
+ await item.click()
+}
+
+async function variantCount(page: Page) {
+ const select = page.locator(promptVariantSelector)
+ await expect(select).toBeVisible()
+ await select.locator('[data-slot="select-select-trigger"]').click()
+ const count = await page.locator('[data-slot="select-select-item"]').count()
+ await page.keyboard.press("Escape")
+ return count
+}
+
+async function agents(page: Page) {
+ const select = page.locator(promptAgentSelector)
+ await expect(select).toBeVisible()
+ await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
+ const labels = await page.locator('[data-slot="select-select-item-label"]').allTextContents()
+ await page.keyboard.press("Escape")
+ return labels.map((item) => item.trim()).filter(Boolean)
+}
+
+async function ensureVariant(page: Page, directory: string): Promise<Footer> {
+ const current = await read(page)
+ if ((await variantCount(page)) >= 2) return current
+
+ const cfg = await createSdk(directory)
+ .config.get()
+ .then((x) => x.data)
+ const visible = new Set(await agents(page))
+ const entry = Object.entries(cfg?.agent ?? {}).find((item) => {
+ const value = item[1]
+ return !!value && typeof value === "object" && "variant" in value && "model" in value && visible.has(item[0])
+ })
+ const name = entry?.[0]
+ test.skip(!name, "no agent with alternate variants available")
+ if (!name) return current
+
+ await choose(page, promptAgentSelector, name)
+ await expect.poll(() => variantCount(page), { timeout: 30_000 }).toBeGreaterThanOrEqual(2)
+ return waitFooter(page, { agent: name })
+}
+
+async function chooseDifferentVariant(page: Page): Promise<Footer> {
+ const current = await read(page)
+ const select = page.locator(promptVariantSelector)
+ await expect(select).toBeVisible()
+ await select.locator('[data-slot="select-select-trigger"]').click()
+
+ const items = page.locator('[data-slot="select-select-item"]')
+ const count = await items.count()
+ if (count < 2) throw new Error("Current model has no alternate variant to select")
+
+ for (let i = 0; i < count; i++) {
+ const item = items.nth(i)
+ const next = await text(item.locator('[data-slot="select-select-item-label"]').first())
+ if (!next || next === current.variant) continue
+ await item.click()
+ return waitFooter(page, { agent: current.agent, model: current.model, variant: next })
+ }
+
+ throw new Error("Failed to choose a different variant")
+}
+
+async function chooseOtherModel(page: Page): Promise<Footer> {
+ const current = await read(page)
+ const button = page.locator(`${promptModelSelector} [data-action="prompt-model"]`)
+ await expect(button).toBeVisible()
+ await button.click()
+
+ const dialog = page.getByRole("dialog")
+ await expect(dialog).toBeVisible()
+ const items = dialog.locator('[data-slot="list-item"]')
+ const count = await items.count()
+ expect(count).toBeGreaterThan(1)
+
+ for (let i = 0; i < count; i++) {
+ const item = items.nth(i)
+ const selected = (await item.getAttribute("data-selected")) === "true"
+ if (selected) continue
+ await item.click()
+ await expect(dialog).toHaveCount(0)
+ await expect.poll(async () => (await read(page)).model !== current.model, { timeout: 30_000 }).toBe(true)
+ return read(page)
+ }
+
+ throw new Error("Failed to choose a different model")
+}
+
+async function goto(page: Page, directory: string, sessionID?: string) {
+ await page.goto(sessionPath(directory, sessionID))
+ await expect(page.locator(promptSelector)).toBeVisible()
+ await expect.poll(async () => dirKey(await probe(page)), { timeout: 30_000 }).toBe(directory)
+}
+
+async function submit(page: Page, value: string) {
+ const prompt = page.locator(promptSelector)
+ await expect(prompt).toBeVisible()
+ await prompt.click()
+ await prompt.fill(value)
+ await prompt.press("Enter")
+
+ await expect.poll(() => sessionIDFromUrl(page.url()) ?? "", { timeout: 30_000 }).not.toBe("")
+ const id = sessionIDFromUrl(page.url())
+ if (!id) throw new Error(`Failed to resolve session id from ${page.url()}`)
+ return id
+}
+
+async function waitUser(directory: string, sessionID: string) {
+ const sdk = createSdk(directory)
+ await expect
+ .poll(
+ async () => {
+ const items = await sdk.session.messages({ sessionID, limit: 20 }).then((x) => x.data ?? [])
+ return items.some((item) => item.info.role === "user")
+ },
+ { timeout: 30_000 },
+ )
+ .toBe(true)
+ await sdk.session.abort({ sessionID }).catch(() => undefined)
+ await waitSessionIdle(sdk, sessionID, 30_000).catch(() => undefined)
+}
+
+async function createWorkspace(page: Page, root: string, seen: string[]) {
+ await openSidebar(page)
+ await page.getByRole("button", { name: "New workspace" }).first().click()
+
+ const slug = await waitSlug(page, [root, ...seen])
+ const directory = base64Decode(slug)
+ if (!directory) throw new Error(`Failed to decode workspace slug: ${slug}`)
+ return { slug, directory }
+}
+
+async function waitWorkspace(page: Page, slug: string) {
+ await openSidebar(page)
+ await expect
+ .poll(
+ async () => {
+ const item = page.locator(workspaceItemSelector(slug)).first()
+ try {
+ await item.hover({ timeout: 500 })
+ return true
+ } catch {
+ return false
+ }
+ },
+ { timeout: 60_000 },
+ )
+ .toBe(true)
+}
+
+async function newWorkspaceSession(page: Page, slug: string) {
+ await waitWorkspace(page, slug)
+ const item = page.locator(workspaceItemSelector(slug)).first()
+ await item.hover()
+
+ const button = page.locator(workspaceNewSessionSelector(slug)).first()
+ await expect(button).toBeVisible()
+ await button.click({ force: true })
+
+ const next = await waitSlug(page)
+ await expect(page).toHaveURL(new RegExp(`/${next}/session(?:[/?#]|$)`))
+ await expect(page.locator(promptSelector)).toBeVisible()
+ return currentDir(page)
+}
+
+test("session model and variant restore per session without leaking into new sessions", async ({
+ page,
+ withProject,
+}) => {
+ await page.setViewportSize({ width: 1440, height: 900 })
+
+ await withProject(async ({ directory, gotoSession, trackSession }) => {
+ await gotoSession()
+
+ await ensureVariant(page, directory)
+ const firstState = await chooseDifferentVariant(page)
+ const first = await submit(page, `session variant ${Date.now()}`)
+ trackSession(first)
+ await waitUser(directory, first)
+
+ await page.reload()
+ await expect(page.locator(promptSelector)).toBeVisible()
+ await waitFooter(page, firstState)
+
+ await gotoSession()
+ const fresh = await ensureVariant(page, directory)
+ expect(fresh.variant).not.toBe(firstState.variant)
+
+ const secondState = await chooseOtherModel(page)
+ const second = await submit(page, `session model ${Date.now()}`)
+ trackSession(second)
+ await waitUser(directory, second)
+
+ await goto(page, directory, first)
+ await waitFooter(page, firstState)
+
+ await goto(page, directory, second)
+ await waitFooter(page, secondState)
+
+ await gotoSession()
+ await waitFooter(page, fresh)
+ })
+})
+
+test("session model restore across workspaces", async ({ page, withProject }) => {
+ await page.setViewportSize({ width: 1440, height: 900 })
+
+ await withProject(async ({ directory: root, slug, gotoSession, trackDirectory, trackSession }) => {
+ await gotoSession()
+
+ await ensureVariant(page, root)
+ const firstState = await chooseDifferentVariant(page)
+ const first = await submit(page, `root session ${Date.now()}`)
+ trackSession(first, root)
+ await waitUser(root, first)
+
+ await openSidebar(page)
+ await setWorkspacesEnabled(page, slug, true)
+
+ const one = await createWorkspace(page, slug, [])
+ const oneDir = await newWorkspaceSession(page, one.slug)
+ trackDirectory(oneDir)
+
+ const secondState = await chooseOtherModel(page)
+ const second = await submit(page, `workspace one ${Date.now()}`)
+ trackSession(second, oneDir)
+ await waitUser(oneDir, second)
+
+ const two = await createWorkspace(page, slug, [one.slug])
+ const twoDir = await newWorkspaceSession(page, two.slug)
+ trackDirectory(twoDir)
+
+ await ensureVariant(page, twoDir)
+ const thirdState = await chooseDifferentVariant(page)
+ const third = await submit(page, `workspace two ${Date.now()}`)
+ trackSession(third, twoDir)
+ await waitUser(twoDir, third)
+
+ await goto(page, root, first)
+ await waitFooter(page, firstState)
+
+ await goto(page, oneDir, second)
+ await waitFooter(page, secondState)
+
+ await goto(page, twoDir, third)
+ await waitFooter(page, thirdState)
+
+ await goto(page, root, first)
+ await waitFooter(page, firstState)
+ })
+})
diff --git a/packages/app/src/components/dialog-select-model-unpaid.tsx b/packages/app/src/components/dialog-select-model-unpaid.tsx
index bcee3f501..2106b3a01 100644
--- a/packages/app/src/components/dialog-select-model-unpaid.tsx
+++ b/packages/app/src/components/dialog-select-model-unpaid.tsx
@@ -13,8 +13,10 @@ import { DialogSelectProvider } from "./dialog-select-provider"
import { ModelTooltip } from "./model-tooltip"
import { useLanguage } from "@/context/language"
-export const DialogSelectModelUnpaid: Component = () => {
- const local = useLocal()
+type ModelState = ReturnType<typeof useLocal>["model"]
+
+export const DialogSelectModelUnpaid: Component<{ model?: ModelState }> = (props) => {
+ const model = props.model ?? useLocal().model
const dialog = useDialog()
const providers = useProviders()
const language = useLanguage()
@@ -35,8 +37,8 @@ export const DialogSelectModelUnpaid: Component = () => {
<List
class="[&_[data-slot=list-scroll]]:overflow-visible"
ref={(ref) => (listRef = ref)}
- items={local.model.list}
- current={local.model.current()}
+ items={model.list}
+ current={model.current()}
key={(x) => `${x.provider.id}:${x.id}`}
itemWrapper={(item, node) => (
<Tooltip
@@ -55,7 +57,7 @@ export const DialogSelectModelUnpaid: Component = () => {
</Tooltip>
)}
onSelect={(x) => {
- local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
+ model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
recent: true,
})
dialog.close()
diff --git a/packages/app/src/components/dialog-select-model.tsx b/packages/app/src/components/dialog-select-model.tsx
index 9f7afb8cd..3654aab85 100644
--- a/packages/app/src/components/dialog-select-model.tsx
+++ b/packages/app/src/components/dialog-select-model.tsx
@@ -18,19 +18,22 @@ import { useLanguage } from "@/context/language"
const isFree = (provider: string, cost: { input: number } | undefined) =>
provider === "opencode" && (!cost || cost.input === 0)
+type ModelState = ReturnType<typeof useLocal>["model"]
+
const ModelList: Component<{
provider?: string
class?: string
onSelect: () => void
action?: JSX.Element
+ model?: ModelState
}> = (props) => {
- const local = useLocal()
+ const model = props.model ?? useLocal().model
const language = useLanguage()
const models = createMemo(() =>
- local.model
+ model
.list()
- .filter((m) => local.model.visible({ modelID: m.id, providerID: m.provider.id }))
+ .filter((m) => model.visible({ modelID: m.id, providerID: m.provider.id }))
.filter((m) => (props.provider ? m.provider.id === props.provider : true)),
)
@@ -41,7 +44,7 @@ const ModelList: Component<{
emptyMessage={language.t("dialog.model.empty")}
key={(x) => `${x.provider.id}:${x.id}`}
items={models}
- current={local.model.current()}
+ current={model.current()}
filterKeys={["provider.name", "name", "id"]}
sortBy={(a, b) => a.name.localeCompare(b.name)}
groupBy={(x) => x.provider.name}
@@ -63,7 +66,7 @@ const ModelList: Component<{
</Tooltip>
)}
onSelect={(x) => {
- local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
+ model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
recent: true,
})
props.onSelect()
@@ -88,6 +91,7 @@ type ModelSelectorTriggerProps = Omit<ComponentProps<typeof Kobalte.Trigger>, "a
export function ModelSelectorPopover(props: {
provider?: string
+ model?: ModelState
children?: JSX.Element
triggerAs?: ValidComponent
triggerProps?: ModelSelectorTriggerProps
@@ -151,6 +155,7 @@ export function ModelSelectorPopover(props: {
<Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title>
<ModelList
provider={props.provider}
+ model={props.model}
onSelect={() => setStore("open", false)}
class="p-1"
action={
@@ -184,7 +189,7 @@ export function ModelSelectorPopover(props: {
)
}
-export const DialogSelectModel: Component<{ provider?: string }> = (props) => {
+export const DialogSelectModel: Component<{ provider?: string; model?: ModelState }> = (props) => {
const dialog = useDialog()
const language = useLanguage()
@@ -202,7 +207,7 @@ export const DialogSelectModel: Component<{ provider?: string }> = (props) => {
</Button>
}
>
- <ModelList provider={props.provider} onSelect={() => dialog.close()} />
+ <ModelList provider={props.provider} model={props.model} onSelect={() => dialog.close()} />
<Button
variant="ghost"
class="ml-3 mt-5 mb-6 text-text-base self-start"
diff --git a/packages/app/src/components/prompt-input.tsx b/packages/app/src/components/prompt-input.tsx
index fd54de9a0..9048fa895 100644
--- a/packages/app/src/components/prompt-input.tsx
+++ b/packages/app/src/components/prompt-input.tsx
@@ -1430,39 +1430,76 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
<div class="size-4 shrink-0" />
</div>
<div class="flex items-center gap-1.5 min-w-0 flex-1">
- <TooltipKeybind
- placement="top"
- gutter={4}
- title={language.t("command.agent.cycle")}
- keybind={command.keybind("agent.cycle")}
- >
- <Select
- size="normal"
- options={agentNames()}
- current={local.agent.current()?.name ?? ""}
- onSelect={local.agent.set}
- class="capitalize max-w-[160px] text-text-base"
- valueClass="truncate text-13-regular text-text-base"
- triggerStyle={control()}
- variant="ghost"
- />
- </TooltipKeybind>
- <Show
- when={providers.paid().length > 0}
- fallback={
+ <div data-component="prompt-agent-control">
+ <TooltipKeybind
+ placement="top"
+ gutter={4}
+ title={language.t("command.agent.cycle")}
+ keybind={command.keybind("agent.cycle")}
+ >
+ <Select
+ size="normal"
+ options={agentNames()}
+ current={local.agent.current()?.name ?? ""}
+ onSelect={local.agent.set}
+ class="capitalize max-w-[160px] text-text-base"
+ valueClass="truncate text-13-regular text-text-base"
+ triggerStyle={control()}
+ triggerProps={{ "data-action": "prompt-agent" }}
+ variant="ghost"
+ />
+ </TooltipKeybind>
+ </div>
+ <div data-component="prompt-model-control">
+ <Show
+ when={providers.paid().length > 0}
+ fallback={
+ <TooltipKeybind
+ placement="top"
+ gutter={4}
+ title={language.t("command.model.choose")}
+ keybind={command.keybind("model.choose")}
+ >
+ <Button
+ data-action="prompt-model"
+ as="div"
+ variant="ghost"
+ size="normal"
+ class="min-w-0 max-w-[320px] text-13-regular text-text-base group"
+ style={control()}
+ onClick={() => dialog.show(() => <DialogSelectModelUnpaid model={local.model} />)}
+ >
+ <Show when={local.model.current()?.provider?.id}>
+ <ProviderIcon
+ id={local.model.current()!.provider.id}
+ class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
+ style={{ "will-change": "opacity", transform: "translateZ(0)" }}
+ />
+ </Show>
+ <span class="truncate">
+ {local.model.current()?.name ?? language.t("dialog.model.select.title")}
+ </span>
+ <Icon name="chevron-down" size="small" class="shrink-0" />
+ </Button>
+ </TooltipKeybind>
+ }
+ >
<TooltipKeybind
placement="top"
gutter={4}
title={language.t("command.model.choose")}
keybind={command.keybind("model.choose")}
>
- <Button
- as="div"
- variant="ghost"
- size="normal"
- class="min-w-0 max-w-[320px] text-13-regular text-text-base group"
- style={control()}
- onClick={() => dialog.show(() => <DialogSelectModelUnpaid />)}
+ <ModelSelectorPopover
+ model={local.model}
+ triggerAs={Button}
+ triggerProps={{
+ variant: "ghost",
+ size: "normal",
+ style: control(),
+ class: "min-w-0 max-w-[320px] text-13-regular text-text-base group",
+ "data-action": "prompt-model",
+ }}
>
<Show when={local.model.current()?.provider?.id}>
<ProviderIcon
@@ -1475,57 +1512,31 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
{local.model.current()?.name ?? language.t("dialog.model.select.title")}
</span>
<Icon name="chevron-down" size="small" class="shrink-0" />
- </Button>
+ </ModelSelectorPopover>
</TooltipKeybind>
- }
- >
+ </Show>
+ </div>
+ <div data-component="prompt-variant-control">
<TooltipKeybind
placement="top"
gutter={4}
- title={language.t("command.model.choose")}
- keybind={command.keybind("model.choose")}
+ title={language.t("command.model.variant.cycle")}
+ keybind={command.keybind("model.variant.cycle")}
>
- <ModelSelectorPopover
- triggerAs={Button}
- triggerProps={{
- variant: "ghost",
- size: "normal",
- style: control(),
- class: "min-w-0 max-w-[320px] text-13-regular text-text-base group",
- }}
- >
- <Show when={local.model.current()?.provider?.id}>
- <ProviderIcon
- id={local.model.current()!.provider.id}
- class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
- style={{ "will-change": "opacity", transform: "translateZ(0)" }}
- />
- </Show>
- <span class="truncate">
- {local.model.current()?.name ?? language.t("dialog.model.select.title")}
- </span>
- <Icon name="chevron-down" size="small" class="shrink-0" />
- </ModelSelectorPopover>
+ <Select
+ size="normal"
+ options={variants()}
+ current={local.model.variant.current() ?? "default"}
+ label={(x) => (x === "default" ? language.t("common.default") : x)}
+ onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
+ class="capitalize max-w-[160px] text-text-base"
+ valueClass="truncate text-13-regular text-text-base"
+ triggerStyle={control()}
+ triggerProps={{ "data-action": "prompt-model-variant" }}
+ variant="ghost"
+ />
</TooltipKeybind>
- </Show>
- <TooltipKeybind
- placement="top"
- gutter={4}
- title={language.t("command.model.variant.cycle")}
- keybind={command.keybind("model.variant.cycle")}
- >
- <Select
- size="normal"
- options={variants()}
- current={local.model.variant.current() ?? "default"}
- label={(x) => (x === "default" ? language.t("common.default") : x)}
- onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
- class="capitalize max-w-[160px] text-text-base"
- valueClass="truncate text-13-regular text-text-base"
- triggerStyle={control()}
- variant="ghost"
- />
- </TooltipKeybind>
+ </div>
<TooltipKeybind
placement="top"
gutter={8}
diff --git a/packages/app/src/components/prompt-input/submit.test.ts b/packages/app/src/components/prompt-input/submit.test.ts
index 9f7fac69d..b0166c43a 100644
--- a/packages/app/src/components/prompt-input/submit.test.ts
+++ b/packages/app/src/components/prompt-input/submit.test.ts
@@ -17,6 +17,7 @@ const optimistic: Array<{
}> = []
const optimisticSeeded: boolean[] = []
const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {}
+const promoted: Array<{ directory: string; sessionID: string }> = []
const sentShell: string[] = []
const syncedDirectories: string[] = []
@@ -86,6 +87,11 @@ beforeAll(async () => {
agent: {
current: () => ({ name: "agent" }),
},
+ session: {
+ promote(directory: string, sessionID: string) {
+ promoted.push({ directory, sessionID })
+ },
+ },
}),
}))
@@ -201,6 +207,7 @@ beforeEach(() => {
enabledAutoAccept.length = 0
optimistic.length = 0
optimisticSeeded.length = 0
+ promoted.length = 0
params = {}
sentShell.length = 0
syncedDirectories.length = 0
@@ -240,6 +247,11 @@ describe("prompt submit worktree selection", () => {
expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
+ expect(promoted).toEqual([
+ { directory: "/repo/worktree-a", sessionID: "session-1" },
+ { directory: "/repo/worktree-b", sessionID: "session-2" },
+ ])
+ expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
})
test("applies auto-accept to newly created sessions", async () => {
diff --git a/packages/app/src/components/prompt-input/submit.ts b/packages/app/src/components/prompt-input/submit.ts
index e8d765cd9..ba299fe36 100644
--- a/packages/app/src/components/prompt-input/submit.ts
+++ b/packages/app/src/components/prompt-input/submit.ts
@@ -296,6 +296,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
const currentModel = local.model.current()
const currentAgent = local.agent.current()
+ const variant = local.model.variant.current()
if (!currentModel || !currentAgent) {
showToast({
title: language.t("prompt.toast.modelAgentRequired.title"),
@@ -370,6 +371,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
seed(sessionDirectory, created)
session = created
if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory)
+ local.session.promote(sessionDirectory, session.id)
layout.handoff.setTabs(base64Encode(sessionDirectory), session.id)
navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`)
}
@@ -387,7 +389,6 @@ export function createPromptSubmit(input: PromptSubmitInput) {
providerID: currentModel.provider.id,
}
const agent = currentAgent.name
- const variant = local.model.variant.current()
const context = prompt.context.items().slice()
const draft: FollowupDraft = {
sessionID: session.id,
diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx
index 75d1334a5..bed7ecd15 100644
--- a/packages/app/src/context/local.tsx
+++ b/packages/app/src/context/local.tsx
@@ -1,252 +1,421 @@
-import { createStore } from "solid-js/store"
-import { batch, createMemo } from "solid-js"
import { createSimpleContext } from "@opencode-ai/ui/context"
-import { useSDK } from "./sdk"
-import { useSync } from "./sync"
import { base64Encode } from "@opencode-ai/util/encode"
-import { useProviders } from "@/hooks/use-providers"
+import { useParams } from "@solidjs/router"
+import { batch, createEffect, createMemo, onCleanup } from "solid-js"
+import { createStore } from "solid-js/store"
import { useModels } from "@/context/models"
+import { useProviders } from "@/hooks/use-providers"
+import { modelEnabled, modelProbe } from "@/testing/model-selection"
+import { Persist, persisted } from "@/utils/persist"
import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant"
+import { useSDK } from "./sdk"
+import { useSync } from "./sync"
export type ModelKey = { providerID: string; modelID: string }
+type State = {
+ agent?: string
+ model?: ModelKey
+ variant?: string | null
+}
+
+type Saved = {
+ session: Record<string, State | undefined>
+}
+
+const WORKSPACE_KEY = "__workspace__"
+const handoff = new Map<string, State>()
+
+const handoffKey = (dir: string, id: string) => `${dir}\n${id}`
+
+const migrate = (value: unknown) => {
+ if (!value || typeof value !== "object") return { session: {} }
+
+ const item = value as {
+ session?: Record<string, State | undefined>
+ pick?: Record<string, State | undefined>
+ }
+
+ if (item.session && typeof item.session === "object") return { session: item.session }
+ if (!item.pick || typeof item.pick !== "object") return { session: {} }
+
+ return {
+ session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)),
+ }
+}
+
+const clone = (value: State | undefined) => {
+ if (!value) return undefined
+ return {
+ ...value,
+ model: value.model ? { ...value.model } : undefined,
+ } satisfies State
+}
+
export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
name: "Local",
init: () => {
+ const params = useParams()
const sdk = useSDK()
const sync = useSync()
const providers = useProviders()
- const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id)))
+ const models = useModels()
+
+ const id = createMemo(() => params.id || undefined)
+ const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden))
+ const connected = createMemo(() => new Set(providers.connected().map((item) => item.id)))
- function isModelValid(model: ModelKey) {
- const provider = providers.all().find((x) => x.id === model.providerID)
+ const [saved, setSaved] = persisted(
+ {
+ ...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]),
+ migrate,
+ },
+ createStore<Saved>({
+ session: {},
+ }),
+ )
+
+ const [store, setStore] = createStore<{
+ current?: string
+ draft?: State
+ last?: {
+ type: "agent" | "model" | "variant"
+ agent?: string
+ model?: ModelKey | null
+ variant?: string | null
+ }
+ }>({
+ current: list()[0]?.name,
+ draft: undefined,
+ last: undefined,
+ })
+
+ const validModel = (model: ModelKey) => {
+ const provider = providers.all().find((item) => item.id === model.providerID)
return !!provider?.models[model.modelID] && connected().has(model.providerID)
}
- function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) {
- for (const modelFn of modelFns) {
- const model = modelFn()
+ const firstModel = (...items: Array<() => ModelKey | undefined>) => {
+ for (const item of items) {
+ const model = item()
if (!model) continue
- if (isModelValid(model)) return model
+ if (validModel(model)) return model
}
}
- let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined
+ const pickAgent = (name: string | undefined) => {
+ const items = list()
+ if (items.length === 0) return undefined
+ return items.find((item) => item.name === name) ?? items[0]
+ }
- const agent = (() => {
- const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden))
- const models = useModels()
+ createEffect(() => {
+ const items = list()
+ if (items.length === 0) {
+ if (store.current !== undefined) setStore("current", undefined)
+ return
+ }
+ if (items.some((item) => item.name === store.current)) return
+ setStore("current", items[0]?.name)
+ })
- const [store, setStore] = createStore<{
- current?: string
- }>({
- current: list()[0]?.name,
- })
- return {
- list,
- current() {
- const available = list()
- if (available.length === 0) return undefined
- return available.find((x) => x.name === store.current) ?? available[0]
- },
- set(name: string | undefined) {
- const available = list()
- if (available.length === 0) {
- setStore("current", undefined)
- return
- }
- const match = name ? available.find((x) => x.name === name) : undefined
- const value = match ?? available[0]
- if (!value) return
- setStore("current", value.name)
- if (!value.model) return
- setModel({
- providerID: value.model.providerID,
- modelID: value.model.modelID,
- })
- if (value.variant)
- models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
- },
- move(direction: 1 | -1) {
- const available = list()
- if (available.length === 0) {
- setStore("current", undefined)
- return
- }
- let next = available.findIndex((x) => x.name === store.current) + direction
- if (next < 0) next = available.length - 1
- if (next >= available.length) next = 0
- const value = available[next]
- if (!value) return
- setStore("current", value.name)
- if (!value.model) return
- setModel({
- providerID: value.model.providerID,
- modelID: value.model.modelID,
- })
- if (value.variant)
- models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
- },
+ const scope = createMemo<State | undefined>(() => {
+ const session = id()
+ if (!session) return store.draft
+ return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session))
+ })
+
+ createEffect(() => {
+ const session = id()
+ if (!session) return
+
+ const key = handoffKey(sdk.directory, session)
+ const next = handoff.get(key)
+ if (!next) return
+ if (saved.session[session] !== undefined) {
+ handoff.delete(key)
+ return
}
- })()
- const model = (() => {
- const models = useModels()
+ setSaved("session", session, clone(next))
+ handoff.delete(key)
+ })
- const [ephemeral, setEphemeral] = createStore<{
- model: Record<string, ModelKey | undefined>
- }>({
- model: {},
- })
+ const configuredModel = () => {
+ if (!sync.data.config.model) return
+ const [providerID, modelID] = sync.data.config.model.split("/")
+ const model = { providerID, modelID }
+ if (validModel(model)) return model
+ }
- const resolveConfigured = () => {
- if (!sync.data.config.model) return
- const [providerID, modelID] = sync.data.config.model.split("/")
- const key = { providerID, modelID }
- if (isModelValid(key)) return key
+ const recentModel = () => {
+ for (const item of models.recent.list()) {
+ if (validModel(item)) return item
}
+ }
- const resolveRecent = () => {
- for (const item of models.recent.list()) {
- if (isModelValid(item)) return item
+ const defaultModel = () => {
+ const defaults = providers.default()
+ for (const provider of providers.connected()) {
+ const configured = defaults[provider.id]
+ if (configured) {
+ const model = { providerID: provider.id, modelID: configured }
+ if (validModel(model)) return model
}
+
+ const first = Object.values(provider.models)[0]
+ if (!first) continue
+ const model = { providerID: provider.id, modelID: first.id }
+ if (validModel(model)) return model
}
+ }
- const resolveDefault = () => {
- const defaults = providers.default()
- for (const provider of providers.connected()) {
- const configured = defaults[provider.id]
- if (configured) {
- const key = { providerID: provider.id, modelID: configured }
- if (isModelValid(key)) return key
- }
+ const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel())
- const first = Object.values(provider.models)[0]
- if (!first) continue
- const key = { providerID: provider.id, modelID: first.id }
- if (isModelValid(key)) return key
+ const agent = {
+ list,
+ current() {
+ return pickAgent(scope()?.agent ?? store.current)
+ },
+ set(name: string | undefined) {
+ const item = pickAgent(name)
+ if (!item) {
+ setStore("current", undefined)
+ return
}
- }
- const fallbackModel = createMemo<ModelKey | undefined>(() => {
- return resolveConfigured() ?? resolveRecent() ?? resolveDefault()
- })
+ batch(() => {
+ setStore("current", item.name)
+ setStore("last", {
+ type: "agent",
+ agent: item.name,
+ model: item.model,
+ variant: item.variant ?? null,
+ })
+ const next = {
+ agent: item.name,
+ model: item.model,
+ variant: item.variant,
+ } satisfies State
+ const session = id()
+ if (session) {
+ setSaved("session", session, next)
+ return
+ }
+ setStore("draft", next)
+ })
+ },
+ move(direction: 1 | -1) {
+ const items = list()
+ if (items.length === 0) {
+ setStore("current", undefined)
+ return
+ }
- const current = createMemo(() => {
- const a = agent.current()
- if (!a) return undefined
- const key = getFirstValidModel(
- () => ephemeral.model[a.name],
- () => a.model,
- fallbackModel,
- )
- if (!key) return undefined
- return models.find(key)
- })
+ let next = items.findIndex((item) => item.name === agent.current()?.name) + direction
+ if (next < 0) next = items.length - 1
+ if (next >= items.length) next = 0
+ const item = items[next]
+ if (!item) return
+ agent.set(item.name)
+ },
+ }
- const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
+ const current = () => {
+ const item = firstModel(
+ () => scope()?.model,
+ () => agent.current()?.model,
+ fallback,
+ )
+ if (!item) return undefined
+ return models.find(item)
+ }
- const cycle = (direction: 1 | -1) => {
- const recentList = recent()
- const currentModel = current()
- if (!currentModel) return
+ const configured = () => {
+ const item = agent.current()
+ const model = current()
+ if (!item || !model) return undefined
+ return getConfiguredAgentVariant({
+ agent: { model: item.model, variant: item.variant },
+ model: { providerID: model.provider.id, modelID: model.id, variants: model.variants },
+ })
+ }
- const index = recentList.findIndex(
- (x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id,
- )
- if (index === -1) return
+ const selected = () => scope()?.variant
- let next = index + direction
- if (next < 0) next = recentList.length - 1
- if (next >= recentList.length) next = 0
+ const snapshot = () => {
+ const model = current()
+ return {
+ agent: agent.current()?.name,
+ model: model ? { providerID: model.provider.id, modelID: model.id } : undefined,
+ variant: selected(),
+ } satisfies State
+ }
- const val = recentList[next]
- if (!val) return
+ const write = (next: Partial<State>) => {
+ const state = {
+ ...(scope() ?? { agent: agent.current()?.name }),
+ ...next,
+ } satisfies State
- model.set({
- providerID: val.provider.id,
- modelID: val.id,
- })
+ const session = id()
+ if (session) {
+ setSaved("session", session, state)
+ return
}
+ setStore("draft", state)
+ }
- const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => {
- batch(() => {
- const currentAgent = agent.current()
- const next = model ?? fallbackModel()
- if (currentAgent) setEphemeral("model", currentAgent.name, next)
- if (model) models.setVisibility(model, true)
- if (options?.recent && model) models.recent.push(model)
- })
- }
+ const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
- setModel = set
+ const model = {
+ ready: models.ready,
+ current,
+ recent,
+ list: models.list,
+ cycle(direction: 1 | -1) {
+ const items = recent()
+ const item = current()
+ if (!item) return
- return {
- ready: models.ready,
- current,
- recent,
- list: models.list,
- cycle,
- set,
- visible(model: ModelKey) {
- return models.visible(model)
+ const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id)
+ if (index === -1) return
+
+ let next = index + direction
+ if (next < 0) next = items.length - 1
+ if (next >= items.length) next = 0
+
+ const entry = items[next]
+ if (!entry) return
+ model.set({ providerID: entry.provider.id, modelID: entry.id })
+ },
+ set(item: ModelKey | undefined, options?: { recent?: boolean }) {
+ batch(() => {
+ setStore("last", {
+ type: "model",
+ agent: agent.current()?.name,
+ model: item ?? null,
+ variant: selected(),
+ })
+ write({ model: item })
+ if (!item) return
+ models.setVisibility(item, true)
+ if (!options?.recent) return
+ models.recent.push(item)
+ })
+ },
+ visible(item: ModelKey) {
+ return models.visible(item)
+ },
+ setVisibility(item: ModelKey, visible: boolean) {
+ models.setVisibility(item, visible)
+ },
+ variant: {
+ configured,
+ selected,
+ current() {
+ return resolveModelVariant({
+ variants: this.list(),
+ selected: this.selected(),
+ configured: this.configured(),
+ })
},
- setVisibility(model: ModelKey, visible: boolean) {
- models.setVisibility(model, visible)
+ list() {
+ const item = current()
+ if (!item?.variants) return []
+ return Object.keys(item.variants)
},
- variant: {
- configured() {
- const a = agent.current()
- const m = current()
- if (!a || !m) return undefined
- return getConfiguredAgentVariant({
- agent: { model: a.model, variant: a.variant },
- model: { providerID: m.provider.id, modelID: m.id, variants: m.variants },
+ set(value: string | undefined) {
+ batch(() => {
+ const model = current()
+ setStore("last", {
+ type: "variant",
+ agent: agent.current()?.name,
+ model: model ? { providerID: model.provider.id, modelID: model.id } : null,
+ variant: value ?? null,
})
- },
- selected() {
- const m = current()
- if (!m) return undefined
- return models.variant.get({ providerID: m.provider.id, modelID: m.id })
- },
- current() {
- return resolveModelVariant({
- variants: this.list(),
+ write({ variant: value ?? null })
+ })
+ },
+ cycle() {
+ const items = this.list()
+ if (items.length === 0) return
+ this.set(
+ cycleModelVariant({
+ variants: items,
selected: this.selected(),
configured: this.configured(),
- })
- },
- list() {
- const m = current()
- if (!m) return []
- if (!m.variants) return []
- return Object.keys(m.variants)
- },
- set(value: string | undefined) {
- const m = current()
- if (!m) return
- models.variant.set({ providerID: m.provider.id, modelID: m.id }, value)
- },
- cycle() {
- const variants = this.list()
- if (variants.length === 0) return
- this.set(
- cycleModelVariant({
- variants,
- selected: this.selected(),
- configured: this.configured(),
- }),
- )
- },
+ }),
+ )
},
- }
- })()
+ },
+ }
const result = {
slug: createMemo(() => base64Encode(sdk.directory)),
model,
agent,
+ session: {
+ reset() {
+ setStore("draft", undefined)
+ },
+ promote(dir: string, session: string) {
+ const next = clone(snapshot())
+ if (!next) return
+
+ if (dir === sdk.directory) {
+ setSaved("session", session, next)
+ setStore("draft", undefined)
+ return
+ }
+
+ handoff.set(handoffKey(dir, session), next)
+ setStore("draft", undefined)
+ },
+ restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
+ const session = id()
+ if (!session) return
+ if (msg.sessionID !== session) return
+ if (saved.session[session] !== undefined) return
+ if (handoff.has(handoffKey(sdk.directory, session))) return
+
+ setSaved("session", session, {
+ agent: msg.agent,
+ model: msg.model,
+ variant: msg.variant ?? null,
+ })
+ },
+ },
}
+
+ if (modelEnabled()) {
+ createEffect(() => {
+ const agent = result.agent.current()
+ const model = result.model.current()
+ modelProbe.set({
+ dir: sdk.directory,
+ sessionID: id(),
+ last: store.last,
+ agent: agent?.name,
+ model: model
+ ? {
+ providerID: model.provider.id,
+ modelID: model.id,
+ name: model.name,
+ }
+ : undefined,
+ variant: result.model.variant.current() ?? null,
+ selected: result.model.variant.selected(),
+ configured: result.model.variant.configured(),
+ pick: scope(),
+ base: undefined,
+ current: store.current,
+ })
+ })
+
+ onCleanup(() => modelProbe.clear())
+ }
+
return result
},
})
diff --git a/packages/app/src/context/model-variant.test.ts b/packages/app/src/context/model-variant.test.ts
index 01b149fd2..583bc5c3d 100644
--- a/packages/app/src/context/model-variant.test.ts
+++ b/packages/app/src/context/model-variant.test.ts
@@ -44,6 +44,16 @@ describe("model variant", () => {
expect(value).toBe("high")
})
+ test("lets an explicit default override the configured variant", () => {
+ const value = resolveModelVariant({
+ variants: ["low", "high", "xhigh"],
+ selected: null,
+ configured: "xhigh",
+ })
+
+ expect(value).toBeUndefined()
+ })
+
test("cycles from configured variant to next", () => {
const value = cycleModelVariant({
variants: ["low", "high", "xhigh"],
@@ -63,4 +73,14 @@ describe("model variant", () => {
expect(value).toBe("low")
})
+
+ test("cycles from an explicit default to the first variant", () => {
+ const value = cycleModelVariant({
+ variants: ["low", "high", "xhigh"],
+ selected: null,
+ configured: "xhigh",
+ })
+
+ expect(value).toBe("low")
+ })
})
diff --git a/packages/app/src/context/model-variant.ts b/packages/app/src/context/model-variant.ts
index 6b7ae7256..525acbba3 100644
--- a/packages/app/src/context/model-variant.ts
+++ b/packages/app/src/context/model-variant.ts
@@ -14,7 +14,7 @@ type Model = AgentModel & {
type VariantInput = {
variants: string[]
- selected: string | undefined
+ selected: string | null | undefined
configured: string | undefined
}
@@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod
}
export function resolveModelVariant(input: VariantInput) {
+ if (input.selected === null) return undefined
if (input.selected && input.variants.includes(input.selected)) return input.selected
if (input.configured && input.variants.includes(input.configured)) return input.configured
return undefined
@@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) {
export function cycleModelVariant(input: VariantInput) {
if (input.variants.length === 0) return undefined
+ if (input.selected === null) return input.variants[0]
if (input.selected && input.variants.includes(input.selected)) {
const index = input.variants.indexOf(input.selected)
if (index === input.variants.length - 1) return undefined
diff --git a/packages/app/src/pages/directory-layout.tsx b/packages/app/src/pages/directory-layout.tsx
index fdf321f2d..f993ffcd8 100644
--- a/packages/app/src/pages/directory-layout.tsx
+++ b/packages/app/src/pages/directory-layout.tsx
@@ -80,11 +80,11 @@ export default function Layout(props: ParentProps) {
})
return (
- <Show when={state.resolved}>
+ <Show when={state.resolved} keyed>
{(resolved) => (
- <SDKProvider directory={resolved}>
+ <SDKProvider directory={() => resolved}>
<SyncProvider>
- <DirectoryDataProvider directory={resolved()}>{props.children}</DirectoryDataProvider>
+ <DirectoryDataProvider directory={resolved}>{props.children}</DirectoryDataProvider>
</SyncProvider>
</SDKProvider>
)}
diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx
index 8399a1367..6d2917008 100644
--- a/packages/app/src/pages/session.tsx
+++ b/packages/app/src/pages/session.tsx
@@ -44,7 +44,7 @@ import { createOpenReviewFile, createSessionTabs, createSizing, focusTerminalByI
import { MessageTimeline } from "@/pages/session/message-timeline"
import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab"
import { useSessionLayout } from "@/pages/session/session-layout"
-import { resetSessionModel, syncSessionModel } from "@/pages/session/session-model-helpers"
+import { syncSessionModel } from "@/pages/session/session-model-helpers"
import { SessionSidePanel } from "@/pages/session/session-side-panel"
import { TerminalPanel } from "@/pages/session/terminal-panel"
import { useSessionCommands } from "@/pages/session/use-session-commands"
@@ -490,7 +490,7 @@ export default function Page() {
(next, prev) => {
if (!prev) return
if (next.dir === prev.dir && next.id === prev.id) return
- if (!next.id) resetSessionModel(local)
+ if (prev.id && !next.id) local.session.reset()
},
{ defer: true },
),
diff --git a/packages/app/src/pages/session/session-model-helpers.test.ts b/packages/app/src/pages/session/session-model-helpers.test.ts
index 5f554dcd3..319db805d 100644
--- a/packages/app/src/pages/session/session-model-helpers.test.ts
+++ b/packages/app/src/pages/session/session-model-helpers.test.ts
@@ -14,145 +14,38 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
}) as UserMessage
describe("syncSessionModel", () => {
- test("restores the last message model and variant", () => {
+ test("restores the last message through session state", () => {
const calls: unknown[] = []
syncSessionModel(
{
- agent: {
- current() {
- return undefined
- },
- set(value) {
- calls.push(["agent", value])
- },
- },
- model: {
- set(value) {
- calls.push(["model", value])
- },
- current() {
- return { id: "claude-sonnet-4", provider: { id: "anthropic" } }
- },
- variant: {
- set(value) {
- calls.push(["variant", value])
- },
- },
- },
- },
- message({ variant: "high" }),
- )
-
- expect(calls).toEqual([
- ["agent", "build"],
- ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
- ["variant", "high"],
- ])
- })
-
- test("skips variant when the model falls back", () => {
- const calls: unknown[] = []
-
- syncSessionModel(
- {
- agent: {
- current() {
- return undefined
- },
- set(value) {
- calls.push(["agent", value])
- },
- },
- model: {
- set(value) {
- calls.push(["model", value])
- },
- current() {
- return { id: "gpt-5", provider: { id: "openai" } }
- },
- variant: {
- set(value) {
- calls.push(["variant", value])
- },
+ session: {
+ restore(value) {
+ calls.push(value)
},
+ reset() {},
},
},
message({ variant: "high" }),
)
- expect(calls).toEqual([
- ["agent", "build"],
- ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
- ])
+ expect(calls).toEqual([message({ variant: "high" })])
})
})
describe("resetSessionModel", () => {
- test("restores the current agent defaults", () => {
- const calls: unknown[] = []
+ test("clears draft session state", () => {
+ const calls: string[] = []
resetSessionModel({
- agent: {
- current() {
- return {
- model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
- variant: "high",
- }
- },
- set() {},
- },
- model: {
- set(value) {
- calls.push(["model", value])
- },
- current() {
- return undefined
- },
- variant: {
- set(value) {
- calls.push(["variant", value])
- },
- },
- },
- })
-
- expect(calls).toEqual([
- ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
- ["variant", "high"],
- ])
- })
-
- test("clears the variant when the agent has none", () => {
- const calls: unknown[] = []
-
- resetSessionModel({
- agent: {
- current() {
- return {
- model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
- }
- },
- set() {},
- },
- model: {
- set(value) {
- calls.push(["model", value])
- },
- current() {
- return undefined
- },
- variant: {
- set(value) {
- calls.push(["variant", value])
- },
+ session: {
+ reset() {
+ calls.push("reset")
},
+ restore() {},
},
})
- expect(calls).toEqual([
- ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
- ["variant", undefined],
- ])
+ expect(calls).toEqual(["reset"])
})
})
diff --git a/packages/app/src/pages/session/session-model-helpers.ts b/packages/app/src/pages/session/session-model-helpers.ts
index 7600f16d5..c9e2e1dbd 100644
--- a/packages/app/src/pages/session/session-model-helpers.ts
+++ b/packages/app/src/pages/session/session-model-helpers.ts
@@ -1,48 +1,16 @@
import type { UserMessage } from "@opencode-ai/sdk/v2"
-import { batch } from "solid-js"
type Local = {
- agent: {
- current():
- | {
- model?: UserMessage["model"]
- variant?: string
- }
- | undefined
- set(name: string | undefined): void
- }
- model: {
- set(model: UserMessage["model"] | undefined): void
- current():
- | {
- id: string
- provider: { id: string }
- }
- | undefined
- variant: {
- set(value: string | undefined): void
- }
+ session: {
+ reset(): void
+ restore(msg: UserMessage): void
}
}
export const resetSessionModel = (local: Local) => {
- const agent = local.agent.current()
- if (!agent) return
- batch(() => {
- local.model.set(agent.model)
- local.model.variant.set(agent.variant)
- })
+ local.session.reset()
}
export const syncSessionModel = (local: Local, msg: UserMessage) => {
- batch(() => {
- local.agent.set(msg.agent)
- local.model.set(msg.model)
- })
-
- const model = local.model.current()
- if (!model) return
- if (model.provider.id !== msg.model.providerID) return
- if (model.id !== msg.model.modelID) return
- local.model.variant.set(msg.variant)
+ local.session.restore(msg)
}
diff --git a/packages/app/src/pages/session/use-session-commands.tsx b/packages/app/src/pages/session/use-session-commands.tsx
index f5a4c0576..1a2e777f5 100644
--- a/packages/app/src/pages/session/use-session-commands.tsx
+++ b/packages/app/src/pages/session/use-session-commands.tsx
@@ -351,7 +351,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => {
description: language.t("command.model.choose.description"),
keybind: "mod+'",
slash: "model",
- onSelect: () => dialog.show(() => <DialogSelectModel />),
+ onSelect: () => dialog.show(() => <DialogSelectModel model={local.model} />),
}),
mcpCommand({
id: "mcp.toggle",
diff --git a/packages/app/src/testing/model-selection.ts b/packages/app/src/testing/model-selection.ts
new file mode 100644
index 000000000..a5ea199ac
--- /dev/null
+++ b/packages/app/src/testing/model-selection.ts
@@ -0,0 +1,80 @@
+type ModelKey = {
+ providerID: string
+ modelID: string
+}
+
+type State = {
+ agent?: string
+ model?: ModelKey | null
+ variant?: string | null
+}
+
+export type ModelProbeState = {
+ dir?: string
+ sessionID?: string
+ last?: {
+ type: "agent" | "model" | "variant"
+ agent?: string
+ model?: ModelKey | null
+ variant?: string | null
+ }
+ agent?: string
+ model?: (ModelKey & { name?: string }) | undefined
+ variant?: string | null
+ selected?: string | null
+ configured?: string
+ pick?: State
+ base?: State
+ current?: string
+}
+
+export type ModelWindow = Window & {
+ __opencode_e2e?: {
+ model?: {
+ enabled?: boolean
+ current?: ModelProbeState
+ }
+ }
+}
+
+const clone = (state?: State) => {
+ if (!state) return undefined
+ return {
+ ...state,
+ model: state.model ? { ...state.model } : state.model,
+ }
+}
+
+export const modelEnabled = () => {
+ if (typeof window === "undefined") return false
+ return (window as ModelWindow).__opencode_e2e?.model?.enabled === true
+}
+
+const root = () => {
+ if (!modelEnabled()) return
+ return (window as ModelWindow).__opencode_e2e?.model
+}
+
+export const modelProbe = {
+ set(input: ModelProbeState) {
+ const state = root()
+ if (!state) return
+ state.current = {
+ ...input,
+ model: input.model ? { ...input.model } : undefined,
+ last: input.last
+ ? {
+ ...input.last,
+ model: input.last.model ? { ...input.last.model } : input.last.model,
+ }
+ : undefined,
+ pick: clone(input.pick),
+ base: clone(input.base),
+ }
+ },
+ clear() {
+ const state = root()
+ if (!state) return
+ state.current = undefined
+ },
+}
diff --git a/packages/app/src/testing/terminal.ts b/packages/app/src/testing/terminal.ts
index 4c179dee3..af1c33309 100644
--- a/packages/app/src/testing/terminal.ts
+++ b/packages/app/src/testing/terminal.ts
@@ -1,3 +1,5 @@
+import type { ModelProbeState } from "./model-selection"
+
export const terminalAttr = "data-pty-id"
export type TerminalProbeState = {
@@ -13,6 +15,10 @@ type TerminalProbeControl = {
export type E2EWindow = Window & {
__opencode_e2e?: {
+ model?: {
+ enabled?: boolean
+ current?: ModelProbeState
+ }
terminal?: {
enabled?: boolean
terminals?: Record<string, TerminalProbeState>
diff --git a/packages/ui/src/components/select.tsx b/packages/ui/src/components/select.tsx
index b370dbb64..61804a951 100644
--- a/packages/ui/src/components/select.tsx
+++ b/packages/ui/src/components/select.tsx
@@ -19,6 +19,7 @@ export type SelectProps<T> = Omit<ComponentProps<typeof Kobalte<T>>, "value" | "
children?: (item: T | undefined) => JSX.Element
triggerStyle?: JSX.CSSProperties
triggerVariant?: "settings"
+ triggerProps?: Record<string, string | number | boolean | undefined>
}
export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) {
@@ -38,6 +39,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
"children",
"triggerStyle",
"triggerVariant",
+ "triggerProps",
])
const state = {
@@ -131,6 +133,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
}}
>
<Kobalte.Trigger
+ {...local.triggerProps}
disabled={props.disabled}
data-slot="select-select-trigger"
as={Button}