diff options
| author | mineo <[email protected]> | 2025-05-16 03:25:21 +0900 |
|---|---|---|
| committer | adamdottv <[email protected]> | 2025-05-15 13:35:06 -0500 |
| commit | 87237b6462b9dfd379b22e69712e8dc516afad9d (patch) | |
| tree | ffce4fab0e86ad05684738834c52de2f7f1f7a76 /internal | |
| parent | 5f5f9dad877300bab3fe5442ea141551ba89421b (diff) | |
| download | opencode-87237b6462b9dfd379b22e69712e8dc516afad9d.tar.gz opencode-87237b6462b9dfd379b22e69712e8dc516afad9d.zip | |
feat: support VertexAI provider (#153)
* support: vertexai
fix
fix
set default for vertexai
added comment
fix
fix
* create schema
* fix README.md
* fix order
* added pupularity
* set tools if tools is exists
restore commentout
* fix comment
* set summarizer model
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 45 | ||||
| -rw-r--r-- | internal/llm/models/models.go | 2 | ||||
| -rw-r--r-- | internal/llm/models/vertexai.go | 38 | ||||
| -rw-r--r-- | internal/llm/provider/gemini.go | 18 | ||||
| -rw-r--r-- | internal/llm/provider/provider.go | 5 | ||||
| -rw-r--r-- | internal/llm/provider/vertexai.go | 34 |
6 files changed, 136 insertions, 6 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index f9aba238d..1d741bc91 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -235,6 +235,7 @@ func setProviderDefaults() { // 5. OpenRouter // 6. AWS Bedrock // 7. Azure + // 8. Google Cloud VertexAI // Anthropic configuration if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" { @@ -299,6 +300,15 @@ func setProviderDefaults() { viper.SetDefault("agents.title.model", models.AzureGPT41Mini) return } + + // Google Cloud VertexAI configuration + if hasVertexAICredentials() { + viper.SetDefault("agents.coder.model", models.VertexAIGemini25) + viper.SetDefault("agents.summarizer.model", models.VertexAIGemini25) + viper.SetDefault("agents.task.model", models.VertexAIGemini25Flash) + viper.SetDefault("agents.title.model", models.VertexAIGemini25Flash) + return + } } // hasAWSCredentials checks if AWS credentials are available in the environment. @@ -327,6 +337,19 @@ func hasAWSCredentials() bool { return false } +// hasVertexAICredentials checks if VertexAI credentials are available in the environment. +func hasVertexAICredentials() bool { + // Check for explicit VertexAI parameters + if os.Getenv("VERTEXAI_PROJECT") != "" && os.Getenv("VERTEXAI_LOCATION") != "" { + return true + } + // Check for Google Cloud project and location + if os.Getenv("GOOGLE_CLOUD_PROJECT") != "" && (os.Getenv("GOOGLE_CLOUD_REGION") != "" || os.Getenv("GOOGLE_CLOUD_LOCATION") != "") { + return true + } + return false +} + // readConfig handles the result of reading a configuration file. func readConfig(err error) error { if err == nil { @@ -549,6 +572,10 @@ func getProviderAPIKey(provider models.ModelProvider) string { if hasAWSCredentials() { return "aws-credentials-available" } + case models.ProviderVertexAI: + if hasVertexAICredentials() { + return "vertex-ai-credentials-available" + } } return "" } @@ -669,6 +696,24 @@ func setDefaultModelForAgent(agent AgentName) bool { return true } + if hasVertexAICredentials() { + var model models.ModelID + maxTokens := int64(5000) + + if agent == AgentTitle { + model = models.VertexAIGemini25Flash + maxTokens = 80 + } else { + model = models.VertexAIGemini25 + } + + cfg.Agents[agent] = Agent{ + Model: model, + MaxTokens: maxTokens, + } + return true + } + return false } diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 16fd406c8..e4fe6603b 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -43,6 +43,7 @@ var ProviderPopularity = map[ModelProvider]int{ ProviderOpenRouter: 5, ProviderBedrock: 6, ProviderAzure: 7, + ProviderVertexAI: 8, } var SupportedModels = map[ModelID]Model{ @@ -95,4 +96,5 @@ func init() { maps.Copy(SupportedModels, AzureModels) maps.Copy(SupportedModels, OpenRouterModels) maps.Copy(SupportedModels, XAIModels) + maps.Copy(SupportedModels, VertexAIGeminiModels) } diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go new file mode 100644 index 000000000..d71dfc0be --- /dev/null +++ b/internal/llm/models/vertexai.go @@ -0,0 +1,38 @@ +package models + +const ( + ProviderVertexAI ModelProvider = "vertexai" + + // Models + VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash" + VertexAIGemini25 ModelID = "vertexai.gemini-2.5" +) + +var VertexAIGeminiModels = map[ModelID]Model{ + VertexAIGemini25Flash: { + ID: VertexAIGemini25Flash, + Name: "VertexAI: Gemini 2.5 Flash", + Provider: ProviderVertexAI, + APIModel: "gemini-2.5-flash-preview-04-17", + CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn, + CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached, + CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut, + CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached, + ContextWindow: GeminiModels[Gemini25Flash].ContextWindow, + DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens, + SupportsAttachments: true, + }, + VertexAIGemini25: { + ID: VertexAIGemini25, + Name: "VertexAI: Gemini 2.5 Pro", + Provider: ProviderVertexAI, + APIModel: "gemini-2.5-pro-preview-03-25", + CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn, + CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached, + CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut, + CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached, + ContextWindow: GeminiModels[Gemini25].ContextWindow, + DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens, + SupportsAttachments: true, + }, +} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index cc97463d4..8b8e33698 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -176,13 +176,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{ + config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(g.providerOptions.maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, - Tools: g.convertTools(tools), - }, history) + } + if len(tools) > 0 { + config.Tools = g.convertTools(tools) + } + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) attempts := 0 for { @@ -262,13 +265,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{ + config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(g.providerOptions.maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, - Tools: g.convertTools(tools), - }, history) + } + if len(tools) > 0 { + config.Tools = g.convertTools(tools) + } + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) attempts := 0 eventChan := make(chan ProviderEvent) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 6f2a20bd9..f21e051c2 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -123,6 +123,11 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption options: clientOptions, client: newAzureClient(clientOptions), }, nil + case models.ProviderVertexAI: + return &baseProvider[VertexAIClient]{ + options: clientOptions, + client: newVertexAIClient(clientOptions), + }, nil case models.ProviderOpenRouter: clientOptions.openaiOptions = append(clientOptions.openaiOptions, WithOpenAIBaseURL("https://openrouter.ai/api/v1"), diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go new file mode 100644 index 000000000..2a13a9572 --- /dev/null +++ b/internal/llm/provider/vertexai.go @@ -0,0 +1,34 @@ +package provider + +import ( + "context" + "os" + + "github.com/opencode-ai/opencode/internal/logging" + "google.golang.org/genai" +) + +type VertexAIClient ProviderClient + +func newVertexAIClient(opts providerClientOptions) VertexAIClient { + geminiOpts := geminiOptions{} + for _, o := range opts.geminiOptions { + o(&geminiOpts) + } + + client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ + Project: os.Getenv("VERTEXAI_PROJECT"), + Location: os.Getenv("VERTEXAI_LOCATION"), + Backend: genai.BackendVertexAI, + }) + if err != nil { + logging.Error("Failed to create VertexAI client", "error", err) + return nil + } + + return &geminiClient{ + providerOptions: opts, + options: geminiOpts, + client: client, + } +} |
