summaryrefslogtreecommitdiffhomepage
path: root/internal/llm
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-04-09 17:45:41 +0200
committerKujtim Hoxha <[email protected]>2025-04-09 17:45:41 +0200
commit939ae03f42e61d0944da80381219e6bbdfc2d850 (patch)
tree3c45cdad120f4e799e92f0a83fad97f8a025cf4f /internal/llm
parentfde04bbf85ea641a33a282b354d63f227f9945fb (diff)
downloadopencode-939ae03f42e61d0944da80381219e6bbdfc2d850.tar.gz
opencode-939ae03f42e61d0944da80381219e6bbdfc2d850.zip
add bedrock support
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/agent/agent.go23
-rw-r--r--internal/llm/models/models.go16
-rw-r--r--internal/llm/provider/anthropic.go33
-rw-r--r--internal/llm/provider/bedrock.go87
4 files changed, 154 insertions, 5 deletions
diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go
index baf78be65..78062d060 100644
--- a/internal/llm/agent/agent.go
+++ b/internal/llm/agent/agent.go
@@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
return nil, nil, err
}
+ case models.ProviderBedrock:
+ var err error
+ agentProvider, err = provider.NewBedrockProvider(
+ provider.WithBedrockSystemMessage(
+ prompt.CoderAnthropicSystemPrompt(),
+ ),
+ provider.WithBedrockMaxTokens(maxTokens),
+ provider.WithBedrockModel(model),
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+ titleGenerator, err = provider.NewBedrockProvider(
+ provider.WithBedrockSystemMessage(
+ prompt.TitlePrompt(),
+ ),
+ provider.WithBedrockMaxTokens(maxTokens),
+ provider.WithBedrockModel(model),
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
}
return agentProvider, titleGenerator, nil
diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go
index 2f75db9c8..4791218c4 100644
--- a/internal/llm/models/models.go
+++ b/internal/llm/models/models.go
@@ -31,11 +31,15 @@ const (
// GROQ
QWENQwq ModelID = "qwen-qwq"
+
+ // Bedrock
+ BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
)
const (
ProviderOpenAI ModelProvider = "openai"
ProviderAnthropic ModelProvider = "anthropic"
+ ProviderBedrock ModelProvider = "bedrock"
ProviderGemini ModelProvider = "gemini"
ProviderGROQ ModelProvider = "groq"
)
@@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{
CostPer1MOutCached: 0,
CostPer1MOut: 0,
},
+
+ // Bedrock
+ BedrockClaude37Sonnet: {
+ ID: BedrockClaude37Sonnet,
+ Name: "Bedrock: Claude 3.7 Sonnet",
+ Provider: ProviderBedrock,
+ APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ CostPer1MIn: 3.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.30,
+ CostPer1MOut: 15.0,
+ },
}
diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go
index 02bd572f1..625976a95 100644
--- a/internal/llm/provider/anthropic.go
+++ b/internal/llm/provider/anthropic.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/anthropics/anthropic-sdk-go"
+ "github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
@@ -21,6 +22,8 @@ type anthropicProvider struct {
maxTokens int64
apiKey string
systemMessage string
+ useBedrock bool
+ disableCache bool
}
type AnthropicOption func(*anthropicProvider)
@@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption {
}
}
+func WithAnthropicBedrock() AnthropicOption {
+ return func(a *anthropicProvider) {
+ a.useBedrock = true
+ }
+}
+
+func WithAnthropicDisableCache() AnthropicOption {
+ return func(a *anthropicProvider) {
+ a.disableCache = true
+ }
+}
+
func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
provider := &anthropicProvider{
maxTokens: 1024,
@@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
return nil, errors.New("system message is required")
}
- provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
+ anthropicOptions := []option.RequestOption{}
+
+ if provider.apiKey != "" {
+ anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
+ }
+ if provider.useBedrock {
+ anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
+ }
+
+ provider.client = anthropic.NewClient(anthropicOptions...)
return provider, nil
}
@@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an
},
}
- if i == len(tools)-1 {
+ if i == len(tools)-1 && !a.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 {
+ if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 {
+ if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
return anthropicMessages
}
-
diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go
new file mode 100644
index 000000000..f1afefdc4
--- /dev/null
+++ b/internal/llm/provider/bedrock.go
@@ -0,0 +1,87 @@
+package provider
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/message"
+)
+
+type bedrockProvider struct {
+ childProvider Provider
+ model models.Model
+ maxTokens int64
+ systemMessage string
+}
+
+func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ return b.childProvider.SendMessages(ctx, messages, tools)
+}
+
+func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
+ return b.childProvider.StreamResponse(ctx, messages, tools)
+}
+
+func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
+ provider := &bedrockProvider{}
+ for _, opt := range opts {
+ opt(provider)
+ }
+
+ // based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
+ region := os.Getenv("AWS_REGION")
+ if region == "" {
+ region = os.Getenv("AWS_DEFAULT_REGION")
+ }
+
+ if region == "" {
+ return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required")
+ }
+ if len(region) < 2 {
+ return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
+ }
+ regionPrefix := region[:2]
+ provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
+
+ if strings.Contains(string(provider.model.APIModel), "anthropic") {
+ anthropic, err := NewAnthropicProvider(
+ WithAnthropicModel(provider.model),
+ WithAnthropicMaxTokens(provider.maxTokens),
+ WithAnthropicSystemMessage(provider.systemMessage),
+ WithAnthropicBedrock(),
+ WithAnthropicDisableCache(),
+ )
+ provider.childProvider = anthropic
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ return nil, errors.New("unsupported model for bedrock provider")
+ }
+ return provider, nil
+}
+
+type BedrockOption func(*bedrockProvider)
+
+func WithBedrockSystemMessage(message string) BedrockOption {
+ return func(a *bedrockProvider) {
+ a.systemMessage = message
+ }
+}
+
+func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
+ return func(a *bedrockProvider) {
+ a.maxTokens = maxTokens
+ }
+}
+
+func WithBedrockModel(model models.Model) BedrockOption {
+ return func(a *bedrockProvider) {
+ a.model = model
+ }
+}