summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDominik Engelhardt <[email protected]>2025-08-02 16:29:03 +0200
committerGitHub <[email protected]>2025-08-02 09:29:03 -0500
commit42a5fcead42c01c932d00d59647844e29c137ac0 (patch)
treebb27f84b9a591c9cd18e7fa410af43dcdf50dd27
parent8ad83f71a9f2e02e705301a72c3d0c39c3d9055d (diff)
downloadopencode-42a5fcead42c01c932d00d59647844e29c137ac0.tar.gz
opencode-42a5fcead42c01c932d00d59647844e29c137ac0.zip
Choose model according to the docs (#1536)
-rw-r--r--packages/tui/internal/app/app.go161
-rw-r--r--packages/tui/internal/app/app_test.go228
-rw-r--r--packages/web/src/content/docs/docs/models.mdx10
3 files changed, 343 insertions, 56 deletions
diff --git a/packages/tui/internal/app/app.go b/packages/tui/internal/app/app.go
index 7ef31fd54..a0e68b53a 100644
--- a/packages/tui/internal/app/app.go
+++ b/packages/tui/internal/app/app.go
@@ -270,37 +270,58 @@ func (a *App) SwitchModeReverse() (*App, tea.Cmd) {
return a.cycleMode(false)
}
-func (a *App) InitializeProvider() tea.Cmd {
- providersResponse, err := a.Client.App.Providers(context.Background())
- if err != nil {
- slog.Error("Failed to list providers", "error", err)
- // TODO: notify user
- return nil
+// findModelByFullID finds a model by its full ID in the format "provider/model"
+func findModelByFullID(providers []opencode.Provider, fullModelID string) (*opencode.Provider, *opencode.Model) {
+ modelParts := strings.SplitN(fullModelID, "/", 2)
+ if len(modelParts) < 2 {
+ return nil, nil
}
- providers := providersResponse.Providers
- var defaultProvider *opencode.Provider
- var defaultModel *opencode.Model
- var anthropic *opencode.Provider
+ providerID := modelParts[0]
+ modelID := modelParts[1]
+
+ return findModelByProviderAndModelID(providers, providerID, modelID)
+}
+
+// findModelByProviderAndModelID finds a model by provider ID and model ID
+func findModelByProviderAndModelID(providers []opencode.Provider, providerID, modelID string) (*opencode.Provider, *opencode.Model) {
for _, provider := range providers {
- if provider.ID == "anthropic" {
- anthropic = &provider
+ if provider.ID != providerID {
+ continue
}
- }
- // default to anthropic if available
- if anthropic != nil {
- defaultProvider = anthropic
- defaultModel = getDefaultModel(providersResponse, *anthropic)
+ for _, model := range provider.Models {
+ if model.ID == modelID {
+ return &provider, &model
+ }
+ }
+
+ // Provider found but model not found
+ return nil, nil
}
+ // Provider not found
+ return nil, nil
+}
+
+// findProviderByID finds a provider by its ID
+func findProviderByID(providers []opencode.Provider, providerID string) *opencode.Provider {
for _, provider := range providers {
- if defaultProvider == nil || defaultModel == nil {
- defaultProvider = &provider
- defaultModel = getDefaultModel(providersResponse, provider)
+ if provider.ID == providerID {
+ return &provider
}
- providers = append(providers, provider)
}
+ return nil
+}
+
+func (a *App) InitializeProvider() tea.Cmd {
+ providersResponse, err := a.Client.App.Providers(context.Background())
+ if err != nil {
+ slog.Error("Failed to list providers", "error", err)
+ // TODO: notify user
+ return nil
+ }
+ providers := providersResponse.Providers
if len(providers) == 0 {
slog.Error("No providers configured")
return nil
@@ -314,50 +335,86 @@ func (a *App) InitializeProvider() tea.Cmd {
a.State.Model = model.ModelID
}
- var currentProvider *opencode.Provider
- var currentModel *opencode.Model
- for _, provider := range providers {
- if provider.ID == a.State.Provider {
- currentProvider = &provider
+ var selectedProvider *opencode.Provider
+ var selectedModel *opencode.Model
- for _, model := range provider.Models {
- if model.ID == a.State.Model {
- currentModel = &model
- }
- }
+ // Priority 1: Command line --model flag (InitialModel)
+ if a.InitialModel != nil && *a.InitialModel != "" {
+ if provider, model := findModelByFullID(providers, *a.InitialModel); provider != nil && model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from command line", "provider", provider.ID, "model", model.ID)
+ } else {
+ slog.Debug("Command line model not found", "model", *a.InitialModel)
}
}
- if currentProvider == nil || currentModel == nil {
- currentProvider = defaultProvider
- currentModel = defaultModel
+
+ // Priority 2: Config file model setting
+ if selectedProvider == nil && a.Config.Model != "" {
+ if provider, model := findModelByFullID(providers, a.Config.Model); provider != nil && model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from config", "provider", provider.ID, "model", model.ID)
+ } else {
+ slog.Debug("Config model not found", "model", a.Config.Model)
+ }
}
- var initialProvider *opencode.Provider
- var initialModel *opencode.Model
- if a.InitialModel != nil && *a.InitialModel != "" {
- splits := strings.Split(*a.InitialModel, "/")
- for _, provider := range providers {
- if provider.ID == splits[0] {
- initialProvider = &provider
- for _, model := range provider.Models {
- modelID := strings.Join(splits[1:], "/")
- if model.ID == modelID {
- initialModel = &model
- }
- }
+ // Priority 3: Recent model usage (most recently used model)
+ if selectedProvider == nil && len(a.State.RecentlyUsedModels) > 0 {
+ recentUsage := a.State.RecentlyUsedModels[0] // Most recent is first
+ if provider, model := findModelByProviderAndModelID(providers, recentUsage.ProviderID, recentUsage.ModelID); provider != nil && model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from recent usage", "provider", provider.ID, "model", model.ID)
+ } else {
+ slog.Debug("Recent model not found", "provider", recentUsage.ProviderID, "model", recentUsage.ModelID)
+ }
+ }
+
+ // Priority 4: State-based model (backwards compatibility)
+ if selectedProvider == nil && a.State.Provider != "" && a.State.Model != "" {
+ if provider, model := findModelByProviderAndModelID(providers, a.State.Provider, a.State.Model); provider != nil && model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from state", "provider", provider.ID, "model", model.ID)
+ } else {
+ slog.Debug("State model not found", "provider", a.State.Provider, "model", a.State.Model)
+ }
+ }
+
+ // Priority 5: Internal priority fallback (Anthropic preferred, then first available)
+ if selectedProvider == nil {
+ // Try Anthropic first as internal priority
+ if provider := findProviderByID(providers, "anthropic"); provider != nil {
+ if model := getDefaultModel(providersResponse, *provider); model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from internal priority (Anthropic)", "provider", provider.ID, "model", model.ID)
+ }
+ }
+
+ // If Anthropic not available, use first available provider
+ if selectedProvider == nil && len(providers) > 0 {
+ provider := &providers[0]
+ if model := getDefaultModel(providersResponse, *provider); model != nil {
+ selectedProvider = provider
+ selectedModel = model
+ slog.Debug("Selected model from fallback (first available)", "provider", provider.ID, "model", model.ID)
}
}
}
- if initialProvider != nil && initialModel != nil {
- currentProvider = initialProvider
- currentModel = initialModel
+ // Final safety check
+ if selectedProvider == nil || selectedModel == nil {
+ slog.Error("Failed to select any model")
+ return nil
}
var cmds []tea.Cmd
cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
- Provider: *currentProvider,
- Model: *currentModel,
+ Provider: *selectedProvider,
+ Model: *selectedModel,
}))
if a.InitialPrompt != nil && *a.InitialPrompt != "" {
cmds = append(cmds, util.CmdHandler(SendPrompt{Text: *a.InitialPrompt}))
diff --git a/packages/tui/internal/app/app_test.go b/packages/tui/internal/app/app_test.go
new file mode 100644
index 000000000..9260a9915
--- /dev/null
+++ b/packages/tui/internal/app/app_test.go
@@ -0,0 +1,228 @@
+package app
+
+import (
+ "testing"
+
+ "github.com/sst/opencode-sdk-go"
+)
+
+// TestFindModelByFullID tests the findModelByFullID function
+func TestFindModelByFullID(t *testing.T) {
+ // Create test providers with models
+ providers := []opencode.Provider{
+ {
+ ID: "anthropic",
+ Models: map[string]opencode.Model{
+ "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
+ "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
+ },
+ },
+ {
+ ID: "openai",
+ Models: map[string]opencode.Model{
+ "gpt-4": {ID: "gpt-4"},
+ "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ fullModelID string
+ expectedFound bool
+ expectedProviderID string
+ expectedModelID string
+ }{
+ {
+ name: "valid full model ID",
+ fullModelID: "anthropic/claude-3-opus-20240229",
+ expectedFound: true,
+ expectedProviderID: "anthropic",
+ expectedModelID: "claude-3-opus-20240229",
+ },
+ {
+ name: "valid full model ID with slash in model name",
+ fullModelID: "openai/gpt-3.5-turbo",
+ expectedFound: true,
+ expectedProviderID: "openai",
+ expectedModelID: "gpt-3.5-turbo",
+ },
+ {
+ name: "invalid format - missing slash",
+ fullModelID: "anthropic",
+ expectedFound: false,
+ },
+ {
+ name: "invalid format - empty string",
+ fullModelID: "",
+ expectedFound: false,
+ },
+ {
+ name: "provider not found",
+ fullModelID: "nonexistent/model",
+ expectedFound: false,
+ },
+ {
+ name: "model not found",
+ fullModelID: "anthropic/nonexistent-model",
+ expectedFound: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider, model := findModelByFullID(providers, tt.fullModelID)
+
+ if tt.expectedFound {
+ if provider == nil || model == nil {
+ t.Errorf("Expected to find provider/model, but got nil")
+ return
+ }
+
+ if provider.ID != tt.expectedProviderID {
+ t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+ }
+
+ if model.ID != tt.expectedModelID {
+ t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
+ }
+ } else {
+ if provider != nil || model != nil {
+ t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
+ }
+ }
+ })
+ }
+}
+
+// TestFindModelByProviderAndModelID tests the findModelByProviderAndModelID function
+func TestFindModelByProviderAndModelID(t *testing.T) {
+ // Create test providers with models
+ providers := []opencode.Provider{
+ {
+ ID: "anthropic",
+ Models: map[string]opencode.Model{
+ "claude-3-opus-20240229": {ID: "claude-3-opus-20240229"},
+ "claude-3-sonnet-20240229": {ID: "claude-3-sonnet-20240229"},
+ },
+ },
+ {
+ ID: "openai",
+ Models: map[string]opencode.Model{
+ "gpt-4": {ID: "gpt-4"},
+ "gpt-3.5-turbo": {ID: "gpt-3.5-turbo"},
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ providerID string
+ modelID string
+ expectedFound bool
+ expectedProviderID string
+ expectedModelID string
+ }{
+ {
+ name: "valid provider and model",
+ providerID: "anthropic",
+ modelID: "claude-3-opus-20240229",
+ expectedFound: true,
+ expectedProviderID: "anthropic",
+ expectedModelID: "claude-3-opus-20240229",
+ },
+ {
+ name: "provider not found",
+ providerID: "nonexistent",
+ modelID: "claude-3-opus-20240229",
+ expectedFound: false,
+ },
+ {
+ name: "model not found",
+ providerID: "anthropic",
+ modelID: "nonexistent-model",
+ expectedFound: false,
+ },
+ {
+ name: "both provider and model not found",
+ providerID: "nonexistent",
+ modelID: "nonexistent-model",
+ expectedFound: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider, model := findModelByProviderAndModelID(providers, tt.providerID, tt.modelID)
+
+ if tt.expectedFound {
+ if provider == nil || model == nil {
+ t.Errorf("Expected to find provider/model, but got nil")
+ return
+ }
+
+ if provider.ID != tt.expectedProviderID {
+ t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+ }
+
+ if model.ID != tt.expectedModelID {
+ t.Errorf("Expected model ID %s, got %s", tt.expectedModelID, model.ID)
+ }
+ } else {
+ if provider != nil || model != nil {
+ t.Errorf("Expected not to find provider/model, but got provider: %v, model: %v", provider, model)
+ }
+ }
+ })
+ }
+}
+
+// TestFindProviderByID tests the findProviderByID function
+func TestFindProviderByID(t *testing.T) {
+ // Create test providers
+ providers := []opencode.Provider{
+ {ID: "anthropic"},
+ {ID: "openai"},
+ {ID: "google"},
+ }
+
+ tests := []struct {
+ name string
+ providerID string
+ expectedFound bool
+ expectedProviderID string
+ }{
+ {
+ name: "provider found",
+ providerID: "anthropic",
+ expectedFound: true,
+ expectedProviderID: "anthropic",
+ },
+ {
+ name: "provider not found",
+ providerID: "nonexistent",
+ expectedFound: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider := findProviderByID(providers, tt.providerID)
+
+ if tt.expectedFound {
+ if provider == nil {
+ t.Errorf("Expected to find provider, but got nil")
+ return
+ }
+
+ if provider.ID != tt.expectedProviderID {
+ t.Errorf("Expected provider ID %s, got %s", tt.expectedProviderID, provider.ID)
+ }
+ } else {
+ if provider != nil {
+ t.Errorf("Expected not to find provider, but got %v", provider)
+ }
+ }
+ })
+ }
+}
diff --git a/packages/web/src/content/docs/docs/models.mdx b/packages/web/src/content/docs/docs/models.mdx
index 591625f8f..5308921a3 100644
--- a/packages/web/src/content/docs/docs/models.mdx
+++ b/packages/web/src/content/docs/docs/models.mdx
@@ -66,9 +66,11 @@ If you've configured a [custom provider](/docs/providers#custom), the `provider_
## Loading models
-When opencode starts up, it checks for the following:
+When opencode starts up, it checks for models in the following priority order:
-1. The model list in the opencode config.
+1. The `--model` or `-m` command line flag. The format is the same as in the config file: `provider_id/model_id`.
+
+2. The model list in the opencode config.
```json title="opencode.json"
{
@@ -79,6 +81,6 @@ When opencode starts up, it checks for the following:
The format here is `provider/model`.
-2. The last used model.
+3. The last used model.
-3. The first model using an internal priority.
+4. The first model using an internal priority.