summaryrefslogtreecommitdiffhomepage
path: root/internal/llm
diff options
context:
space:
mode:
authormineo <[email protected]>2025-05-16 03:25:21 +0900
committeradamdottv <[email protected]>2025-05-15 13:35:06 -0500
commit87237b6462b9dfd379b22e69712e8dc516afad9d (patch)
treeffce4fab0e86ad05684738834c52de2f7f1f7a76 /internal/llm
parent5f5f9dad877300bab3fe5442ea141551ba89421b (diff)
downloadopencode-87237b6462b9dfd379b22e69712e8dc516afad9d.tar.gz
opencode-87237b6462b9dfd379b22e69712e8dc516afad9d.zip
feat: support VertexAI provider (#153)
* support: vertexai fix fix set default for vertexai added comment fix fix * create schema * fix README.md * fix order * added pupularity * set tools if tools is exists restore commentout * fix comment * set summarizer model
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/models/models.go2
-rw-r--r--internal/llm/models/vertexai.go38
-rw-r--r--internal/llm/provider/gemini.go18
-rw-r--r--internal/llm/provider/provider.go5
-rw-r--r--internal/llm/provider/vertexai.go34
5 files changed, 91 insertions, 6 deletions
diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go
index 16fd406c8..e4fe6603b 100644
--- a/internal/llm/models/models.go
+++ b/internal/llm/models/models.go
@@ -43,6 +43,7 @@ var ProviderPopularity = map[ModelProvider]int{
ProviderOpenRouter: 5,
ProviderBedrock: 6,
ProviderAzure: 7,
+ ProviderVertexAI: 8,
}
var SupportedModels = map[ModelID]Model{
@@ -95,4 +96,5 @@ func init() {
maps.Copy(SupportedModels, AzureModels)
maps.Copy(SupportedModels, OpenRouterModels)
maps.Copy(SupportedModels, XAIModels)
+ maps.Copy(SupportedModels, VertexAIGeminiModels)
}
diff --git a/internal/llm/models/vertexai.go b/internal/llm/models/vertexai.go
new file mode 100644
index 000000000..d71dfc0be
--- /dev/null
+++ b/internal/llm/models/vertexai.go
@@ -0,0 +1,38 @@
+package models
+
+const (
+ ProviderVertexAI ModelProvider = "vertexai"
+
+ // Models
+ VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
+ VertexAIGemini25 ModelID = "vertexai.gemini-2.5"
+)
+
+var VertexAIGeminiModels = map[ModelID]Model{
+ VertexAIGemini25Flash: {
+ ID: VertexAIGemini25Flash,
+ Name: "VertexAI: Gemini 2.5 Flash",
+ Provider: ProviderVertexAI,
+ APIModel: "gemini-2.5-flash-preview-04-17",
+ CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
+ CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
+ CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
+ CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
+ ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
+ DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
+ SupportsAttachments: true,
+ },
+ VertexAIGemini25: {
+ ID: VertexAIGemini25,
+ Name: "VertexAI: Gemini 2.5 Pro",
+ Provider: ProviderVertexAI,
+ APIModel: "gemini-2.5-pro-preview-03-25",
+ CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
+ CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
+ CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
+ CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
+ ContextWindow: GeminiModels[Gemini25].ContextWindow,
+ DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
+ SupportsAttachments: true,
+ },
+}
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go
index cc97463d4..8b8e33698 100644
--- a/internal/llm/provider/gemini.go
+++ b/internal/llm/provider/gemini.go
@@ -176,13 +176,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
- Tools: g.convertTools(tools),
- }, history)
+ }
+ if len(tools) > 0 {
+ config.Tools = g.convertTools(tools)
+ }
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
attempts := 0
for {
@@ -262,13 +265,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
- chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(g.providerOptions.maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
- Tools: g.convertTools(tools),
- }, history)
+ }
+ if len(tools) > 0 {
+ config.Tools = g.convertTools(tools)
+ }
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
attempts := 0
eventChan := make(chan ProviderEvent)
diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go
index 6f2a20bd9..f21e051c2 100644
--- a/internal/llm/provider/provider.go
+++ b/internal/llm/provider/provider.go
@@ -123,6 +123,11 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
options: clientOptions,
client: newAzureClient(clientOptions),
}, nil
+ case models.ProviderVertexAI:
+ return &baseProvider[VertexAIClient]{
+ options: clientOptions,
+ client: newVertexAIClient(clientOptions),
+ }, nil
case models.ProviderOpenRouter:
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go
new file mode 100644
index 000000000..2a13a9572
--- /dev/null
+++ b/internal/llm/provider/vertexai.go
@@ -0,0 +1,34 @@
+package provider
+
+import (
+ "context"
+ "os"
+
+ "github.com/opencode-ai/opencode/internal/logging"
+ "google.golang.org/genai"
+)
+
+type VertexAIClient ProviderClient
+
+func newVertexAIClient(opts providerClientOptions) VertexAIClient {
+ geminiOpts := geminiOptions{}
+ for _, o := range opts.geminiOptions {
+ o(&geminiOpts)
+ }
+
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
+ Project: os.Getenv("VERTEXAI_PROJECT"),
+ Location: os.Getenv("VERTEXAI_LOCATION"),
+ Backend: genai.BackendVertexAI,
+ })
+ if err != nil {
+ logging.Error("Failed to create VertexAI client", "error", err)
+ return nil
+ }
+
+ return &geminiClient{
+ providerOptions: opts,
+ options: geminiOpts,
+ client: client,
+ }
+}