diff options
| author | Adam <[email protected]> | 2026-03-13 11:05:08 -0500 |
|---|---|---|
| committer | GitHub <[email protected]> | 2026-03-13 11:05:08 -0500 |
| commit | 4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e (patch) | |
| tree | b7e5ed2b05aabb5ed5134520c4eb485c52eb5333 /packages/app/src/context | |
| parent | 5c7088338c07ad632834ebd4a87feb23d255fb8a (diff) | |
| download | opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.tar.gz opencode-4ad8116ce37a0e77e7f3c0e9e4e1002bba05b15e.zip | |
fix(app): model selection persist by session (#17348)
Diffstat (limited to 'packages/app/src/context')
| -rw-r--r-- | packages/app/src/context/local.tsx | 551 | ||||
| -rw-r--r-- | packages/app/src/context/model-variant.test.ts | 20 | ||||
| -rw-r--r-- | packages/app/src/context/model-variant.ts | 4 |
3 files changed, 383 insertions, 192 deletions
diff --git a/packages/app/src/context/local.tsx b/packages/app/src/context/local.tsx index 75d1334a5..bed7ecd15 100644 --- a/packages/app/src/context/local.tsx +++ b/packages/app/src/context/local.tsx @@ -1,252 +1,421 @@ -import { createStore } from "solid-js/store" -import { batch, createMemo } from "solid-js" import { createSimpleContext } from "@opencode-ai/ui/context" -import { useSDK } from "./sdk" -import { useSync } from "./sync" import { base64Encode } from "@opencode-ai/util/encode" -import { useProviders } from "@/hooks/use-providers" +import { useParams } from "@solidjs/router" +import { batch, createEffect, createMemo, onCleanup } from "solid-js" +import { createStore } from "solid-js/store" import { useModels } from "@/context/models" +import { useProviders } from "@/hooks/use-providers" +import { modelEnabled, modelProbe } from "@/testing/model-selection" +import { Persist, persisted } from "@/utils/persist" import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant" +import { useSDK } from "./sdk" +import { useSync } from "./sync" export type ModelKey = { providerID: string; modelID: string } +type State = { + agent?: string + model?: ModelKey + variant?: string | null +} + +type Saved = { + session: Record<string, State | undefined> +} + +const WORKSPACE_KEY = "__workspace__" +const handoff = new Map<string, State>() + +const handoffKey = (dir: string, id: string) => `${dir}\n${id}` + +const migrate = (value: unknown) => { + if (!value || typeof value !== "object") return { session: {} } + + const item = value as { + session?: Record<string, State | undefined> + pick?: Record<string, State | undefined> + } + + if (item.session && typeof item.session === "object") return { session: item.session } + if (!item.pick || typeof item.pick !== "object") return { session: {} } + + return { + session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)), + } +} + +const clone = (value: State | undefined) => { + if (!value) return undefined + return { + ...value, + model: value.model ? { ...value.model } : undefined, + } satisfies State +} + export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ name: "Local", init: () => { + const params = useParams() const sdk = useSDK() const sync = useSync() const providers = useProviders() - const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id))) + const models = useModels() + + const id = createMemo(() => params.id || undefined) + const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden)) + const connected = createMemo(() => new Set(providers.connected().map((item) => item.id))) - function isModelValid(model: ModelKey) { - const provider = providers.all().find((x) => x.id === model.providerID) + const [saved, setSaved] = persisted( + { + ...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]), + migrate, + }, + createStore<Saved>({ + session: {}, + }), + ) + + const [store, setStore] = createStore<{ + current?: string + draft?: State + last?: { + type: "agent" | "model" | "variant" + agent?: string + model?: ModelKey | null + variant?: string | null + } + }>({ + current: list()[0]?.name, + draft: undefined, + last: undefined, + }) + + const validModel = (model: ModelKey) => { + const provider = providers.all().find((item) => item.id === model.providerID) return !!provider?.models[model.modelID] && connected().has(model.providerID) } - function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) { - for (const modelFn of modelFns) { - const model = modelFn() + const firstModel = (...items: Array<() => ModelKey | undefined>) => { + for (const item of items) { + const model = item() if (!model) continue - if (isModelValid(model)) return model + if (validModel(model)) return model } } - let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined + const pickAgent = (name: string | undefined) => { + const items = list() + if (items.length === 0) return undefined + return items.find((item) => item.name === name) ?? items[0] + } - const agent = (() => { - const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden)) - const models = useModels() + createEffect(() => { + const items = list() + if (items.length === 0) { + if (store.current !== undefined) setStore("current", undefined) + return + } + if (items.some((item) => item.name === store.current)) return + setStore("current", items[0]?.name) + }) - const [store, setStore] = createStore<{ - current?: string - }>({ - current: list()[0]?.name, - }) - return { - list, - current() { - const available = list() - if (available.length === 0) return undefined - return available.find((x) => x.name === store.current) ?? available[0] - }, - set(name: string | undefined) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - const match = name ? available.find((x) => x.name === name) : undefined - const value = match ?? available[0] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, - move(direction: 1 | -1) { - const available = list() - if (available.length === 0) { - setStore("current", undefined) - return - } - let next = available.findIndex((x) => x.name === store.current) + direction - if (next < 0) next = available.length - 1 - if (next >= available.length) next = 0 - const value = available[next] - if (!value) return - setStore("current", value.name) - if (!value.model) return - setModel({ - providerID: value.model.providerID, - modelID: value.model.modelID, - }) - if (value.variant) - models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant) - }, + const scope = createMemo<State | undefined>(() => { + const session = id() + if (!session) return store.draft + return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session)) + }) + + createEffect(() => { + const session = id() + if (!session) return + + const key = handoffKey(sdk.directory, session) + const next = handoff.get(key) + if (!next) return + if (saved.session[session] !== undefined) { + handoff.delete(key) + return } - })() - const model = (() => { - const models = useModels() + setSaved("session", session, clone(next)) + handoff.delete(key) + }) - const [ephemeral, setEphemeral] = createStore<{ - model: Record<string, ModelKey | undefined> - }>({ - model: {}, - }) + const configuredModel = () => { + if (!sync.data.config.model) return + const [providerID, modelID] = sync.data.config.model.split("/") + const model = { providerID, modelID } + if (validModel(model)) return model + } - const resolveConfigured = () => { - if (!sync.data.config.model) return - const [providerID, modelID] = sync.data.config.model.split("/") - const key = { providerID, modelID } - if (isModelValid(key)) return key + const recentModel = () => { + for (const item of models.recent.list()) { + if (validModel(item)) return item } + } - const resolveRecent = () => { - for (const item of models.recent.list()) { - if (isModelValid(item)) return item + const defaultModel = () => { + const defaults = providers.default() + for (const provider of providers.connected()) { + const configured = defaults[provider.id] + if (configured) { + const model = { providerID: provider.id, modelID: configured } + if (validModel(model)) return model } + + const first = Object.values(provider.models)[0] + if (!first) continue + const model = { providerID: provider.id, modelID: first.id } + if (validModel(model)) return model } + } - const resolveDefault = () => { - const defaults = providers.default() - for (const provider of providers.connected()) { - const configured = defaults[provider.id] - if (configured) { - const key = { providerID: provider.id, modelID: configured } - if (isModelValid(key)) return key - } + const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel()) - const first = Object.values(provider.models)[0] - if (!first) continue - const key = { providerID: provider.id, modelID: first.id } - if (isModelValid(key)) return key + const agent = { + list, + current() { + return pickAgent(scope()?.agent ?? store.current) + }, + set(name: string | undefined) { + const item = pickAgent(name) + if (!item) { + setStore("current", undefined) + return } - } - const fallbackModel = createMemo<ModelKey | undefined>(() => { - return resolveConfigured() ?? resolveRecent() ?? resolveDefault() - }) + batch(() => { + setStore("current", item.name) + setStore("last", { + type: "agent", + agent: item.name, + model: item.model, + variant: item.variant ?? null, + }) + const next = { + agent: item.name, + model: item.model, + variant: item.variant, + } satisfies State + const session = id() + if (session) { + setSaved("session", session, next) + return + } + setStore("draft", next) + }) + }, + move(direction: 1 | -1) { + const items = list() + if (items.length === 0) { + setStore("current", undefined) + return + } - const current = createMemo(() => { - const a = agent.current() - if (!a) return undefined - const key = getFirstValidModel( - () => ephemeral.model[a.name], - () => a.model, - fallbackModel, - ) - if (!key) return undefined - return models.find(key) - }) + let next = items.findIndex((item) => item.name === agent.current()?.name) + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + const item = items[next] + if (!item) return + agent.set(item.name) + }, + } - const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) + const current = () => { + const item = firstModel( + () => scope()?.model, + () => agent.current()?.model, + fallback, + ) + if (!item) return undefined + return models.find(item) + } - const cycle = (direction: 1 | -1) => { - const recentList = recent() - const currentModel = current() - if (!currentModel) return + const configured = () => { + const item = agent.current() + const model = current() + if (!item || !model) return undefined + return getConfiguredAgentVariant({ + agent: { model: item.model, variant: item.variant }, + model: { providerID: model.provider.id, modelID: model.id, variants: model.variants }, + }) + } - const index = recentList.findIndex( - (x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id, - ) - if (index === -1) return + const selected = () => scope()?.variant - let next = index + direction - if (next < 0) next = recentList.length - 1 - if (next >= recentList.length) next = 0 + const snapshot = () => { + const model = current() + return { + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : undefined, + variant: selected(), + } satisfies State + } - const val = recentList[next] - if (!val) return + const write = (next: Partial<State>) => { + const state = { + ...(scope() ?? { agent: agent.current()?.name }), + ...next, + } satisfies State - model.set({ - providerID: val.provider.id, - modelID: val.id, - }) + const session = id() + if (session) { + setSaved("session", session, state) + return } + setStore("draft", state) + } - const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => { - batch(() => { - const currentAgent = agent.current() - const next = model ?? fallbackModel() - if (currentAgent) setEphemeral("model", currentAgent.name, next) - if (model) models.setVisibility(model, true) - if (options?.recent && model) models.recent.push(model) - }) - } + const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean)) - setModel = set + const model = { + ready: models.ready, + current, + recent, + list: models.list, + cycle(direction: 1 | -1) { + const items = recent() + const item = current() + if (!item) return - return { - ready: models.ready, - current, - recent, - list: models.list, - cycle, - set, - visible(model: ModelKey) { - return models.visible(model) + const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id) + if (index === -1) return + + let next = index + direction + if (next < 0) next = items.length - 1 + if (next >= items.length) next = 0 + + const entry = items[next] + if (!entry) return + model.set({ providerID: entry.provider.id, modelID: entry.id }) + }, + set(item: ModelKey | undefined, options?: { recent?: boolean }) { + batch(() => { + setStore("last", { + type: "model", + agent: agent.current()?.name, + model: item ?? null, + variant: selected(), + }) + write({ model: item }) + if (!item) return + models.setVisibility(item, true) + if (!options?.recent) return + models.recent.push(item) + }) + }, + visible(item: ModelKey) { + return models.visible(item) + }, + setVisibility(item: ModelKey, visible: boolean) { + models.setVisibility(item, visible) + }, + variant: { + configured, + selected, + current() { + return resolveModelVariant({ + variants: this.list(), + selected: this.selected(), + configured: this.configured(), + }) }, - setVisibility(model: ModelKey, visible: boolean) { - models.setVisibility(model, visible) + list() { + const item = current() + if (!item?.variants) return [] + return Object.keys(item.variants) }, - variant: { - configured() { - const a = agent.current() - const m = current() - if (!a || !m) return undefined - return getConfiguredAgentVariant({ - agent: { model: a.model, variant: a.variant }, - model: { providerID: m.provider.id, modelID: m.id, variants: m.variants }, + set(value: string | undefined) { + batch(() => { + const model = current() + setStore("last", { + type: "variant", + agent: agent.current()?.name, + model: model ? { providerID: model.provider.id, modelID: model.id } : null, + variant: value ?? null, }) - }, - selected() { - const m = current() - if (!m) return undefined - return models.variant.get({ providerID: m.provider.id, modelID: m.id }) - }, - current() { - return resolveModelVariant({ - variants: this.list(), + write({ variant: value ?? null }) + }) + }, + cycle() { + const items = this.list() + if (items.length === 0) return + this.set( + cycleModelVariant({ + variants: items, selected: this.selected(), configured: this.configured(), - }) - }, - list() { - const m = current() - if (!m) return [] - if (!m.variants) return [] - return Object.keys(m.variants) - }, - set(value: string | undefined) { - const m = current() - if (!m) return - models.variant.set({ providerID: m.provider.id, modelID: m.id }, value) - }, - cycle() { - const variants = this.list() - if (variants.length === 0) return - this.set( - cycleModelVariant({ - variants, - selected: this.selected(), - configured: this.configured(), - }), - ) - }, + }), + ) }, - } - })() + }, + } const result = { slug: createMemo(() => base64Encode(sdk.directory)), model, agent, + session: { + reset() { + setStore("draft", undefined) + }, + promote(dir: string, session: string) { + const next = clone(snapshot()) + if (!next) return + + if (dir === sdk.directory) { + setSaved("session", session, next) + setStore("draft", undefined) + return + } + + handoff.set(handoffKey(dir, session), next) + setStore("draft", undefined) + }, + restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) { + const session = id() + if (!session) return + if (msg.sessionID !== session) return + if (saved.session[session] !== undefined) return + if (handoff.has(handoffKey(sdk.directory, session))) return + + setSaved("session", session, { + agent: msg.agent, + model: msg.model, + variant: msg.variant ?? null, + }) + }, + }, } + + if (modelEnabled()) { + createEffect(() => { + const agent = result.agent.current() + const model = result.model.current() + modelProbe.set({ + dir: sdk.directory, + sessionID: id(), + last: store.last, + agent: agent?.name, + model: model + ? { + providerID: model.provider.id, + modelID: model.id, + name: model.name, + } + : undefined, + variant: result.model.variant.current() ?? null, + selected: result.model.variant.selected(), + configured: result.model.variant.configured(), + pick: scope(), + base: undefined, + current: store.current, + }) + }) + + onCleanup(() => modelProbe.clear()) + } + return result }, }) diff --git a/packages/app/src/context/model-variant.test.ts b/packages/app/src/context/model-variant.test.ts index 01b149fd2..583bc5c3d 100644 --- a/packages/app/src/context/model-variant.test.ts +++ b/packages/app/src/context/model-variant.test.ts @@ -44,6 +44,16 @@ describe("model variant", () => { expect(value).toBe("high") }) + test("lets an explicit default override the configured variant", () => { + const value = resolveModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBeUndefined() + }) + test("cycles from configured variant to next", () => { const value = cycleModelVariant({ variants: ["low", "high", "xhigh"], @@ -63,4 +73,14 @@ describe("model variant", () => { expect(value).toBe("low") }) + + test("cycles from an explicit default to the first variant", () => { + const value = cycleModelVariant({ + variants: ["low", "high", "xhigh"], + selected: null, + configured: "xhigh", + }) + + expect(value).toBe("low") + }) }) diff --git a/packages/app/src/context/model-variant.ts b/packages/app/src/context/model-variant.ts index 6b7ae7256..525acbba3 100644 --- a/packages/app/src/context/model-variant.ts +++ b/packages/app/src/context/model-variant.ts @@ -14,7 +14,7 @@ type Model = AgentModel & { type VariantInput = { variants: string[] - selected: string | undefined + selected: string | null | undefined configured: string | undefined } @@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod } export function resolveModelVariant(input: VariantInput) { + if (input.selected === null) return undefined if (input.selected && input.variants.includes(input.selected)) return input.selected if (input.configured && input.variants.includes(input.configured)) return input.configured return undefined @@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) { export function cycleModelVariant(input: VariantInput) { if (input.variants.length === 0) return undefined + if (input.selected === null) return input.variants[0] if (input.selected && input.variants.includes(input.selected)) { const index = input.variants.indexOf(input.selected) if (index === input.variants.length - 1) return undefined |
