diff options
| author | mineo <[email protected]> | 2025-05-09 21:15:38 +0900 |
|---|---|---|
| committer | adamdottv <[email protected]> | 2025-05-12 09:56:30 -0500 |
| commit | f92b2b76dc0836b8ad9f4a47a16941efdb2accf6 (patch) | |
| tree | d4299212ede8c9d3a2734a02317944bae43dfd8c /internal | |
| parent | 1d1a1ddcbf2ce5bca04fc8ccc6877b2c1c93ef59 (diff) | |
| download | opencode-f92b2b76dc0836b8ad9f4a47a16941efdb2accf6.tar.gz opencode-f92b2b76dc0836b8ad9f4a47a16941efdb2accf6.zip | |
replace `github.com/google/generative-ai-go` with `github.com/googleapis/go-genai` (#138)
* replace to github.com/googleapis/go-genai
* fix history logic
* small fixes
---------
Co-authored-by: Kujtim Hoxha <[email protected]>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/llm/provider/gemini.go | 139 |
1 files changed, 69 insertions, 70 deletions
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index c37aee4b6..05ce76e2c 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -9,14 +9,12 @@ import ( "strings" "time" - "github.com/google/generative-ai-go/genai" "github.com/google/uuid" "github.com/opencode-ai/opencode/internal/config" "github.com/opencode-ai/opencode/internal/llm/tools" "github.com/opencode-ai/opencode/internal/message" "github.com/opencode-ai/opencode/internal/status" - "google.golang.org/api/iterator" - "google.golang.org/api/option" + "google.golang.org/genai" "log/slog" ) @@ -40,7 +38,7 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { o(&geminiOpts) } - client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey)) + client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) if err != nil { slog.Error("Failed to create Gemini client", "error", err) return nil @@ -58,11 +56,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont for _, msg := range messages { switch msg.Role { case message.User: - var parts []genai.Part - parts = append(parts, genai.Text(msg.Content().String())) + var parts []*genai.Part + parts = append(parts, &genai.Part{Text: msg.Content().String()}) for _, binaryContent := range msg.BinaryContent() { imageFormat := strings.Split(binaryContent.MIMEType, "/") - parts = append(parts, genai.ImageData(imageFormat[1], binaryContent.Data)) + parts = append(parts, &genai.Part{InlineData: &genai.Blob{ + MIMEType: imageFormat[1], + Data: binaryContent.Data, + }}) } history = append(history, &genai.Content{ Parts: parts, @@ -71,19 +72,21 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont case message.Assistant: content := &genai.Content{ Role: "model", - Parts: []genai.Part{}, + Parts: []*genai.Part{}, } if msg.Content().String() != "" { - content.Parts = append(content.Parts, genai.Text(msg.Content().String())) + content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()}) } if len(msg.ToolCalls()) > 0 { for _, call := range msg.ToolCalls() { args, _ := parseJsonToMap(call.Input) - content.Parts = append(content.Parts, genai.FunctionCall{ - Name: call.Name, - Args: args, + content.Parts = append(content.Parts, &genai.Part{ + FunctionCall: &genai.FunctionCall{ + Name: call.Name, + Args: args, + }, }) } } @@ -111,10 +114,14 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont } history = append(history, &genai.Content{ - Parts: []genai.Part{genai.FunctionResponse{ - Name: toolCall.Name, - Response: response, - }}, + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + Name: toolCall.Name, + Response: response, + }, + }, + }, Role: "function", }) } @@ -158,18 +165,6 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea } func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - model := g.client.GenerativeModel(g.providerOptions.model.APIModel) - model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{ - genai.Text(g.providerOptions.systemMessage), - }, - } - // Convert tools - if len(tools) > 0 { - model.Tools = g.convertTools(tools) - } - // Convert messages geminiMessages := g.convertMessages(messages) @@ -179,16 +174,26 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too slog.Debug("Prepared messages", "messages", string(jsonData)) } + history := geminiMessages[:len(geminiMessages)-1] // All but last message + lastMsg := geminiMessages[len(geminiMessages)-1] + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{ + MaxOutputTokens: int32(g.providerOptions.maxTokens), + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + }, + Tools: g.convertTools(tools), + }, history) + attempts := 0 for { attempts++ var toolCalls []message.ToolCall - chat := model.StartChat() - chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message - - lastMsg := geminiMessages[len(geminiMessages)-1] - resp, err := chat.SendMessage(ctx, lastMsg.Parts...) + var lastMsgParts []genai.Part + for _, part := range lastMsg.Parts { + lastMsgParts = append(lastMsgParts, *part) + } + resp, err := chat.SendMessage(ctx, lastMsgParts...) // If there is an error we are going to see if we can retry the call if err != nil { retry, after, retryErr := g.shouldRetry(attempts, err) @@ -211,15 +216,15 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - content = string(p) - case genai.FunctionCall: + switch { + case part.Text != "": + content = string(part.Text) + case part.FunctionCall != nil: id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) + args, _ := json.Marshal(part.FunctionCall.Args) toolCalls = append(toolCalls, message.ToolCall{ ID: id, - Name: p.Name, + Name: part.FunctionCall.Name, Input: string(args), Type: "function", Finished: true, @@ -245,18 +250,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too } func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - model := g.client.GenerativeModel(g.providerOptions.model.APIModel) - model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{ - genai.Text(g.providerOptions.systemMessage), - }, - } - // Convert tools - if len(tools) > 0 { - model.Tools = g.convertTools(tools) - } - // Convert messages geminiMessages := g.convertMessages(messages) @@ -266,6 +259,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t slog.Debug("Prepared messages", "messages", string(jsonData)) } + history := geminiMessages[:len(geminiMessages)-1] // All but last message + lastMsg := geminiMessages[len(geminiMessages)-1] + chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{ + MaxOutputTokens: int32(g.providerOptions.maxTokens), + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + }, + Tools: g.convertTools(tools), + }, history) + attempts := 0 eventChan := make(chan ProviderEvent) @@ -274,11 +277,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t for { attempts++ - chat := model.StartChat() - chat.History = geminiMessages[:len(geminiMessages)-1] - lastMsg := geminiMessages[len(geminiMessages)-1] - - iter := chat.SendMessageStream(ctx, lastMsg.Parts...) currentContent := "" toolCalls := []message.ToolCall{} @@ -286,11 +284,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t eventChan <- ProviderEvent{Type: EventContentStart} - for { - resp, err := iter.Next() - if err == iterator.Done { - break - } + var lastMsgParts []genai.Part + + for _, part := range lastMsg.Parts { + lastMsgParts = append(lastMsgParts, *part) + } + for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) { if err != nil { retry, after, retryErr := g.shouldRetry(attempts, err) if retryErr != nil { @@ -319,9 +318,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - delta := string(p) + switch { + case part.Text != "": + delta := string(part.Text) if delta != "" { eventChan <- ProviderEvent{ Type: EventContentDelta, @@ -329,12 +328,12 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t } currentContent += delta } - case genai.FunctionCall: + case part.FunctionCall != nil: id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) + args, _ := json.Marshal(part.FunctionCall.Args) newCall := message.ToolCall{ ID: id, - Name: p.Name, + Name: part.FunctionCall.Name, Input: string(args), Type: "function", Finished: true, @@ -422,12 +421,12 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message. if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { for _, part := range resp.Candidates[0].Content.Parts { - if funcCall, ok := part.(genai.FunctionCall); ok { + if part.FunctionCall != nil { id := "call_" + uuid.New().String() - args, _ := json.Marshal(funcCall.Args) + args, _ := json.Marshal(part.FunctionCall.Args) toolCalls = append(toolCalls, message.ToolCall{ ID: id, - Name: funcCall.Name, + Name: part.FunctionCall.Name, Input: string(args), Type: "function", }) |
