summaryrefslogtreecommitdiffhomepage
path: root/internal
diff options
context:
space:
mode:
authormineo <[email protected]>2025-05-09 21:15:38 +0900
committeradamdottv <[email protected]>2025-05-12 09:56:30 -0500
commitf92b2b76dc0836b8ad9f4a47a16941efdb2accf6 (patch)
treed4299212ede8c9d3a2734a02317944bae43dfd8c /internal
parent1d1a1ddcbf2ce5bca04fc8ccc6877b2c1c93ef59 (diff)
downloadopencode-f92b2b76dc0836b8ad9f4a47a16941efdb2accf6.tar.gz
opencode-f92b2b76dc0836b8ad9f4a47a16941efdb2accf6.zip
replace `github.com/google/generative-ai-go` with `github.com/googleapis/go-genai` (#138)
* replace to github.com/googleapis/go-genai * fix history logic * small fixes --------- Co-authored-by: Kujtim Hoxha <[email protected]>
Diffstat (limited to 'internal')
-rw-r--r--internal/llm/provider/gemini.go139
1 files changed, 69 insertions, 70 deletions
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go
index c37aee4b6..05ce76e2c 100644
--- a/internal/llm/provider/gemini.go
+++ b/internal/llm/provider/gemini.go
@@ -9,14 +9,12 @@ import (
"strings"
"time"
- "github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/tools"
"github.com/opencode-ai/opencode/internal/message"
"github.com/opencode-ai/opencode/internal/status"
- "google.golang.org/api/iterator"
- "google.golang.org/api/option"
+ "google.golang.org/genai"
"log/slog"
)
@@ -40,7 +38,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
o(&geminiOpts)
}
- client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
slog.Error("Failed to create Gemini client", "error", err)
return nil
@@ -58,11 +56,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
for _, msg := range messages {
switch msg.Role {
case message.User:
- var parts []genai.Part
- parts = append(parts, genai.Text(msg.Content().String()))
+ var parts []*genai.Part
+ parts = append(parts, &genai.Part{Text: msg.Content().String()})
for _, binaryContent := range msg.BinaryContent() {
imageFormat := strings.Split(binaryContent.MIMEType, "/")
- parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data))
+ parts = append(parts, &genai.Part{InlineData: &genai.Blob{
+ MIMEType: imageFormat[1],
+ Data: binaryContent.Data,
+ }})
}
history = append(history, &genai.Content{
Parts: parts,
@@ -71,19 +72,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
case message.Assistant:
content := &genai.Content{
Role: "model",
- Parts: []genai.Part{},
+ Parts: []*genai.Part{},
}
if msg.Content().String() != "" {
- content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
+ content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
}
if len(msg.ToolCalls()) > 0 {
for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
- content.Parts = append(content.Parts, genai.FunctionCall{
- Name: call.Name,
- Args: args,
+ content.Parts = append(content.Parts, &genai.Part{
+ FunctionCall: &genai.FunctionCall{
+ Name: call.Name,
+ Args: args,
+ },
})
}
}
@@ -111,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
}
history = append(history, &genai.Content{
- Parts: []genai.Part{genai.FunctionResponse{
- Name: toolCall.Name,
- Response: response,
- }},
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ Name: toolCall.Name,
+ Response: response,
+ },
+ },
+ },
Role: "function",
})
}
@@ -158,18 +165,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
}
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
- model.SystemInstruction = &genai.Content{
- Parts: []genai.Part{
- genai.Text(g.providerOptions.systemMessage),
- },
- }
- // Convert tools
- if len(tools) > 0 {
- model.Tools = g.convertTools(tools)
- }
-
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -179,16 +174,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
slog.Debug("Prepared messages", "messages", string(jsonData))
}
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ SystemInstruction: &genai.Content{
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ },
+ Tools: g.convertTools(tools),
+ }, history)
+
attempts := 0
for {
attempts++
var toolCalls []message.ToolCall
- chat := model.StartChat()
- chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
-
- lastMsg := geminiMessages[len(geminiMessages)-1]
- resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
+ var lastMsgParts []genai.Part
+ for _, part := range lastMsg.Parts {
+ lastMsgParts = append(lastMsgParts, *part)
+ }
+ resp, err := chat.SendMessage(ctx, lastMsgParts...)
// If there is an error we are going to see if we can retry the call
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -211,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- content = string(p)
- case genai.FunctionCall:
+ switch {
+ case part.Text != "":
+ content = string(part.Text)
+ case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
- Name: p.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -245,18 +250,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
- model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
- model.SystemInstruction = &genai.Content{
- Parts: []genai.Part{
- genai.Text(g.providerOptions.systemMessage),
- },
- }
- // Convert tools
- if len(tools) > 0 {
- model.Tools = g.convertTools(tools)
- }
-
// Convert messages
geminiMessages := g.convertMessages(messages)
@@ -266,6 +259,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
slog.Debug("Prepared messages", "messages", string(jsonData))
}
+ history := geminiMessages[:len(geminiMessages)-1] // All but last message
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
+ MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ SystemInstruction: &genai.Content{
+ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
+ },
+ Tools: g.convertTools(tools),
+ }, history)
+
attempts := 0
eventChan := make(chan ProviderEvent)
@@ -274,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
for {
attempts++
- chat := model.StartChat()
- chat.History = geminiMessages[:len(geminiMessages)-1]
- lastMsg := geminiMessages[len(geminiMessages)-1]
-
- iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
currentContent := ""
toolCalls := []message.ToolCall{}
@@ -286,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
eventChan <- ProviderEvent{Type: EventContentStart}
- for {
- resp, err := iter.Next()
- if err == iterator.Done {
- break
- }
+ var lastMsgParts []genai.Part
+
+ for _, part := range lastMsg.Parts {
+ lastMsgParts = append(lastMsgParts, *part)
+ }
+ for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
if err != nil {
retry, after, retryErr := g.shouldRetry(attempts, err)
if retryErr != nil {
@@ -319,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- delta := string(p)
+ switch {
+ case part.Text != "":
+ delta := string(part.Text)
if delta != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -329,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
}
currentContent += delta
}
- case genai.FunctionCall:
+ case part.FunctionCall != nil:
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
newCall := message.ToolCall{
ID: id,
- Name: p.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
Finished: true,
@@ -422,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
- if funcCall, ok := part.(genai.FunctionCall); ok {
+ if part.FunctionCall != nil {
id := "call_" + uuid.New().String()
- args, _ := json.Marshal(funcCall.Args)
+ args, _ := json.Marshal(part.FunctionCall.Args)
toolCalls = append(toolCalls, message.ToolCall{
ID: id,
- Name: funcCall.Name,
+ Name: part.FunctionCall.Name,
Input: string(args),
Type: "function",
})