diff options
| author | Adam <[email protected]> | 2026-03-13 11:05:08 -0500 |
|---|---|---|
| committer | GitHub <[email protected]> | 2026-03-13 11:05:08 -0500 |
| commit | 4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e (patch) | |
| tree | b7e5ed2b05aabb5ed5134520c4eb485c52eb5333 /packages/app/src | |
| parent | 5c7088338c07ad632834ebd4a87feb23d255fb8a (diff) | |
| download | opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.tar.gz opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.zip | |
fix(app): model selection persist by session (#17348)
Diffstat (limited to 'packages/app/src')
| -rw-r--r-- | packages/app/src/components/dialog-select-model-unpaid.tsx | 12 | ||||
| -rw-r--r-- | packages/app/src/components/dialog-select-model.tsx | 19 | ||||
| -rw-r--r-- | packages/app/src/components/prompt-input.tsx | 155 | ||||
| -rw-r--r-- | packages/app/src/components/prompt-input/submit.test.ts | 12 | ||||
| -rw-r--r-- | packages/app/src/components/prompt-input/submit.ts | 3 | ||||
| -rw-r--r-- | packages/app/src/context/local.tsx | 551 | ||||
| -rw-r--r-- | packages/app/src/context/model-variant.test.ts | 20 | ||||
| -rw-r--r-- | packages/app/src/context/model-variant.ts | 4 | ||||
| -rw-r--r-- | packages/app/src/pages/directory-layout.tsx | 6 | ||||
| -rw-r--r-- | packages/app/src/pages/session.tsx | 4 | ||||
| -rw-r--r-- | packages/app/src/pages/session/session-model-helpers.test.ts | 133 | ||||
| -rw-r--r-- | packages/app/src/pages/session/session-model-helpers.ts | 42 | ||||
| -rw-r--r-- | packages/app/src/pages/session/use-session-commands.tsx | 2 | ||||
| -rw-r--r-- | packages/app/src/testing/model-selection.ts | 80 | ||||
| -rw-r--r-- | packages/app/src/testing/terminal.ts | 6 |
15 files changed, 609 insertions, 440 deletions
diff --git a/packages/app/src/components/dialog-select-model-unpaid.tsx b/packages/app/src/components/dialog-select-model-unpaid.tsx index bcee3f501..2106b3a01 100644 --- a/packages/app/src/components/dialog-select-model-unpaid.tsx +++ b/packages/app/src/components/dialog-select-model-unpaid.tsx @@ -13,8 +13,10 @@ import { DialogSelectProvider } from "./dialog-select-provider" import { ModelTooltip } from "./model-tooltip" import { useLanguage } from "@/context/language" -export const DialogSelectModelUnpaid: Component = () => { - const local = useLocal() +type ModelState = ReturnType<typeof useLocal>["model"] + +export const DialogSelectModelUnpaid: Component<{ model?: ModelState }> = (props) => { + const model = props.model ?? useLocal().model const dialog = useDialog() const providers = useProviders() const language = useLanguage() @@ -35,8 +37,8 @@ export const DialogSelectModelUnpaid: Component = () => { <List class="[&_[data-slot=list-scroll]]:overflow-visible" ref={(ref) => (listRef = ref)} - items={local.model.list} - current={local.model.current()} + items={model.list} + current={model.current()} key={(x) => `${x.provider.id}:${x.id}`} itemWrapper={(item, node) => ( <Tooltip @@ -55,7 +57,7 @@ export const DialogSelectModelUnpaid: Component = () => { </Tooltip> )} onSelect={(x) => { - local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { + model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { recent: true, }) dialog.close() diff --git a/packages/app/src/components/dialog-select-model.tsx b/packages/app/src/components/dialog-select-model.tsx index 9f7afb8cd..3654aab85 100644 --- a/packages/app/src/components/dialog-select-model.tsx +++ b/packages/app/src/components/dialog-select-model.tsx @@ -18,19 +18,22 @@ import { useLanguage } from "@/context/language" const isFree = (provider: string, cost: { input: number } | undefined) => provider === "opencode" && (!cost || cost.input === 0) +type ModelState = ReturnType<typeof useLocal>["model"] + const ModelList: Component<{ provider?: string class?: string onSelect: () => void action?: JSX.Element + model?: ModelState }> = (props) => { - const local = useLocal() + const model = props.model ?? useLocal().model const language = useLanguage() const models = createMemo(() => - local.model + model .list() - .filter((m) => local.model.visible({ modelID: m.id, providerID: m.provider.id })) + .filter((m) => model.visible({ modelID: m.id, providerID: m.provider.id })) .filter((m) => (props.provider ? m.provider.id === props.provider : true)), ) @@ -41,7 +44,7 @@ const ModelList: Component<{ emptyMessage={language.t("dialog.model.empty")} key={(x) => `${x.provider.id}:${x.id}`} items={models} - current={local.model.current()} + current={model.current()} filterKeys={["provider.name", "name", "id"]} sortBy={(a, b) => a.name.localeCompare(b.name)} groupBy={(x) => x.provider.name} @@ -63,7 +66,7 @@ const ModelList: Component<{ </Tooltip> )} onSelect={(x) => { - local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { + model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { recent: true, }) props.onSelect() @@ -88,6 +91,7 @@ type ModelSelectorTriggerProps = Omit<ComponentProps<typeof Kobalte.Trigger>, "a export function ModelSelectorPopover(props: { provider?: string + model?: ModelState children?: JSX.Element triggerAs?: ValidComponent triggerProps?: ModelSelectorTriggerProps @@ -151,6 +155,7 @@ export function ModelSelectorPopover(props: { <Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title> <ModelList provider={props.provider} + model={props.model} onSelect={() => setStore("open", false)} class="p-1" action={ @@ -184,7 +189,7 @@ export function ModelSelectorPopover(props: { ) } -export const DialogSelectModel: Component<{ provider?: string }> = (props) => { +export const DialogSelectModel: Component<{ provider?: string; model?: ModelState }> = (props) => { const dialog = useDialog() const language = useLanguage() @@ -202,7 +207,7 @@ export const DialogSelectModel: Component<{ provider?: string }> = (props) => { </Button> } > - <ModelList provider={props.provider} onSelect={() => dialog.close()} /> + <ModelList provider={props.provider} model={props.model} onSelect={() => dialog.close()} /> <Button variant="ghost" class="ml-3 mt-5 mb-6 text-text-base self-start" diff --git a/packages/app/src/components/prompt-input.tsx b/packages/app/src/components/prompt-input.tsx index fd54de9a0..9048fa895 100644 --- a/packages/app/src/components/prompt-input.tsx +++ b/packages/app/src/components/prompt-input.tsx @@ -1430,39 +1430,76 @@ export const PromptInput: Component<PromptInputProps> = (props) => { <div class="size-4 shrink-0" /> </div> <div class="flex items-center gap-1.5 min-w-0 flex-1"> - <TooltipKeybind - placement="top" - gutter={4} - title={language.t("command.agent.cycle")} - keybind={command.keybind("agent.cycle")} - > - <Select - size="normal" - options={agentNames()} - current={local.agent.current()?.name ?? ""} - onSelect={local.agent.set} - class="capitalize max-w-[160px] text-text-base" - valueClass="truncate text-13-regular text-text-base" - triggerStyle={control()} - variant="ghost" - /> - </TooltipKeybind> - <Show - when={providers.paid().length > 0} - fallback={ + <div data-component="prompt-agent-control"> + <TooltipKeybind + placement="top" + gutter={4} + title={language.t("command.agent.cycle")} + keybind={command.keybind("agent.cycle")} + > + <Select + size="normal" + options={agentNames()} + current={local.agent.current()?.name ?? ""} + onSelect={local.agent.set} + class="capitalize max-w-[160px] text-text-base" + valueClass="truncate text-13-regular text-text-base" + triggerStyle={control()} + triggerProps={{ "data-action": "prompt-agent" }} + variant="ghost" + /> + </TooltipKeybind> + </div> + <div data-component="prompt-model-control"> + <Show + when={providers.paid().length > 0} + fallback={ + <TooltipKeybind + placement="top" + gutter={4} + title={language.t("command.model.choose")} + keybind={command.keybind("model.choose")} + > + <Button + data-action="prompt-model" + as="div" + variant="ghost" + size="normal" + class="min-w-0 max-w-[320px] text-13-regular text-text-base group" + style={control()} + onClick={() => dialog.show(() => <DialogSelectModelUnpaid model={local.model} />)} + > + <Show when={local.model.current()?.provider?.id}> + <ProviderIcon + id={local.model.current()!.provider.id} + class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150" + style={{ "will-change": "opacity", transform: "translateZ(0)" }} + /> + </Show> + <span class="truncate"> + {local.model.current()?.name ?? language.t("dialog.model.select.title")} + </span> + <Icon name="chevron-down" size="small" class="shrink-0" /> + </Button> + </TooltipKeybind> + } + > <TooltipKeybind placement="top" gutter={4} title={language.t("command.model.choose")} keybind={command.keybind("model.choose")} > - <Button - as="div" - variant="ghost" - size="normal" - class="min-w-0 max-w-[320px] text-13-regular text-text-base group" - style={control()} - onClick={() => dialog.show(() => <DialogSelectModelUnpaid />)} + <ModelSelectorPopover + model={local.model} + triggerAs={Button} + triggerProps={{ + variant: "ghost", + size: "normal", + style: control(), + class: "min-w-0 max-w-[320px] text-13-regular text-text-base group", + "data-action": "prompt-model", + }} > <Show when={local.model.current()?.provider?.id}> <ProviderIcon @@ -1475,57 +1512,31 @@ export const PromptInput: Component<PromptInputProps> = (props) => { {local.model.current()?.name ?? language.t("dialog.model.select.title")} </span> <Icon name="chevron-down" size="small" class="shrink-0" /> - </Button> + </ModelSelectorPopover> </TooltipKeybind> - } - > + </Show> + </div> + <div data-component="prompt-variant-control"> <TooltipKeybind placement="top" gutter={4} - title={language.t("command.model.choose")} - keybind={command.keybind("model.choose")} + title={language.t("command.model.variant.cycle")} + keybind={command.keybind("model.variant.cycle")} > - <ModelSelectorPopover - triggerAs={Button} - triggerProps={{ - variant: "ghost", - size: "normal", - style: control(), - class: "min-w-0 max-w-[320px] text-13-regular text-text-base group", - }} - > - <Show when={local.model.current()?.provider?.id}> - <ProviderIcon - id={local.model.current()!.provider.id} - class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150" - style={{ "will-change": "opacity", transform: "translateZ(0)" }} - /> - </Show> - <span class="truncate"> - {local.model.current()?.name ?? language.t("dialog.model.select.title")} - </span> - <Icon name="chevron-down" size="small" class="shrink-0" /> - </ModelSelectorPopover> + <Select + size="normal" + options={variants()} + current={local.model.variant.current() ?? "default"} + label={(x) => (x === "default" ? language.t("common.default") : x)} + onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)} + class="capitalize max-w-[160px] text-text-base" + valueClass="truncate text-13-regular text-text-base" + triggerStyle={control()} + triggerProps={{ "data-action": "prompt-model-variant" }} + variant="ghost" + /> </TooltipKeybind> - </Show> - <TooltipKeybind - placement="top" - gutter={4} - title={language.t("command.model.variant.cycle")} - keybind={command.keybind("model.variant.cycle")} - > - <Select - size="normal" - options={variants()} - current={local.model.variant.current() ?? "default"} - label={(x) => (x === "default" ? language.t("common.default") : x)} - onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)} - class="capitalize max-w-[160px] text-text-base" - valueClass="truncate text-13-regular text-text-base" - triggerStyle={control()} - variant="ghost" - /> - </TooltipKeybind> + </div> <TooltipKeybind placement="top" gutter={8} diff --git a/packages/app/src/components/prompt-input/submit.test.ts b/packages/app/src/components/prompt-input/submit.test.ts index 9f7fac69d..b0166c43a 100644 --- a/packages/app/src/components/prompt-input/submit.test.ts +++ b/packages/app/src/components/prompt-input/submit.test.ts @@ -17,6 +17,7 @@ const optimistic: Array<{ }> = [] const optimisticSeeded: boolean[] = [] const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {} +const promoted: Array<{ directory: string; sessionID: string }> = [] const sentShell: string[] = [] const syncedDirectories: string[] = [] @@ -86,6 +87,11 @@ beforeAll(async () => { agent: { current: () => ({ name: "agent" }), }, + session: { + promote(directory: string, sessionID: string) { + promoted.push({ directory, sessionID }) + }, + }, }), })) @@ -201,6 +207,7 @@ beforeEach(() => { enabledAutoAccept.length = 0 optimistic.length = 0 optimisticSeeded.length = 0 + promoted.length = 0 params = {} sentShell.length = 0 syncedDirectories.length = 0 @@ -240,6 +247,11 @@ describe("prompt submit worktree selection", () => { expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"]) + expect(promoted).toEqual([ + { directory: "/repo/worktree-a", sessionID: "session-1" }, + { directory: "/repo/worktree-b", sessionID: "session-2" }, + ]) + expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"]) }) test("applies auto-accept to newly created sessions", async () => { diff --git a/packages/app/src/components/prompt-input/submit.ts b/packages/app/src/components/prompt-input/submit.ts index e8d765cd9..ba299fe36 100644 --- a/packages/app/src/components/prompt-input/submit.ts +++ b/packages/app/src/components/prompt-input/submit.ts @@ -296,6 +296,7 @@ export function createPromptSubmit(input: PromptSubmitInput) { const currentModel = local.model.current() const currentAgent = local.agent.current() + const variant = local.model.variant.current() if (!currentModel || !currentAgent) { showToast({ title: language.t("prompt.toast.modelAgentRequired.title"), @@ -370,6 +371,7 @@ export function createPromptSubmit(input: PromptSubmitInput) { seed(sessionDirectory, created) session = created if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory) + local.session.promote(sessionDirectory, session.id) layout.handoff.setTabs(base64Encode(sessionDirectory), session.id) navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`) } @@ -387,7 +389,6 @@ export function createPromptSubmit(input: PromptSubmitInput) { providerID: currentModel.provider.id, } const agent = currentAgent.name - const variant = local.model.variant.current() const context = prompt.context.items().slice() const draft: FollowupDraft = { sessionID: session.id, diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index 75d1334a5..bed7ecd15 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -1,252 +1,421 @@ -import { createStore } from "solid-js/store" -import { batch, createMemo } from "solid-js" import { createSimpleContext } from "@opencode-ai/ui/context" -import { useSDK } from "./sdk" -import { useSync } from "./sync" import { base64Encode } from "@opencode-ai/util/encode" -import { useProviders } from "@/hooks/use-providers" +import { useParams } from "@solidjs/router" +import { batch, createEffect, createMemo, onCleanup } from "solid-js" +import { createStore } from "solid-js/store" import { useModels } from "@/context/models" +import { useProviders } from "@/hooks/use-providers" +import { modelEnabled, modelProbe } from "@/testing/model-selection" +import { Persist, persisted } from "@/utils/persist" import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" +import { useSDK } from "./sdk" +import { useSync } from "./sync" export type ModelKey = { providerID: string; modelID: string } +type State = { + agent?: string + model?: ModelKey + variant?: string | null +} + +type Saved = { + session: Record<string, State | undefined> +} + +const WORKSPACE_KEY = "__workspace__" +const handoff = new Map<string, State>() + +const handoffKey = (dir: string, id: string) => `${dir}\n${id}` + +const migrate = (value: unknown) => { + if (!value || typeof value !== "object") return { session: {} } + + const item = value as { + session?: Record<string, State | undefined> + pick?: Record<string, State | undefined> + } + + if (item.session && typeof item.session === "object") return { session: item.session } + if (!item.pick || typeof item.pick !== "object") return { session: {} } + + return { + session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)), + } +} + +const clone = (value: State | undefined) => { + if (!value) return undefined + return { + ...value, + model: value.model ? { ...value.model } : undefined, + } satisfies State +} + export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ name: "Local", init: () => { + const params = useParams() const sdk = useSDK() const sync = useSync() const providers = useProviders() - const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id))) + const models = useModels() + + const id = createMemo(() => params.id || undefined) + const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden)) + const connected = createMemo(() => new Set(providers.connected().map((item) => item.id))) - function isModelValid(model: ModelKey) { - const provider = providers.all().find((x) => x.id === model.providerID) + const [saved, setSaved] = persisted( + { + ...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]), + migrate, + }, + createStore<Saved>({ + session: {}, + }), + ) + + const [store, setStore] = createStore<{ + current?: string + draft?: State + last?: { + type: "agent" | "model" | "variant" + agent?: string + model?: ModelKey | null + variant?: string | null + } + }>({ + current: list()[0]?.name, + draft: undefined, + last: undefined, + }) + + const validModel = (model: ModelKey) => { + const provider = providers.all().find((item) => item.id === model.providerID) return !!provider?.models[model.modelID] && connected().has(model.providerID) } - function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) { - for (const modelFn of modelFns) { - const model = modelFn() + const firstModel = (...items: Array<() => ModelKey | undefined>) => { + for (const item of items) { + const model = item() if (!model) continue - if (isModelValid(model)) return model + if (validModel(model)) return model } } - let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined + const pickAgent = (name: string | undefined) => { + const items = list() + if (items.length === 0) return undefined + return items.find((item) => item.name === name) ?? items[0] + } - const agent = (() => { - const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden)) - const models = useModels() + createEffect(() => { + const items = list() + if (items.length === 0) { + if (store.current !== undefined) setStore("current", undefined) + return + } + if (items.some((item) => item.name === store.current)) return + setStore("current", items[0]?.name) + }) - const [store, setStore] = createStore<{ - current?: string - }>({ - current: list()[0]?.name, - }) - return { - list, - current() { - const available = list() - if (available.length === 0) return undefined - return available.find((x) => x.name === store.current) ?? available[0] - }, - set(name: string | undefined) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - const match = name ? available.find((x) => x.name === name) : undefined - const value = match ?? available[0] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, - move(direction: 1 | -1) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - let next = available.findIndex((x) => x.name === store.current) + direction - if (next < 0) next = available.length - 1 - if (next >= available.length) next = 0 - const value = available[next] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, + const scope = createMemo<State | undefined>(() => { + const session = id() + if (!session) return store.draft + return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session)) + }) + + createEffect(() => { + const session = id() + if (!session) return + + const key = handoffKey(sdk.directory, session) + const next = handoff.get(key) + if (!next) return + if (saved.session[session] !== undefined) { + handoff.delete(key) + return } - })() - const model = (() => { - const models = useModels() + setSaved("session", session, clone(next)) + handoff.delete(key) + }) - const [ephemeral, setEphemeral] = createStore<{ - model: Record<string, ModelKey | undefined> - }>({ - model: {}, - }) + const configuredModel = () => { + if (!sync.data.config.model) return + const [providerID, modelID] = sync.data.config.model.split("/") + const model = { providerID, modelID } + if (validModel(model)) return model + } - const resolveConfigured = () => { - if (!sync.data.config.model) return - const [providerID, modelID] = sync.data.config.model.split("/") - const key = { providerID, modelID } - if (isModelValid(key)) return key + const recentModel = () => { + for (const item of models.recent.list()) { + if (validModel(item)) return item } + } - const resolveRecent = () => { - for (const item of models.recent.list()) { - if (isModelValid(item)) return item + const defaultModel = () => { + const defaults = providers.default() + for (const provider of providers.connected()) { + const configured = defaults[provider.id] + if (configured) { + const model = { providerID: provider.id, modelID: configured } + if (validModel(model)) return model } + + const first = Object.values(provider.models)[0] + if (!first) continue + const model = { providerID: provider.id, modelID: first.id } + if (validModel(model)) return model } + } - const resolveDefault = () => { - const defaults = providers.default() - for (const provider of providers.connected()) { - const configured = defaults[provider.id] - if (configured) { - const key = { providerID: provider.id, modelID: configured } - if (isModelValid(key)) return key - } + const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel()) - const first = Object.values(provider.models)[0] - if (!first) continue - const key = { providerID: provider.id, modelID: first.id } - if (isModelValid(key)) return key + const agent = { + list, + current() { + return pickAgent(scope()?.agent ?? store.current) + }, + set(name: string | undefined) { + const item = pickAgent(name) + if (!item) { + setStore("current", undefined) + return } - } - const fallbackModel = createMemo<ModelKey | undefined>(() => { - return resolveConfigured() ?? resolveRecent() ?? resolveDefault() - }) + batch(() => { + setStore("current", item.name) + setStore("last", { + type: "agent", + agent: item.name, + model: item.model, + variant: item.variant ?? null, + }) + const next = { + agent: item.name, + model: item.model, + variant: item.variant, + } satisfies State + const session = id() + if (session) { + setSaved("session", session, next) + return + } + setStore("draft", next) + }) + }, + move(direction: 1 | -1) { + const items = list() + if (items.length === 0) { + setStore("current", undefined) + return + } - const current = createMemo(() => { - const a = agent.current() - if (!a) return undefined - const key = getFirstValidModel( - () => ephemeral.model[a.name], - () => a.model, - fallbackModel, - ) - if (!key) return undefined - return models.find(key) - }) + let next = items.findIndex((item) => item.name === agent.current()?.name) + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + const item = items[next] + if (!item) return + agent.set(item.name) + }, + } - const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) + const current = () => { + const item = firstModel( + () => scope()?.model, + () => agent.current()?.model, + fallback, + ) + if (!item) return undefined + return models.find(item) + } - const cycle = (direction: 1 | -1) => { - const recentList = recent() - const currentModel = current() - if (!currentModel) return + const configured = () => { + const item = agent.current() + const model = current() + if (!item || !model) return undefined + return getConfiguredAgentVariant({ + agent: { model: item.model, variant: item.variant }, + model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, + }) + } - const index = recentList.findIndex( - (x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id, - ) - if (index === -1) return + const selected = () => scope()?.variant - let next = index + direction - if (next < 0) next = recentList.length - 1 - if (next >= recentList.length) next = 0 + const snapshot = () => { + const model = current() + return { + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : undefined, + variant: selected(), + } satisfies State + } - const val = recentList[next] - if (!val) return + const write = (next: Partial<State>) => { + const state = { + ...(scope() ?? { agent: agent.current()?.name }), + ...next, + } satisfies State - model.set({ - providerID: val.provider.id, - modelID: val.id, - }) + const session = id() + if (session) { + setSaved("session", session, state) + return } + setStore("draft", state) + } - const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => { - batch(() => { - const currentAgent = agent.current() - const next = model ?? fallbackModel() - if (currentAgent) setEphemeral("model", currentAgent.name, next) - if (model) models.setVisibility(model, true) - if (options?.recent && model) models.recent.push(model) - }) - } + const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) - setModel = set + const model = { + ready: models.ready, + current, + recent, + list: models.list, + cycle(direction: 1 | -1) { + const items = recent() + const item = current() + if (!item) return - return { - ready: models.ready, - current, - recent, - list: models.list, - cycle, - set, - visible(model: ModelKey) { - return models.visible(model) + const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id) + if (index === -1) return + + let next = index + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + + const entry = items[next] + if (!entry) return + model.set({ providerID: entry.provider.id, modelID: entry.id }) + }, + set(item: ModelKey | undefined, options?: { recent?: boolean }) { + batch(() => { + setStore("last", { + type: "model", + agent: agent.current()?.name, + model: item ?? null, + variant: selected(), + }) + write({ model: item }) + if (!item) return + models.setVisibility(item, true) + if (!options?.recent) return + models.recent.push(item) + }) + }, + visible(item: ModelKey) { + return models.visible(item) + }, + setVisibility(item: ModelKey, visible: boolean) { + models.setVisibility(item, visible) + }, + variant: { + configured, + selected, + current() { + return resolveModelVariant({ + variants: this.list(), + selected: this.selected(), + configured: this.configured(), + }) }, - setVisibility(model: ModelKey, visible: boolean) { - models.setVisibility(model, visible) + list() { + const item = current() + if (!item?.variants) return [] + return Object.keys(item.variants) }, - variant: { - configured() { - const a = agent.current() - const m = current() - if (!a || !m) return undefined - return getConfiguredAgentVariant({ - agent: { model: a.model, variant: a.variant }, - model: { providerID: m.provider.id, modelID: m.id, variants: m.variants }, + set(value: string | undefined) { + batch(() => { + const model = current() + setStore("last", { + type: "variant", + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : null, + variant: value ?? null, }) - }, - selected() { - const m = current() - if (!m) return undefined - return models.variant.get({ providerID: m.provider.id, modelID: m.id }) - }, - current() { - return resolveModelVariant({ - variants: this.list(), + write({ variant: value ?? null }) + }) + }, + cycle() { + const items = this.list() + if (items.length === 0) return + this.set( + cycleModelVariant({ + variants: items, selected: this.selected(), configured: this.configured(), - }) - }, - list() { - const m = current() - if (!m) return [] - if (!m.variants) return [] - return Object.keys(m.variants) - }, - set(value: string | undefined) { - const m = current() - if (!m) return - models.variant.set({ providerID: m.provider.id, modelID: m.id }, value) - }, - cycle() { - const variants = this.list() - if (variants.length === 0) return - this.set( - cycleModelVariant({ - variants, - selected: this.selected(), - configured: this.configured(), - }), - ) - }, + }), + ) }, - } - })() + }, + } const result = { slug: createMemo(() => base64Encode(sdk.directory)), model, agent, + session: { + reset() { + setStore("draft", undefined) + }, + promote(dir: string, session: string) { + const next = clone(snapshot()) + if (!next) return + + if (dir === sdk.directory) { + setSaved("session", session, next) + setStore("draft", undefined) + return + } + + handoff.set(handoffKey(dir, session), next) + setStore("draft", undefined) + }, + restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) { + const session = id() + if (!session) return + if (msg.sessionID !== session) return + if (saved.session[session] !== undefined) return + if (handoff.has(handoffKey(sdk.directory, session))) return + + setSaved("session", session, { + agent: msg.agent, + model: msg.model, + variant: msg.variant ?? null, + }) + }, + }, } + + if (modelEnabled()) { + createEffect(() => { + const agent = result.agent.current() + const model = result.model.current() + modelProbe.set({ + dir: sdk.directory, + sessionID: id(), + last: store.last, + agent: agent?.name, + model: model + ? { + providerID: model.provider.id, + modelID: model.id, + name: model.name, + } + : undefined, + variant: result.model.variant.current() ?? null, + selected: result.model.variant.selected(), + configured: result.model.variant.configured(), + pick: scope(), + base: undefined, + current: store.current, + }) + }) + + onCleanup(() => modelProbe.clear()) + } + return result }, }) diff --git a/packages/app/src/context/model-variant.test.ts b/packages/app/src/context/model-variant.test.ts index 01b149fd2..583bc5c3d 100644 --- a/packages/app/src/context/model-variant.test.ts +++ b/packages/app/src/context/model-variant.test.ts @@ -44,6 +44,16 @@ describe("model variant", () => { expect(value).toBe("high") }) + test("lets an explicit default override the configured variant", () => { + const value = resolveModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBeUndefined() + }) + test("cycles from configured variant to next", () => { const value = cycleModelVariant({ variants: ["low", "high", "xhigh"], @@ -63,4 +73,14 @@ describe("model variant", () => { expect(value).toBe("low") }) + + test("cycles from an explicit default to the first variant", () => { + const value = cycleModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBe("low") + }) }) diff --git a/packages/app/src/context/model-variant.ts b/packages/app/src/context/model-variant.ts index 6b7ae7256..525acbba3 100644 --- a/packages/app/src/context/model-variant.ts +++ b/packages/app/src/context/model-variant.ts @@ -14,7 +14,7 @@ type Model = AgentModel & { type VariantInput = { variants: string[] - selected: string | undefined + selected: string | null | undefined configured: string | undefined } @@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod } export function resolveModelVariant(input: VariantInput) { + if (input.selected === null) return undefined if (input.selected && input.variants.includes(input.selected)) return input.selected if (input.configured && input.variants.includes(input.configured)) return input.configured return undefined @@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) { export function cycleModelVariant(input: VariantInput) { if (input.variants.length === 0) return undefined + if (input.selected === null) return input.variants[0] if (input.selected && input.variants.includes(input.selected)) { const index = input.variants.indexOf(input.selected) if (index === input.variants.length - 1) return undefined diff --git a/packages/app/src/pages/directory-layout.tsx b/packages/app/src/pages/directory-layout.tsx index fdf321f2d..f993ffcd8 100644 --- a/packages/app/src/pages/directory-layout.tsx +++ b/packages/app/src/pages/directory-layout.tsx @@ -80,11 +80,11 @@ export default function Layout(props: ParentProps) { }) return ( - <Show when={state.resolved}> + <Show when={state.resolved} keyed> {(resolved) => ( - <SDKProvider directory={resolved}> + <SDKProvider directory={() => resolved}> <SyncProvider> - <DirectoryDataProvider directory={resolved()}>{props.children}</DirectoryDataProvider> + <DirectoryDataProvider directory={resolved}>{props.children}</DirectoryDataProvider> </SyncProvider> </SDKProvider> )} diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx index 8399a1367..6d2917008 100644 --- a/packages/app/src/pages/session.tsx +++ b/packages/app/src/pages/session.tsx @@ -44,7 +44,7 @@ import { createOpenReviewFile, createSessionTabs, createSizing, focusTerminalByI import { MessageTimeline } from "@/pages/session/message-timeline" import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab" import { useSessionLayout } from "@/pages/session/session-layout" -import { resetSessionModel, syncSessionModel } from "@/pages/session/session-model-helpers" +import { syncSessionModel } from "@/pages/session/session-model-helpers" import { SessionSidePanel } from "@/pages/session/session-side-panel" import { TerminalPanel } from "@/pages/session/terminal-panel" import { useSessionCommands } from "@/pages/session/use-session-commands" @@ -490,7 +490,7 @@ export default function Page() { (next, prev) => { if (!prev) return if (next.dir === prev.dir && next.id === prev.id) return - if (!next.id) resetSessionModel(local) + if (prev.id && !next.id) local.session.reset() }, { defer: true }, ), diff --git a/packages/app/src/pages/session/session-model-helpers.test.ts b/packages/app/src/pages/session/session-model-helpers.test.ts index 5f554dcd3..319db805d 100644 --- a/packages/app/src/pages/session/session-model-helpers.test.ts +++ b/packages/app/src/pages/session/session-model-helpers.test.ts @@ -14,145 +14,38 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant" }) as UserMessage describe("syncSessionModel", () => { - test("restores the last message model and variant", () => { + test("restores the last message through session state", () => { const calls: unknown[] = [] syncSessionModel( { - agent: { - current() { - return undefined - }, - set(value) { - calls.push(["agent", value]) - }, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return { id: "claude-sonnet-4", provider: { id: "anthropic" } } - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, - }, - }, - }, - message({ variant: "high" }), - ) - - expect(calls).toEqual([ - ["agent", "build"], - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", "high"], - ]) - }) - - test("skips variant when the model falls back", () => { - const calls: unknown[] = [] - - syncSessionModel( - { - agent: { - current() { - return undefined - }, - set(value) { - calls.push(["agent", value]) - }, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return { id: "gpt-5", provider: { id: "openai" } } - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, + session: { + restore(value) { + calls.push(value) }, + reset() {}, }, }, message({ variant: "high" }), ) - expect(calls).toEqual([ - ["agent", "build"], - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ]) + expect(calls).toEqual([message({ variant: "high" })]) }) }) describe("resetSessionModel", () => { - test("restores the current agent defaults", () => { - const calls: unknown[] = [] + test("clears draft session state", () => { + const calls: string[] = [] resetSessionModel({ - agent: { - current() { - return { - model: { providerID: "anthropic", modelID: "claude-sonnet-4" }, - variant: "high", - } - }, - set() {}, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return undefined - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, - }, - }, - }) - - expect(calls).toEqual([ - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", "high"], - ]) - }) - - test("clears the variant when the agent has none", () => { - const calls: unknown[] = [] - - resetSessionModel({ - agent: { - current() { - return { - model: { providerID: "anthropic", modelID: "claude-sonnet-4" }, - } - }, - set() {}, - }, - model: { - set(value) { - calls.push(["model", value]) - }, - current() { - return undefined - }, - variant: { - set(value) { - calls.push(["variant", value]) - }, + session: { + reset() { + calls.push("reset") }, + restore() {}, }, }) - expect(calls).toEqual([ - ["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }], - ["variant", undefined], - ]) + expect(calls).toEqual(["reset"]) }) }) diff --git a/packages/app/src/pages/session/session-model-helpers.ts b/packages/app/src/pages/session/session-model-helpers.ts index 7600f16d5..c9e2e1dbd 100644 --- a/packages/app/src/pages/session/session-model-helpers.ts +++ b/packages/app/src/pages/session/session-model-helpers.ts @@ -1,48 +1,16 @@ import type { UserMessage } from "@opencode-ai/sdk/v2" -import { batch } from "solid-js" type Local = { - agent: { - current(): - | { - model?: UserMessage["model"] - variant?: string - } - | undefined - set(name: string | undefined): void - } - model: { - set(model: UserMessage["model"] | undefined): void - current(): - | { - id: string - provider: { id: string } - } - | undefined - variant: { - set(value: string | undefined): void - } + session: { + reset(): void + restore(msg: UserMessage): void } } export const resetSessionModel = (local: Local) => { - const agent = local.agent.current() - if (!agent) return - batch(() => { - local.model.set(agent.model) - local.model.variant.set(agent.variant) - }) + local.session.reset() } export const syncSessionModel = (local: Local, msg: UserMessage) => { - batch(() => { - local.agent.set(msg.agent) - local.model.set(msg.model) - }) - - const model = local.model.current() - if (!model) return - if (model.provider.id !== msg.model.providerID) return - if (model.id !== msg.model.modelID) return - local.model.variant.set(msg.variant) + local.session.restore(msg) } diff --git a/packages/app/src/pages/session/use-session-commands.tsx b/packages/app/src/pages/session/use-session-commands.tsx index f5a4c0576..1a2e777f5 100644 --- a/packages/app/src/pages/session/use-session-commands.tsx +++ b/packages/app/src/pages/session/use-session-commands.tsx @@ -351,7 +351,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => { description: language.t("command.model.choose.description"), keybind: "mod+'", slash: "model", - onSelect: () => dialog.show(() => <DialogSelectModel />), + onSelect: () => dialog.show(() => <DialogSelectModel model={local.model} />), }), mcpCommand({ id: "mcp.toggle", diff --git a/packages/app/src/testing/model-selection.ts b/packages/app/src/testing/model-selection.ts new file mode 100644 index 000000000..a5ea199ac --- /dev/null +++ b/packages/app/src/testing/model-selection.ts @@ -0,0 +1,80 @@ +type ModelKey = { + providerID: string + modelID: string +} + +type State = { + agent?: string + model?: ModelKey | null + variant?: string | null +} + +export type ModelProbeState = { + dir?: string + sessionID?: string + last?: { + type: "agent" | "model" | "variant" + agent?: string + model?: ModelKey | null + variant?: string | null + } + agent?: string + model?: (ModelKey & { name?: string }) | undefined + variant?: string | null + selected?: string | null + configured?: string + pick?: State + base?: State + current?: string +} + +export type ModelWindow = Window & { + __opencode_e2e?: { + model?: { + enabled?: boolean + current?: ModelProbeState + } + } +} + +const clone = (state?: State) => { + if (!state) return undefined + return { + ...state, + model: state.model ? { ...state.model } : state.model, + } +} + +export const modelEnabled = () => { + if (typeof window === "undefined") return false + return (window as ModelWindow).__opencode_e2e?.model?.enabled === true +} + +const root = () => { + if (!modelEnabled()) return + return (window as ModelWindow).__opencode_e2e?.model +} + +export const modelProbe = { + set(input: ModelProbeState) { + const state = root() + if (!state) return + state.current = { + ...input, + model: input.model ? { ...input.model } : undefined, + last: input.last + ? { + ...input.last, + model: input.last.model ? { ...input.last.model } : input.last.model, + } + : undefined, + pick: clone(input.pick), + base: clone(input.base), + } + }, + clear() { + const state = root() + if (!state) return + state.current = undefined + }, +} diff --git a/packages/app/src/testing/terminal.ts b/packages/app/src/testing/terminal.ts index 4c179dee3..af1c33309 100644 --- a/packages/app/src/testing/terminal.ts +++ b/packages/app/src/testing/terminal.ts @@ -1,3 +1,5 @@ +import type { ModelProbeState } from "./model-selection" + export const terminalAttr = "data-pty-id" export type TerminalProbeState = { @@ -13,6 +15,10 @@ type TerminalProbeControl = { export type E2EWindow = Window & { __opencode_e2e?: { + model?: { + enabled?: boolean + current?: ModelProbeState + } terminal?: { enabled?: boolean terminals?: Record<string, TerminalProbeState> |
