diff options
| author | Adam <[email protected]> | 2026-03-13 11:05:08 -0500 |
|---|---|---|
| committer | GitHub <[email protected]> | 2026-03-13 11:05:08 -0500 |
| commit | 4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e (patch) | |
| tree | b7e5ed2b05aabb5ed5134520c4eb485c52eb5333 | |
| parent | 5c7088338c07ad632834ebd4a87feb23d255fb8a (diff) | |
| download | opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.tar.gz opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.zip | |
fix(app): model selection persist by session (#17348)
19 files changed, 969 insertions, 440 deletions
diff --git a/packages/app/e2e/fixtures.ts b/packages/app/e2e/fixtures.ts index cf59eeb47..efefd479e 100644 --- a/packages/app/e2e/fixtures.ts +++ b/packages/app/e2e/fixtures.ts @@ -95,6 +95,9 @@ async function seedStorage(page: Page, input: { directory: string; extra?: strin const win = window as E2EWindow win.__opencode_e2e = { ...win.__opencode_e2e, + model: { + enabled: true, + }, terminal: { enabled: true, terminals: {}, diff --git a/packages/app/e2e/selectors.ts b/packages/app/e2e/selectors.ts index 64b7bfe54..80b6c473d 100644 --- a/packages/app/e2e/selectors.ts +++ b/packages/app/e2e/selectors.ts @@ -13,6 +13,9 @@ export const sessionTodoToggleButtonSelector = '[data-action="session-todo-toggl export const sessionTodoListSelector = '[data-slot="session-todo-list"]' export const modelVariantCycleSelector = '[data-action="model-variant-cycle"]' +export const promptAgentSelector = '[data-component="prompt-agent-control"]' +export const promptModelSelector = '[data-component="prompt-model-control"]' +export const promptVariantSelector = '[data-component="prompt-variant-control"]' export const settingsLanguageSelectSelector = '[data-action="settings-language"]' export const settingsColorSchemeSelector = '[data-action="settings-color-scheme"]' export const settingsThemeSelector = '[data-action="settings-theme"]' diff --git a/packages/app/e2e/session/session-model-persistence.spec.ts b/packages/app/e2e/session/session-model-persistence.spec.ts new file mode 100644 index 000000000..4b09a5287 --- /dev/null +++ b/packages/app/e2e/session/session-model-persistence.spec.ts @@ -0,0 +1,351 @@ +import { base64Decode } from "@opencode-ai/util/encode" +import type { Locator, Page } from "@playwright/test" +import { test, expect } from "../fixtures" +import { openSidebar, sessionIDFromUrl, setWorkspacesEnabled, waitSessionIdle, waitSlug } from "../actions" +import { + promptAgentSelector, + promptModelSelector, + promptSelector, + promptVariantSelector, + workspaceItemSelector, + workspaceNewSessionSelector, +} from "../selectors" +import { createSdk, sessionPath } from "../utils" + +type Footer = { + agent: string + model: string + variant: string +} + +type Probe = { + dir?: string + sessionID?: string + model?: { providerID: string; modelID: string } +} + +const escape = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&") + +const text = async (locator: Locator) => ((await locator.textContent()) ?? "").trim() + +const modelKey = (state: Probe | null) => (state?.model ? `${state.model.providerID}:${state.model.modelID}` : null) + +const dirKey = (state: Probe | null) => state?.dir ?? "" + +async function probe(page: Page): Promise<Probe | null> { + return page.evaluate(() => { + const win = window as Window & { + __opencode_e2e?: { + model?: { + current?: Probe + } + } + } + return win.__opencode_e2e?.model?.current ?? null + }) +} + +async function currentDir(page: Page) { + let hit = "" + await expect + .poll( + async () => { + const next = dirKey(await probe(page)) + if (next) hit = next + return next + }, + { timeout: 30_000 }, + ) + .not.toBe("") + return hit +} + +async function read(page: Page): Promise<Footer> { + return { + agent: await text(page.locator(`${promptAgentSelector} [data-slot="select-select-trigger-value"]`).first()), + model: await text(page.locator(`${promptModelSelector} [data-action="prompt-model"] span`).first()), + variant: await text(page.locator(`${promptVariantSelector} [data-slot="select-select-trigger-value"]`).first()), + } +} + +async function waitFooter(page: Page, expected: Partial<Footer>) { + let hit: Footer | null = null + await expect + .poll( + async () => { + const state = await read(page) + const ok = Object.entries(expected).every(([key, value]) => state[key as keyof Footer] === value) + if (ok) hit = state + return ok + }, + { timeout: 30_000 }, + ) + .toBe(true) + if (!hit) throw new Error("Failed to resolve prompt footer state") + return hit +} + +async function waitModel(page: Page, value: string) { + await expect.poll(() => probe(page).then(modelKey), { timeout: 30_000 }).toBe(value) +} + +async function choose(page: Page, root: string, value: string) { + const select = page.locator(root) + await expect(select).toBeVisible() + await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click() + const item = page + .locator('[data-slot="select-select-item"]') + .filter({ hasText: new RegExp(`^\\s*${escape(value)}\\s*$`) }) + .first() + await expect(item).toBeVisible() + await item.click() +} + +async function variantCount(page: Page) { + const select = page.locator(promptVariantSelector) + await expect(select).toBeVisible() + await select.locator('[data-slot="select-select-trigger"]').click() + const count = await page.locator('[data-slot="select-select-item"]').count() + await page.keyboard.press("Escape") + return count +} + +async function agents(page: Page) { + const select = page.locator(promptAgentSelector) + await expect(select).toBeVisible() + await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click() + const labels = await page.locator('[data-slot="select-select-item-label"]').allTextContents() + await page.keyboard.press("Escape") + return labels.map((item) => item.trim()).filter(Boolean) +} + +async function ensureVariant(page: Page, directory: string): Promise<Footer> { + const current = await read(page) + if ((await variantCount(page)) >= 2) return current + + const cfg = await createSdk(directory) + .config.get() + .then((x) => x.data) + const visible = new Set(await agents(page)) + const entry = Object.entries(cfg?.agent ?? {}).find((item) => { + const value = item[1] + return !!value && typeof value === "object" && "variant" in value && "model" in value && visible.has(item[0]) + }) + const name = entry?.[0] + test.skip(!name, "no agent with alternate variants available") + if (!name) return current + + await choose(page, promptAgentSelector, name) + await expect.poll(() => variantCount(page), { timeout: 30_000 }).toBeGreaterThanOrEqual(2) + return waitFooter(page, { agent: name }) +} + +async function chooseDifferentVariant(page: Page): Promise<Footer> { + const current = await read(page) + const select = page.locator(promptVariantSelector) + await expect(select).toBeVisible() + await select.locator('[data-slot="select-select-trigger"]').click() + + const items = page.locator('[data-slot="select-select-item"]') + const count = await items.count() + if (count < 2) throw new Error("Current model has no alternate variant to select") + + for (let i = 0; i < count; i++) { + const item = items.nth(i) + const next = await text(item.locator('[data-slot="select-select-item-label"]').first()) + if (!next || next === current.variant) continue + await item.click() + return waitFooter(page, { agent: current.agent, model: current.model, variant: next }) + } + + throw new Error("Failed to choose a different variant") +} + +async function chooseOtherModel(page: Page): Promise<Footer> { + const current = await read(page) + const button = page.locator(`${promptModelSelector} [data-action="prompt-model"]`) + await expect(button).toBeVisible() + await button.click() + + const dialog = page.getByRole("dialog") + await expect(dialog).toBeVisible() + const items = dialog.locator('[data-slot="list-item"]') + const count = await items.count() + expect(count).toBeGreaterThan(1) + + for (let i = 0; i < count; i++) { + const item = items.nth(i) + const selected = (await item.getAttribute("data-selected")) === "true" + if (selected) continue + await item.click() + await expect(dialog).toHaveCount(0) + await expect.poll(async () => (await read(page)).model !== current.model, { timeout: 30_000 }).toBe(true) + return read(page) + } + + throw new Error("Failed to choose a different model") +} + +async function goto(page: Page, directory: string, sessionID?: string) { + await page.goto(sessionPath(directory, sessionID)) + await expect(page.locator(promptSelector)).toBeVisible() + await expect.poll(async () => dirKey(await probe(page)), { timeout: 30_000 }).toBe(directory) +} + +async function submit(page: Page, value: string) { + const prompt = page.locator(promptSelector) + await expect(prompt).toBeVisible() + await prompt.click() + await prompt.fill(value) + await prompt.press("Enter") + + await expect.poll(() => sessionIDFromUrl(page.url()) ?? "", { timeout: 30_000 }).not.toBe("") + const id = sessionIDFromUrl(page.url()) + if (!id) throw new Error(`Failed to resolve session id from ${page.url()}`) + return id +} + +async function waitUser(directory: string, sessionID: string) { + const sdk = createSdk(directory) + await expect + .poll( + async () => { + const items = await sdk.session.messages({ sessionID, limit: 20 }).then((x) => x.data ?? []) + return items.some((item) => item.info.role === "user") + }, + { timeout: 30_000 }, + ) + .toBe(true) + await sdk.session.abort({ sessionID }).catch(() => undefined) + await waitSessionIdle(sdk, sessionID, 30_000).catch(() => undefined) +} + +async function createWorkspace(page: Page, root: string, seen: string[]) { + await openSidebar(page) + await page.getByRole("button", { name: "New workspace" }).first().click() + + const slug = await waitSlug(page, [root, ...seen]) + const directory = base64Decode(slug) + if (!directory) throw new Error(`Failed to decode workspace slug: ${slug}`) + return { slug, directory } +} + +async function waitWorkspace(page: Page, slug: string) { + await openSidebar(page) + await expect + .poll( + async () => { + const item = page.locator(workspaceItemSelector(slug)).first() + try { + await item.hover({ timeout: 500 }) + return true + } catch { + return false + } + }, + { timeout: 60_000 }, + ) + .toBe(true) +} + +async function newWorkspaceSession(page: Page, slug: string) { + await waitWorkspace(page, slug) + const item = page.locator(workspaceItemSelector(slug)).first() + await item.hover() + + const button = page.locator(workspaceNewSessionSelector(slug)).first() + await expect(button).toBeVisible() + await button.click({ force: true }) + + const next = await waitSlug(page) + await expect(page).toHaveURL(new RegExp(`/${next}/session(?:[/?#]|$)`)) + await expect(page.locator(promptSelector)).toBeVisible() + return currentDir(page) +} + +test("session model and variant restore per session without leaking into new sessions", async ({ + page, + withProject, +}) => { + await page.setViewportSize({ width: 1440, height: 900 }) + + await withProject(async ({ directory, gotoSession, trackSession }) => { + await gotoSession() + + await ensureVariant(page, directory) + const firstState = await chooseDifferentVariant(page) + const first = await submit(page, `session variant ${Date.now()}`) + trackSession(first) + await waitUser(directory, first) + + await page.reload() + await expect(page.locator(promptSelector)).toBeVisible() + await waitFooter(page, firstState) + + await gotoSession() + const fresh = await ensureVariant(page, directory) + expect(fresh.variant).not.toBe(firstState.variant) + + const secondState = await chooseOtherModel(page) + const second = await submit(page, `session model ${Date.now()}`) + trackSession(second) + await waitUser(directory, second) + + await goto(page, directory, first) + await waitFooter(page, firstState) + + await goto(page, directory, second) + await waitFooter(page, secondState) + + await gotoSession() + await waitFooter(page, fresh) + }) +}) + +test("session model restore across workspaces", async ({ page, withProject }) => { + await page.setViewportSize({ width: 1440, height: 900 }) + + await withProject(async ({ directory: root, slug, gotoSession, trackDirectory, trackSession }) => { + await gotoSession() + + await ensureVariant(page, root) + const firstState = await chooseDifferentVariant(page) + const first = await submit(page, `root session ${Date.now()}`) + trackSession(first, root) + await waitUser(root, first) + + await openSidebar(page) + await setWorkspacesEnabled(page, slug, true) + + const one = await createWorkspace(page, slug, []) + const oneDir = await newWorkspaceSession(page, one.slug) + trackDirectory(oneDir) + + const secondState = await chooseOtherModel(page) + const second = await submit(page, `workspace one ${Date.now()}`) + trackSession(second, oneDir) + await waitUser(oneDir, second) + + const two = await createWorkspace(page, slug, [one.slug]) + const twoDir = await newWorkspaceSession(page, two.slug) + trackDirectory(twoDir) + + await ensureVariant(page, twoDir) + const thirdState = await chooseDifferentVariant(page) + const third = await submit(page, `workspace two ${Date.now()}`) + trackSession(third, twoDir) + await waitUser(twoDir, third) + + await goto(page, root, first) + await waitFooter(page, firstState) + + await goto(page, oneDir, second) + await waitFooter(page, secondState) + + await goto(page, twoDir, third) + await waitFooter(page, thirdState) + + await goto(page, root, first) + await waitFooter(page, firstState) + }) +}) diff --git a/packages/app/src/components/dialog-select-model-unpaid.tsx b/packages/app/src/components/dialog-select-model-unpaid.tsx index bcee3f501..2106b3a01 100644 --- a/packages/app/src/components/dialog-select-model-unpaid.tsx +++ b/packages/app/src/components/dialog-select-model-unpaid.tsx @@ -13,8 +13,10 @@ import { DialogSelectProvider } from "./dialog-select-provider" import { ModelTooltip } from "./model-tooltip" import { useLanguage } from "@/context/language" -export const DialogSelectModelUnpaid: Component = () => { - const local = useLocal() +type ModelState = ReturnType<typeof useLocal>["model"] + +export const DialogSelectModelUnpaid: Component<{ model?: ModelState }> = (props) => { + const model = props.model ?? useLocal().model const dialog = useDialog() const providers = useProviders() const language = useLanguage() @@ -35,8 +37,8 @@ export const DialogSelectModelUnpaid: Component = () => { <List class="[&_[data-slot=list-scroll]]:overflow-visible" ref={(ref) => (listRef = ref)} - items={local.model.list} - current={local.model.current()} + items={model.list} + current={model.current()} key={(x) => `${x.provider.id}:${x.id}`} itemWrapper={(item, node) => ( <Tooltip @@ -55,7 +57,7 @@ export const DialogSelectModelUnpaid: Component = () => { </Tooltip> )} onSelect={(x) => { - local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { + model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { recent: true, }) dialog.close() diff --git a/packages/app/src/components/dialog-select-model.tsx b/packages/app/src/components/dialog-select-model.tsx index 9f7afb8cd..3654aab85 100644 --- a/packages/app/src/components/dialog-select-model.tsx +++ b/packages/app/src/components/dialog-select-model.tsx @@ -18,19 +18,22 @@ import { useLanguage } from "@/context/language" const isFree = (provider: string, cost: { input: number } | undefined) => provider === "opencode" && (!cost || cost.input === 0) +type ModelState = ReturnType<typeof useLocal>["model"] + const ModelList: Component<{ provider?: string class?: string onSelect: () => void action?: JSX.Element + model?: ModelState }> = (props) => { - const local = useLocal() + const model = props.model ?? useLocal().model const language = useLanguage() const models = createMemo(() => - local.model + model .list() - .filter((m) => local.model.visible({ modelID: m.id, providerID: m.provider.id })) + .filter((m) => model.visible({ modelID: m.id, providerID: m.provider.id })) .filter((m) => (props.provider ? m.provider.id === props.provider : true)), ) @@ -41,7 +44,7 @@ const ModelList: Component<{ emptyMessage={language.t("dialog.model.empty")} key={(x) => `${x.provider.id}:${x.id}`} items={models} - current={local.model.current()} + current={model.current()} filterKeys={["provider.name", "name", "id"]} sortBy={(a, b) => a.name.localeCompare(b.name)} groupBy={(x) => x.provider.name} @@ -63,7 +66,7 @@ const ModelList: Component<{ </Tooltip> )} onSelect={(x) => { - local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { + model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { recent: true, }) props.onSelect() @@ -88,6 +91,7 @@ type ModelSelectorTriggerProps = Omit<ComponentProps<typeof Kobalte.Trigger>, "a export function ModelSelectorPopover(props: { provider?: string + model?: ModelState children?: JSX.Element triggerAs?: ValidComponent triggerProps?: ModelSelectorTriggerProps @@ -151,6 +155,7 @@ export function ModelSelectorPopover(props: { <Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title> <ModelList provider={props.provider} + model={props.model} onSelect={() => setStore("open", false)} class="p-1" action={ @@ -184,7 +189,7 @@ export function ModelSelectorPopover(props: { ) } -export const DialogSelectModel: Component<{ provider?: string }> = (props) => { +export const DialogSelectModel: Component<{ provider?: string; model?: ModelState }> = (props) => { const dialog = useDialog() const language = useLanguage() @@ -202,7 +207,7 @@ export const DialogSelectModel: Component<{ provider?: string }> = (props) => { </Button> } > - <ModelList provider={props.provider} onSelect={() => dialog.close()} /> + <ModelList provider={props.provider} model={props.model} onSelect={() => dialog.close()} /> <Button variant="ghost" class="ml-3 mt-5 mb-6 text-text-base self-start" diff --git a/packages/app/src/components/prompt-input.tsx b/packages/app/src/components/prompt-input.tsx index fd54de9a0..9048fa895 100644 --- a/packages/app/src/components/prompt-input.tsx +++ b/packages/app/src/components/prompt-input.tsx @@ -1430,39 +1430,76 @@ export const PromptInput: Component<PromptInputProps> = (props) => { <div class="size-4 shrink-0" /> </div> <div class="flex items-center gap-1.5 min-w-0 flex-1"> - <TooltipKeybind - placement="top" - gutter={4} - title={language.t("command.agent.cycle")} - keybind={command.keybind("agent.cycle")} - > - <Select - size="normal" - options={agentNames()} - current={local.agent.current()?.name ?? ""} - onSelect={local.agent.set} - class="capitalize max-w-[160px] text-text-base" - valueClass="truncate text-13-regular text-text-base" - triggerStyle={control()} - variant="ghost" - /> - </TooltipKeybind> - <Show - when={providers.paid().length > 0} - fallback={ + <div data-component="prompt-agent-control"> + <TooltipKeybind + placement="top" + gutter={4} + title={language.t("command.agent.cycle")} + keybind={command.keybind("agent.cycle")} + > + <Select + size="normal" + options={agentNames()} + current={local.agent.current()?.name ?? ""} + onSelect={local.agent.set} + class="capitalize max-w-[160px] text-text-base" + valueClass="truncate text-13-regular text-text-base" + triggerStyle={control()} + triggerProps={{ "data-action": "prompt-agent" }} + variant="ghost" + /> + </TooltipKeybind> + </div> + <div data-component="prompt-model-control"> + <Show + when={providers.paid().length > 0} + fallback={ + <TooltipKeybind + placement="top" + gutter={4} + title={language.t("command.model.choose")} + keybind={command.keybind("model.choose")} + > + <Button + data-action="prompt-model" + as="div" + variant="ghost" + size="normal" + class="min-w-0 max-w-[320px] text-13-regular text-text-base group" + style={control()} + onClick={() => dialog.show(() => <DialogSelectModelUnpaid model={local.model} />)} + > + <Show when={local.model.current()?.provider?.id}> + <ProviderIcon + id={local.model.current()!.provider.id} + class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150" + style={{ "will-change": "opacity", transform: "translateZ(0)" }} + /> + </Show> + <span class="truncate"> + {local.model.current()?.name ?? language.t("dialog.model.select.title")} + </span> + <Icon name="chevron-down" size="small" class="shrink-0" /> + </Button> + </TooltipKeybind> + } + > <TooltipKeybind placement="top" gutter={4} title={language.t("command.model.choose")} keybind={command.keybind("model.choose")} > - <Button - as="div" - variant="ghost" - size="normal" - class="min-w-0 max-w-[320px] text-13-regular text-text-base group" - style={control()} - onClick={() => dialog.show(() => <DialogSelectModelUnpaid />)} + <ModelSelectorPopover + model={local.model} + triggerAs={Button} + triggerProps={{ + variant: "ghost", + size: "normal", + style: control(), + class: "min-w-0 max-w-[320px] text-13-regular text-text-base group", + "data-action": "prompt-model", + }} > <Show when={local.model.current()?.provider?.id}> <ProviderIcon @@ -1475,57 +1512,31 @@ export const PromptInput: Component<PromptInputProps> = (props) => { {local.model.current()?.name ?? language.t("dialog.model.select.title")} </span> <Icon name="chevron-down" size="small" class="shrink-0" /> - </Button> + </ModelSelectorPopover> </TooltipKeybind> - } - > + </Show> + </div> + <div data-component="prompt-variant-control"> <TooltipKeybind placement="top" gutter={4} - title={language.t("command.model.choose")} - keybind={command.keybind("model.choose")} + title={language.t("command.model.variant.cycle")} + keybind={command.keybind("model.variant.cycle")} > - <ModelSelectorPopover - triggerAs={Button} - triggerProps={{ - variant: "ghost", - size: "normal", - style: control(), - class: "min-w-0 max-w-[320px] text-13-regular text-text-base group", - }} - > - <Show when={local.model.current()?.provider?.id}> - <ProviderIcon - id={local.model.current()!.provider.id} - class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150" - style={{ "will-change": "opacity", transform: "translateZ(0)" }} - /> - </Show> - <span class="truncate"> - {local.model.current()?.name ?? language.t("dialog.model.select.title")} - </span> - <Icon name="chevron-down" size="small" class="shrink-0" /> - </ModelSelectorPopover> + <Select + size="normal" + options={variants()} + current={local.model.variant.current() ?? "default"} + label={(x) => (x === "default" ? language.t("common.default") : x)} + onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)} + class="capitalize max-w-[160px] text-text-base" + valueClass="truncate text-13-regular text-text-base" + triggerStyle={control()} + triggerProps={{ "data-action": "prompt-model-variant" }} + variant="ghost" + /> </TooltipKeybind> - </Show> - <TooltipKeybind - placement="top" - gutter={4} - title={language.t("command.model.variant.cycle")} - keybind={command.keybind("model.variant.cycle")} - > - <Select - size="normal" - options={variants()} - current={local.model.variant.current() ?? "default"} - label={(x) => (x === "default" ? language.t("common.default") : x)} - onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)} - class="capitalize max-w-[160px] text-text-base" - valueClass="truncate text-13-regular text-text-base" - triggerStyle={control()} - variant="ghost" - /> - </TooltipKeybind> + </div> <TooltipKeybind placement="top" gutter={8} diff --git a/packages/app/src/components/prompt-input/submit.test.ts b/packages/app/src/components/prompt-input/submit.test.ts index 9f7fac69d..b0166c43a 100644 --- a/packages/app/src/components/prompt-input/submit.test.ts +++ b/packages/app/src/components/prompt-input/submit.test.ts @@ -17,6 +17,7 @@ const optimistic: Array<{ }> = [] const optimisticSeeded: boolean[] = [] const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {} +const promoted: Array<{ directory: string; sessionID: string }> = [] const sentShell: string[] = [] const syncedDirectories: string[] = [] @@ -86,6 +87,11 @@ beforeAll(async () => { agent: { current: () => ({ name: "agent" }), }, + session: { + promote(directory: string, sessionID: string) { + promoted.push({ directory, sessionID }) + }, + }, }), })) @@ -201,6 +207,7 @@ beforeEach(() => { enabledAutoAccept.length = 0 optimistic.length = 0 optimisticSeeded.length = 0 + promoted.length = 0 params = {} sentShell.length = 0 syncedDirectories.length = 0 @@ -240,6 +247,11 @@ describe("prompt submit worktree selection", () => { expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"]) + expect(promoted).toEqual([ + { directory: "/repo/worktree-a", sessionID: "session-1" }, + { directory: "/repo/worktree-b", sessionID: "session-2" }, + ]) + expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"]) }) test("applies auto-accept to newly created sessions", async () => { diff --git a/packages/app/src/components/prompt-input/submit.ts b/packages/app/src/components/prompt-input/submit.ts index e8d765cd9..ba299fe36 100644 --- a/packages/app/src/components/prompt-input/submit.ts +++ b/packages/app/src/components/prompt-input/submit.ts @@ -296,6 +296,7 @@ export function createPromptSubmit(input: PromptSubmitInput) { const currentModel = local.model.current() const currentAgent = local.agent.current() + const variant = local.model.variant.current() if (!currentModel || !currentAgent) { showToast({ title: language.t("prompt.toast.modelAgentRequired.title"), @@ -370,6 +371,7 @@ export function createPromptSubmit(input: PromptSubmitInput) { seed(sessionDirectory, created) session = created if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory) + local.session.promote(sessionDirectory, session.id) layout.handoff.setTabs(base64Encode(sessionDirectory), session.id) navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`) } @@ -387,7 +389,6 @@ export function createPromptSubmit(input: PromptSubmitInput) { providerID: currentModel.provider.id, } const agent = currentAgent.name - const variant = local.model.variant.current() const context = prompt.context.items().slice() const draft: FollowupDraft = { sessionID: session.id, diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index 75d1334a5..bed7ecd15 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -1,252 +1,421 @@ -import { createStore } from "solid-js/store" -import { batch, createMemo } from "solid-js" import { createSimpleContext } from "@opencode-ai/ui/context" -import { useSDK } from "./sdk" -import { useSync } from "./sync" import { base64Encode } from "@opencode-ai/util/encode" -import { useProviders } from "@/hooks/use-providers" +import { useParams } from "@solidjs/router" +import { batch, createEffect, createMemo, onCleanup } from "solid-js" +import { createStore } from "solid-js/store" import { useModels } from "@/context/models" +import { useProviders } from "@/hooks/use-providers" +import { modelEnabled, modelProbe } from "@/testing/model-selection" +import { Persist, persisted } from "@/utils/persist" import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" +import { useSDK } from "./sdk" +import { useSync } from "./sync" export type ModelKey = { providerID: string; modelID: string } +type State = { + agent?: string + model?: ModelKey + variant?: string | null +} + +type Saved = { + session: Record<string, State | undefined> +} + +const WORKSPACE_KEY = "__workspace__" +const handoff = new Map<string, State>() + +const handoffKey = (dir: string, id: string) => `${dir}\n${id}` + +const migrate = (value: unknown) => { + if (!value || typeof value !== "object") return { session: {} } + + const item = value as { + session?: Record<string, State | undefined> + pick?: Record<string, State | undefined> + } + + if (item.session && typeof item.session === "object") return { session: item.session } + if (!item.pick || typeof item.pick !== "object") return { session: {} } + + return { + session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)), + } +} + +const clone = (value: State | undefined) => { + if (!value) return undefined + return { + ...value, + model: value.model ? { ...value.model } : undefined, + } satisfies State +} + export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ name: "Local", init: () => { + const params = useParams() const sdk = useSDK() const sync = useSync() const providers = useProviders() - const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id))) + const models = useModels() + + const id = createMemo(() => params.id || undefined) + const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden)) + const connected = createMemo(() => new Set(providers.connected().map((item) => item.id))) - function isModelValid(model: ModelKey) { - const provider = providers.all().find((x) => x.id === model.providerID) + const [saved, setSaved] = persisted( + { + ...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]), + migrate, + }, + createStore<Saved>({ + session: {}, + }), + ) + + const [store, setStore] = createStore<{ + current?: string + draft?: State + last?: { + type: "agent" | "model" | "variant" + agent?: string + model?: ModelKey | null + variant?: string | null + } + }>({ + current: list()[0]?.name, + draft: undefined, + last: undefined, + }) + + const validModel = (model: ModelKey) => { + const provider = providers.all().find((item) => item.id === model.providerID) return !!provider?.models[model.modelID] && connected().has(model.providerID) } - function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) { - for (const modelFn of modelFns) { - const model = modelFn() + const firstModel = (...items: Array<() => ModelKey | undefined>) => { + for (const item of items) { + const model = item() if (!model) continue - if (isModelValid(model)) return model + if (validModel(model)) return model } } - let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined + const pickAgent = (name: string | undefined) => { + const items = list() + if (items.length === 0) return undefined + return items.find((item) => item.name === name) ?? items[0] + } - const agent = (() => { - const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden)) - const models = useModels() + createEffect(() => { + const items = list() + if (items.length === 0) { + if (store.current !== undefined) setStore("current", undefined) + return + } + if (items.some((item) => item.name === store.current)) return + setStore("current", items[0]?.name) + }) - const [store, setStore] = createStore<{ - current?: string - }>({ - current: list()[0]?.name, - }) - return { - list, - current() { - const available = list() - if (available.length === 0) return undefined - return available.find((x) => x.name === store.current) ?? available[0] - }, - set(name: string | undefined) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - const match = name ? available.find((x) => x.name === name) : undefined - const value = match ?? available[0] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, - move(direction: 1 | -1) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - let next = available.findIndex((x) => x.name === store.current) + direction - if (next < 0) next = available.length - 1 - if (next >= available.length) next = 0 - const value = available[next] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, + const scope = createMemo<State | undefined>(() => { + const session = id() + if (!session) return store.draft + return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session)) + }) + + createEffect(() => { + const session = id() + if (!session) return + + const key = handoffKey(sdk.directory, session) + const next = handoff.get(key) + if (!next) return + if (saved.session[session] !== undefined) { + handoff.delete(key) + return } - })() - const model = (() => { - const models = useModels() + setSaved("session", session, clone(next)) + handoff.delete(key) + }) - const [ephemeral, setEphemeral] = createStore<{ - model: Record<string, ModelKey | undefined> - }>({ - model: {}, - }) + const configuredModel = () => { + if (!sync.data.config.model) return + const [providerID, modelID] = sync.data.config.model.split("/") + const model = { providerID, modelID } + if (validModel(model)) return model + } - const resolveConfigured = () => { - if (!sync.data.config.model) return - const [providerID, modelID] = sync.data.config.model.split("/") - const key = { providerID, modelID } - if (isModelValid(key)) return key + const recentModel = () => { + for (const item of models.recent.list()) { + if (validModel(item)) return item } + } - const resolveRecent = () => { - for (const item of models.recent.list()) { - if (isModelValid(item)) return item + const defaultModel = () => { + const defaults = providers.default() + for (const provider of providers.connected()) { + const configured = defaults[provider.id] + if (configured) { + const model = { providerID: provider.id, modelID: configured } + if (validModel(model)) return model } + + const first = Object.values(provider.models)[0] + if (!first) continue + const model = { providerID: provider.id, modelID: first.id } + if (validModel(model)) return model } + } - const resolveDefault = () => { - const defaults = providers.default() - for (const provider of providers.connected()) { - const configured = defaults[provider.id] - if (configured) { - const key = { providerID: provider.id, modelID: configured } - if (isModelValid(key)) return key - } + const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel()) - const first = Object.values(provider.models)[0] - if (!first) continue - const key = { providerID: provider.id, modelID: first.id } - if (isModelValid(key)) return key + const agent = { + list, + current() { + return pickAgent(scope()?.agent ?? store.current) + }, + set(name: string | undefined) { + const item = pickAgent(name) + if (!item) { + setStore("current", undefined) + return } - } - const fallbackModel = createMemo<ModelKey | undefined>(() => { - return resolveConfigured() ?? resolveRecent() ?? resolveDefault() - }) + batch(() => { + setStore("current", item.name) + setStore("last", { + type: "agent", + agent: item.name, + model: item.model, + variant: item.variant ?? null, + }) + const next = { + agent: item.name, + model: item.model, + variant: item.variant, + } satisfies State + const session = id() + if (session) { + setSaved("session", session, next) + return + } + setStore("draft", next) + }) + }, + move(direction: 1 | -1) { + const items = list() + if (items.length === 0) { + setStore("current", undefined) + return + } - const current = createMemo(() => { - const a = agent.current() - if (!a) return undefined - const key = getFirstValidModel( - () => ephemeral.model[a.name], - () => a.model, - fallbackModel, - ) - if (!key) return undefined - return models.find(key) - }) + let next = items.findIndex((item) => item.name === agent.current()?.name) + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + const item = items[next] + if (!item) return + agent.set(item.name) + }, + } - const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) + const current = () => { + const item = firstModel( + () => scope()?.model, + () => agent.current()?.model, + fallback, + ) + if (!item) return undefined + return models.find(item) + } - const cycle = (direction: 1 | -1) => { - const recentList = recent() - const currentModel = current() - if (!currentModel) return + const configured = () => { + const item = agent.current() + const model = current() + if (!item || !model) return undefined + return getConfiguredAgentVariant({ + agent: { model: item.model, variant: item.variant }, + model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, + }) + } - const index = recentList.findIndex( - (x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id, - ) - if (index === -1) return + const selected = () => scope()?.variant - let next = index + direction - if (next < 0) next = recentList.length - 1 - if (next >= recentList.length) next = 0 + const snapshot = () => { + const model = current() + return { + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : undefined, + variant: selected(), + } satisfies State + } - const val = recentList[next] - if (!val) return + const write = (next: Partial<State>) => { + const state = { + ...(scope() ?? { agent: agent.current()?.name }), + ...next, + } satisfies State - model.set({ - providerID: val.provider.id, - modelID: val.id, - }) + const session = id() + if (session) { + setSaved("session", session, state) + return } + setStore("draft", state) + } - const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => { - batch(() => { - const currentAgent = agent.current() - const next = model ?? fallbackModel() - if (currentAgent) setEphemeral("model", currentAgent.name, next) - if (model) models.setVisibility(model, true) - if (options?.recent && model) models.recent.push(model) - }) - } + const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) - setModel = set + const model = { + ready: models.ready, + current, + recent, + list: models.list, + cycle(direction: 1 | -1) { + const items = recent() + const item = current() + if (!item) return - return { - ready: models.ready, - current, - recent, - list: models.list, - cycle, - set, - visible(model: ModelKey) { - return models.visible(model) + const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id) + if (index === -1) return + + let next = index + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + + const entry = items[next] + if (!entry) return + model.set({ providerID: entry.provider.id, modelID: entry.id }) + }, + set(item: ModelKey | undefined, options?: { recent?: boolean }) { + batch(() => { + setStore("last", { + type: "model", + agent: agent.current()?.name, + model: item ?? null, + variant: selected(), + }) + write({ model: item }) + if (!item) return + models.setVisibility(item, true) + if (!options?.recent) return + models.recent.push(item) + }) + }, + visible(item: ModelKey) { + return models.visible(item) + }, + setVisibility(item: ModelKey, visible: boolean) { + models.setVisibility(item, visible) + }, + variant: { + configured, + selected, + current() { + return resolveModelVariant({ + variants: this.list(), + selected: this.selected(), + configured: this.configured(), + }) }, - setVisibility(model: ModelKey, visible: boolean) { - models.setVisibility(model, visible) + list() { + const item = current() + if (!item?.variants) return [] + return Object.keys(item.variants) }, - variant: { - configured() { - const a = agent.current() - const m = current() - if (!a || !m) return undefined - return getConfiguredAgentVariant({ - agent: { model: a.model, variant: a.variant }, - model: { providerID: m.provider.id, modelID: m.id, variants: m.variants }, + set(value: string | undefined) { + batch(() => { + const model = current() + setStore("last", { + type: "variant", + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : null, + variant: value ?? null, }) - }, - selected() { - const m = current() - if (!m) return undefined - return models.variant.get({ providerID: m.provider.id, modelID: m.id }) - }, - current() { - return resolveModelVariant({ - variants: this.list(), + write({ variant: value ?? null }) + }) + }, + cycle() { + const items = this.list() + if (items.length === 0) return + this.set( + cycleModelVariant({ + variants: items, selected: this.selected(), configured: this.configured(), - }) - }, - list() { - const m = current() - if (!m) return [] - if (!m.variants) return [] - return Object.keys(m.variants) - }, - set(value: string | undefined) { - const m = current() - if (!m) return - models.variant.set({ providerID: m.provider.id, modelID: m.id }, value) - }, - cycle() { - const variants = this.list() - if (variants.length === 0) return - this.set( - cycleModelVariant({ - variants, - selected: this.selected(), - configured: this.configured(), - }), - ) - }, + }), + ) }, - } - })() + }, + } const result = { slug: createMemo(() => base64Encode(sdk.directory)), model, agent, + session: { + reset() { + setStore("draft", undefined) + }, + promote(dir: string, session: string) { + const next = clone(snapshot()) + if (!next) return + + if (dir === sdk.directory) { + setSaved("session", session, next) + setStore("draft", undefined) + return + } + + handoff.set(handoffKey(dir, session), next) + setStore("draft", undefined) + }, + restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) { + const session = id() + if (!session) return + if (msg.sessionID !== session) return + if (saved.session[session] !== undefined) return + if (handoff.has(handoffKey(sdk.directory, session))) return + + setSaved("session", session, { + agent: msg.agent, + model: msg.model, + variant: msg.variant ?? null, + }) + }, + }, } + + if (modelEnabled()) { + createEffect(() => { + const agent = result.agent.current() + const model = result.model.current() + modelProbe.set({ + dir: sdk.directory, + sessionID: id(), + last: store.last, + agent: agent?.name, + model: model + ? { + providerID: model.provider.id, + modelID: model.id, + name: model.name, + } + : undefined, + variant: result.model.variant.current() ?? null, + selected: result.model.variant.selected(), + configured: result.model.variant.configured(), + pick: scope(), + base: undefined, + current: store.current, + }) + }) + + onCleanup(() => modelProbe.clear()) + } + return result }, }) diff --git a/packages/app/src/context/model-variant.test.ts b/packages/app/src/context/model-variant.test.ts index 01b149fd2..583bc5c3d 100644 --- a/packages/app/src/context/model-variant.test.ts +++ b/packages/app/src/context/model-variant.test.ts @@ -44,6 +44,16 @@ describe("model variant", () => { expect(value).toBe("high") }) + test("lets an explicit default override the configured variant", () => { + const value = resolveModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBeUndefined() + }) + test("cycles from configured variant to next", () => { const value = cycleModelVariant({ variants: ["low", "high", "xhigh"], @@ -63,4 +73,14 @@ describe("model variant", () => { expect(value).toBe("low") }) + + test("cycles from an explicit default to the first variant", () => { + const value = cycleModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBe("low") + }) }) diff --git a/packages/app/src/context/model-variant.ts b/packages/app/src/context/model-variant.ts index 6b7ae7256..525acbba3 100644 --- a/packages/app/src/context/model-variant.ts +++ b/packages/app/src/context/model-variant.ts @@ -14,7 +14,7 @@ type Model = AgentModel & { type VariantInput = { variants: string[] - selected: string | undefined + selected: string | null | undefined configured: string | undefined } @@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod } export function resolveModelVariant(input: VariantInput) { + if (input.selected === null) return undefined if (input.selected && input.variants.includes(input.selected)) return input.selected if (input.configured && input.variants.includes(input.configured)) return input.configured return undefined @@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) { export function cycleModelVariant(input: VariantInput) { if (input.variants.length === 0) return undefined + if (input.selected === null) return input.variants[0] if (input.selected && input.variants.includes(input.selected)) { const index = input.variants.indexOf(input.selected) if (index === input.variants.length - 1) return undefined diff --git a/packages/app/src/pages/directory-layout.tsx b/packages/app/src/pages/directory-layout.tsx index fdf321f2d..f993ffcd8 100644 --- a/packages/app/src/pages/directory-layout.tsx +++ b/packages/app/src/pages/directory-layout.tsx @@ -80,11 +80,11 @@ export default function Layout(props: ParentProps) { }) return ( - <Show when={state.resolved}> + <Show when={state.resolved} keyed> {(resolved) => ( - <SDKProvider directory={resolved}> + <SDKProvider directory={() => resolved}> <SyncProvider> - <DirectoryDataProvider directory={resolved()}>{props.children}</DirectoryDataProvider> + <DirectoryDataProvider directory={resolved}>{props.children}</DirectoryDataProvider> </SyncProvider> </SDKProvider> )} diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx index 8399a1367..6d2917008 100644 --- a/packages/app/src/pages/session.tsx +++ b/packages/app/src/pages/session.tsx @@ -44,7 +44,7 @@ import { createOpenReviewFile, createSessionTabs, createSizing, focusTerminalByI import { MessageTimeline } from "@/pages/session/message-timeline" import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab" import { useSessionLayout } from "@/pages/session/session-layout" -import { resetSessionModel, syncSessionModel } from "@/pages/session/session-model-helpers" +import { syncSessionModel } from "@/pages/session/session-model-helpers" import { SessionSidePanel } from "@/pages/session/session-side-panel" import { TerminalPanel } from "@/pages/session/terminal-panel" import { useSessionCommands } from "@/pages/session/use-session-commands" @@ -490,7 +490,7 @@ export default function Page() { (next, prev) => { if (!prev) return if (next.dir === prev.dir && next.id === prev.id) return - if (!next.id) resetSessionModel(local) + if (prev.id && !next.id) local.session.reset() }, { defer: true }, ), diff --git a/packages/app/src/pages/session/session-model-helpers.test.ts b/packages/app/src/pages/session/session-model-helpers.test.ts index 5f554dcd3..319db805d 100644 --- a/packages/app/src/pages/session/session-model-helpers.test.ts +++ b/packages/app/src/pages/session/session-model-helpers.test.ts @@ -14,145 +14,38 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant" }) as UserMessage describe("syncSessionModel", () => { - test("restores the last message model and variant", () => { + test("restores the last message through session state", () => { const calls: unknown[] = [] syncSessionModel( { - agent: { - current() { - return undefined - }, - set(value) { - calls.push(["agent", value]) - }, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return { id: "claude-sonnet-4", provider: { id: "anthropic" } } - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, - }, - }, - }, - message({ variant: "high" }), - ) - - expect(calls).toEqual([ - ["agent", "build"], - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", "high"], - ]) - }) - - test("skips variant when the model falls back", () => { - const calls: unknown[] = [] - - syncSessionModel( - { - agent: { - current() { - return undefined - }, - set(value) { - calls.push(["agent", value]) - }, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return { id: "gpt-5", provider: { id: "openai" } } - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, + session: { + restore(value) { + calls.push(value) }, + reset() {}, }, }, message({ variant: "high" }), ) - expect(calls).toEqual([ - ["agent", "build"], - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ]) + expect(calls).toEqual([message({ variant: "high" })]) }) }) describe("resetSessionModel", () => { - test("restores the current agent defaults", () => { - const calls: unknown[] = [] + test("clears draft session state", () => { + const calls: string[] = [] resetSessionModel({ - agent: { - current() { - return { - model: { providerID: "anthropic", modelID: "claude-sonnet-4" }, - variant: "high", - } - }, - set() {}, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return undefined - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, - }, - }, - }) - - expect(calls).toEqual([ - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", "high"], - ]) - }) - - test("clears the variant when the agent has none", () => { - const calls: unknown[] = [] - - resetSessionModel({ - agent: { - current() { - return { - model: { providerID: "anthropic", modelID: "claude-sonnet-4" }, - } - }, - set() {}, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return undefined - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, + session: { + reset() { + calls.push("reset") }, + restore() {}, }, }) - expect(calls).toEqual([ - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", undefined], - ]) + expect(calls).toEqual(["reset"]) }) }) diff --git a/packages/app/src/pages/session/session-model-helpers.ts b/packages/app/src/pages/session/session-model-helpers.ts index 7600f16d5..c9e2e1dbd 100644 --- a/packages/app/src/pages/session/session-model-helpers.ts +++ b/packages/app/src/pages/session/session-model-helpers.ts @@ -1,48 +1,16 @@ import type { UserMessage } from "@opencode-ai/sdk/v2" -import { batch } from "solid-js" type Local = { - agent: { - current(): - | { - model?: UserMessage["model"] - variant?: string - } - | undefined - set(name: string | undefined): void - } - model: { - set(model: UserMessage["model"] | undefined): void - current(): - | { - id: string - provider: { id: string } - } - | undefined - variant: { - set(value: string | undefined): void - } + session: { + reset(): void + restore(msg: UserMessage): void } } export const resetSessionModel = (local: Local) => { - const agent = local.agent.current() - if (!agent) return - batch(() => { - local.model.set(agent.model) - local.model.variant.set(agent.variant) - }) + local.session.reset() } export const syncSessionModel = (local: Local, msg: UserMessage) => { - batch(() => { - local.agent.set(msg.agent) - local.model.set(msg.model) - }) - - const model = local.model.current() - if (!model) return - if (model.provider.id !== msg.model.providerID) return - if (model.id !== msg.model.modelID) return - local.model.variant.set(msg.variant) + local.session.restore(msg) } diff --git a/packages/app/src/pages/session/use-session-commands.tsx b/packages/app/src/pages/session/use-session-commands.tsx index f5a4c0576..1a2e777f5 100644 --- a/packages/app/src/pages/session/use-session-commands.tsx +++ b/packages/app/src/pages/session/use-session-commands.tsx @@ -351,7 +351,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => { description: language.t("command.model.choose.description"), keybind: "mod+'", slash: "model", - onSelect: () => dialog.show(() => <DialogSelectModel />), + onSelect: () => dialog.show(() => <DialogSelectModel model={local.model} />), }), mcpCommand({ id: "mcp.toggle", diff --git a/packages/app/src/testing/model-selection.ts b/packages/app/src/testing/model-selection.ts new file mode 100644 index 000000000..a5ea199ac --- /dev/null +++ b/packages/app/src/testing/model-selection.ts @@ -0,0 +1,80 @@ +type ModelKey = { + providerID: string + modelID: string +} + +type State = { + agent?: string + model?: ModelKey | null + variant?: string | null +} + +export type ModelProbeState = { + dir?: string + sessionID?: string + last?: { + type: "agent" | "model" | "variant" + agent?: string + model?: ModelKey | null + variant?: string | null + } + agent?: string + model?: (ModelKey & { name?: string }) | undefined + variant?: string | null + selected?: string | null + configured?: string + pick?: State + base?: State + current?: string +} + +export type ModelWindow = Window & { + __opencode_e2e?: { + model?: { + enabled?: boolean + current?: ModelProbeState + } + } +} + +const clone = (state?: State) => { + if (!state) return undefined + return { + ...state, + model: state.model ? { ...state.model } : state.model, + } +} + +export const modelEnabled = () => { + if (typeof window === "undefined") return false + return (window as ModelWindow).__opencode_e2e?.model?.enabled === true +} + +const root = () => { + if (!modelEnabled()) return + return (window as ModelWindow).__opencode_e2e?.model +} + +export const modelProbe = { + set(input: ModelProbeState) { + const state = root() + if (!state) return + state.current = { + ...input, + model: input.model ? { ...input.model } : undefined, + last: input.last + ? { + ...input.last, + model: input.last.model ? { ...input.last.model } : input.last.model, + } + : undefined, + pick: clone(input.pick), + base: clone(input.base), + } + }, + clear() { + const state = root() + if (!state) return + state.current = undefined + }, +} diff --git a/packages/app/src/testing/terminal.ts b/packages/app/src/testing/terminal.ts index 4c179dee3..af1c33309 100644 --- a/packages/app/src/testing/terminal.ts +++ b/packages/app/src/testing/terminal.ts @@ -1,3 +1,5 @@ +import type { ModelProbeState } from "./model-selection" + export const terminalAttr = "data-pty-id" export type TerminalProbeState = { @@ -13,6 +15,10 @@ type TerminalProbeControl = { export type E2EWindow = Window & { __opencode_e2e?: { + model?: { + enabled?: boolean + current?: ModelProbeState + } terminal?: { enabled?: boolean terminals?: Record<string, TerminalProbeState> diff --git a/packages/ui/src/components/select.tsx b/packages/ui/src/components/select.tsx index b370dbb64..61804a951 100644 --- a/packages/ui/src/components/select.tsx +++ b/packages/ui/src/components/select.tsx @@ -19,6 +19,7 @@ export type SelectProps<T> = Omit<ComponentProps<typeof Kobalte<T>>, "value" | " children?: (item: T | undefined) => JSX.Element triggerStyle?: JSX.CSSProperties triggerVariant?: "settings" + triggerProps?: Record<string, string | number | boolean | undefined> } export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) { @@ -38,6 +39,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) "children", "triggerStyle", "triggerVariant", + "triggerProps", ]) const state = { @@ -131,6 +133,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) }} > <Kobalte.Trigger + {...local.triggerProps} disabled={props.disabled} data-slot="select-select-trigger" as={Button} |
