diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-09 17:45:41 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-09 17:45:41 +0200 |
| commit | 939ae03f42e61d0944da80381219e6bbdfc2d850 (patch) | |
| tree | 3c45cdad120f4e799e92f0a83fad97f8a025cf4f /internal/llm | |
| parent | fde04bbf85ea641a33a282b354d63f227f9945fb (diff) | |
| download | opencode-939ae03f42e61d0944da80381219e6bbdfc2d850.tar.gz opencode-939ae03f42e61d0944da80381219e6bbdfc2d850.zip | |
add bedrock support
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/agent/agent.go | 23 | ||||
| -rw-r--r-- | internal/llm/models/models.go | 16 | ||||
| -rw-r--r-- | internal/llm/provider/anthropic.go | 33 | ||||
| -rw-r--r-- | internal/llm/provider/bedrock.go | 87 |
4 files changed, 154 insertions, 5 deletions
diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index baf78be65..78062d060 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid return nil, nil, err } + case models.ProviderBedrock: + var err error + agentProvider, err = provider.NewBedrockProvider( + provider.WithBedrockSystemMessage( + prompt.CoderAnthropicSystemPrompt(), + ), + provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockModel(model), + ) + if err != nil { + return nil, nil, err + } + titleGenerator, err = provider.NewBedrockProvider( + provider.WithBedrockSystemMessage( + prompt.TitlePrompt(), + ), + provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockModel(model), + ) + if err != nil { + return nil, nil, err + } + } return agentProvider, titleGenerator, nil diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 2f75db9c8..4791218c4 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -31,11 +31,15 @@ const ( // GROQ QWENQwq ModelID = "qwen-qwq" + + // Bedrock + BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet" ) const ( ProviderOpenAI ModelProvider = "openai" ProviderAnthropic ModelProvider = "anthropic" + ProviderBedrock ModelProvider = "bedrock" ProviderGemini ModelProvider = "gemini" ProviderGROQ ModelProvider = "groq" ) @@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 0, }, + + // Bedrock + BedrockClaude37Sonnet: { + ID: BedrockClaude37Sonnet, + Name: "Bedrock: Claude 3.7 Sonnet", + Provider: ProviderBedrock, + APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + }, } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 02bd572f1..625976a95 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -9,6 +9,7 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" @@ -21,6 +22,8 @@ type anthropicProvider struct { maxTokens int64 apiKey string systemMessage string + useBedrock bool + disableCache bool } type AnthropicOption func(*anthropicProvider) @@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption { } } +func WithAnthropicBedrock() AnthropicOption { + return func(a *anthropicProvider) { + a.useBedrock = true + } +} + +func WithAnthropicDisableCache() AnthropicOption { + return func(a *anthropicProvider) { + a.disableCache = true + } +} + func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { provider := &anthropicProvider{ maxTokens: 1024, @@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { return nil, errors.New("system message is required") } - provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey)) + anthropicOptions := []option.RequestOption{} + + if provider.apiKey != "" { + anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) + } + if provider.useBedrock { + anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + } + + provider.client = anthropic.NewClient(anthropicOptions...) return provider, nil } @@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an }, } - if i == len(tools)-1 { + if i == len(tools)-1 && !a.disableCache { toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 { + if cachedBlocks < 2 && !a.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag blocks := []anthropic.ContentBlockParamUnion{} if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 { + if cachedBlocks < 2 && !a.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag return anthropicMessages } - diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go new file mode 100644 index 000000000..f1afefdc4 --- /dev/null +++ b/internal/llm/provider/bedrock.go @@ -0,0 +1,87 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/message" +) + +type bedrockProvider struct { + childProvider Provider + model models.Model + maxTokens int64 + systemMessage string +} + +func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + return b.childProvider.SendMessages(ctx, messages, tools) +} + +func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { + return b.childProvider.StreamResponse(ctx, messages, tools) +} + +func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { + provider := &bedrockProvider{} + for _, opt := range opts { + opt(provider) + } + + // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + region := os.Getenv("AWS_REGION") + if region == "" { + region = os.Getenv("AWS_DEFAULT_REGION") + } + + if region == "" { + return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required") + } + if len(region) < 2 { + return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + } + regionPrefix := region[:2] + provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + + if strings.Contains(string(provider.model.APIModel), "anthropic") { + anthropic, err := NewAnthropicProvider( + WithAnthropicModel(provider.model), + WithAnthropicMaxTokens(provider.maxTokens), + WithAnthropicSystemMessage(provider.systemMessage), + WithAnthropicBedrock(), + WithAnthropicDisableCache(), + ) + provider.childProvider = anthropic + if err != nil { + return nil, err + } + } else { + return nil, errors.New("unsupported model for bedrock provider") + } + return provider, nil +} + +type BedrockOption func(*bedrockProvider) + +func WithBedrockSystemMessage(message string) BedrockOption { + return func(a *bedrockProvider) { + a.systemMessage = message + } +} + +func WithBedrockMaxTokens(maxTokens int64) BedrockOption { + return func(a *bedrockProvider) { + a.maxTokens = maxTokens + } +} + +func WithBedrockModel(model models.Model) BedrockOption { + return func(a *bedrockProvider) { + a.model = model + } +} |
