diff options
| author | Kujtim Hoxha <[email protected]> | 2025-03-24 11:47:39 +0100 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-03-24 11:47:39 +0100 |
| commit | 005b8ac16776512b2d4b1f22bd989da162ca1bad (patch) | |
| tree | cfe0d1da344ac31a467f1bea788ce80c723cd980 /internal/llm | |
| parent | e7258e38aeb46281fda474b8b7fcc3eee35edd9f (diff) | |
| download | opencode-005b8ac16776512b2d4b1f22bd989da162ca1bad.tar.gz opencode-005b8ac16776512b2d4b1f22bd989da162ca1bad.zip | |
initial working agent
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/agent/title.go | 31 | ||||
| -rw-r--r-- | internal/llm/llm.go | 31 | ||||
| -rw-r--r-- | internal/llm/models/models.go | 48 |
3 files changed, 99 insertions, 11 deletions
diff --git a/internal/llm/agent/title.go b/internal/llm/agent/title.go new file mode 100644 index 000000000..1b9840cc2 --- /dev/null +++ b/internal/llm/agent/title.go @@ -0,0 +1,31 @@ +package agent + +import ( + "context" + + "github.com/cloudwego/eino/schema" + "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/spf13/viper" +) + +func GenerateTitle(ctx context.Context, content string) (string, error) { + model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small"))) + if err != nil { + return "", err + } + out, err := model.Generate( + ctx, + []*schema.Message{ + schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with + - ensure it is not more than 80 characters long + - the title should be a summary of the user's message + - do not use quotes or colons + - the entire text you return will be used as the title`), + schema.UserMessage(content), + }, + ) + if err != nil { + return "", err + } + return out.Content, nil +} diff --git a/internal/llm/llm.go b/internal/llm/llm.go index bbf9961ea..2f87b225e 100644 --- a/internal/llm/llm.go +++ b/internal/llm/llm.go @@ -11,6 +11,7 @@ import ( "github.com/cloudwego/eino/schema" "github.com/google/uuid" "github.com/kujtimiihoxha/termai/internal/llm/agent" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" @@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { } log.Printf("Request: %s", content) - agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default")) + currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default")) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ ID: id, @@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { for _, m := range history { messages = append(messages, &m.MessageData) } + builder := callbacks.NewHandlerBuilder() builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { i, ok := input.(*eModel.CallbackInput) @@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return ctx }) - out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build()))) + out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build()))) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ ID: id, @@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return } usage := out.ResponseMeta.Usage + s.messages.Create(sessionID, *out) if usage != nil { log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens) session, err := s.sessions.Get(sessionID) @@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) { session.PromptTokens += int64(usage.PromptTokens) session.CompletionTokens += int64(usage.CompletionTokens) // TODO: calculate cost + model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))] + session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) + + float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000) + var newTitle string + if len(history) == 1 { + // first message generate the title + newTitle, err = agent.GenerateTitle(s.ctx, content) + if err != nil { + s.Publish(AgentErrorEvent, AgentEvent{ + ID: id, + Type: AgentMessageTypeError, + AgentID: RootAgent, + MessageID: "", + SessionID: sessionID, + Content: err.Error(), + }) + return + } + } + if newTitle != "" { + session.Title = newTitle + } + _, err = s.sessions.Save(session) if err != nil { s.Publish(AgentErrorEvent, AgentEvent{ @@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) { return } } - s.messages.Create(sessionID, *out) } func (s *service) SendRequest(sessionID string, content string) { diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 1895e256b..e59da194b 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -3,6 +3,7 @@ package models import ( "context" "errors" + "log" "github.com/cloudwego/eino-ext/components/model/claude" "github.com/cloudwego/eino-ext/components/model/openai" @@ -16,10 +17,12 @@ type ( ) type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider ModelProvider `json:"provider"` - APIModel string `json:"api_model"` // Actual value used when calling the API + ID ModelID `json:"id"` + Name string `json:"name"` + Provider ModelProvider `json:"provider"` + APIModel string `json:"api_model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` } const ( @@ -52,6 +55,9 @@ const ( // Meta Llama3 ModelID = "llama-3" Llama270B ModelID = "llama-2-70b" + // GROQ + GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec" + GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b" ) const ( @@ -61,6 +67,7 @@ const ( ProviderXAI ModelProvider = "xai" ProviderDeepSeek ModelProvider = "deepseek" ProviderMeta ModelProvider = "meta" + ProviderGroq ModelProvider = "groq" ) var SupportedModels = map[ModelID]Model{ @@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{ APIModel: "gpt-4o", }, GPT4oMini: { - ID: GPT4oMini, - Name: "GPT-4o Mini", - Provider: ProviderOpenAI, - APIModel: "gpt-4o-mini", + ID: GPT4oMini, + Name: "GPT-4o Mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4o-mini", + CostPer1MIn: 0.150, + CostPer1MOut: 0.600, }, GPT45: { ID: GPT45, @@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{ Provider: ProviderMeta, APIModel: "llama-2-70b", }, + + // GROQ + GroqLlama3SpecDec: { + ID: GroqLlama3SpecDec, + Name: "GROQ LLaMA 3 SpecDec", + Provider: ProviderGroq, + APIModel: "llama-3.3-70b-specdec", + }, + GroqQwen32BCoder: { + ID: GroqQwen32BCoder, + Name: "GROQ Qwen 2.5 Coder 32B", + Provider: ProviderGroq, + APIModel: "qwen-2.5-coder-32b", + }, } func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) { provider := SupportedModels[model].Provider + log.Printf("Provider: %s", provider) maxTokens := viper.GetInt("providers.common.max_tokens") switch provider { case ProviderOpenAI: @@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) { MaxTokens: maxTokens, }) + case ProviderGroq: + return openai.NewChatModel(ctx, &openai.ChatModelConfig{ + BaseURL: "https://api.groq.com/openai/v1", + APIKey: viper.GetString("providers.groq.key"), + Model: string(SupportedModels[model].APIModel), + MaxTokens: &maxTokens, + }) + } return nil, errors.New("unsupported provider") } |
