summaryrefslogtreecommitdiffhomepage
path: root/internal/llm
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-03-24 11:47:39 +0100
committerKujtim Hoxha <[email protected]>2025-03-24 11:47:39 +0100
commit005b8ac16776512b2d4b1f22bd989da162ca1bad (patch)
treecfe0d1da344ac31a467f1bea788ce80c723cd980 /internal/llm
parente7258e38aeb46281fda474b8b7fcc3eee35edd9f (diff)
downloadopencode-005b8ac16776512b2d4b1f22bd989da162ca1bad.tar.gz
opencode-005b8ac16776512b2d4b1f22bd989da162ca1bad.zip
initial working agent
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/agent/title.go31
-rw-r--r--internal/llm/llm.go31
-rw-r--r--internal/llm/models/models.go48
3 files changed, 99 insertions, 11 deletions
diff --git a/internal/llm/agent/title.go b/internal/llm/agent/title.go
new file mode 100644
index 000000000..1b9840cc2
--- /dev/null
+++ b/internal/llm/agent/title.go
@@ -0,0 +1,31 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/cloudwego/eino/schema"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/spf13/viper"
+)
+
+func GenerateTitle(ctx context.Context, content string) (string, error) {
+ model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
+ if err != nil {
+ return "", err
+ }
+ out, err := model.Generate(
+ ctx,
+ []*schema.Message{
+ schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
+ - ensure it is not more than 80 characters long
+ - the title should be a summary of the user's message
+ - do not use quotes or colons
+ - the entire text you return will be used as the title`),
+ schema.UserMessage(content),
+ },
+ )
+ if err != nil {
+ return "", err
+ }
+ return out.Content, nil
+}
diff --git a/internal/llm/llm.go b/internal/llm/llm.go
index bbf9961ea..2f87b225e 100644
--- a/internal/llm/llm.go
+++ b/internal/llm/llm.go
@@ -11,6 +11,7 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
}
log.Printf("Request: %s", content)
- agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
+ currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
for _, m := range history {
messages = append(messages, &m.MessageData)
}
+
builder := callbacks.NewHandlerBuilder()
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
i, ok := input.(*eModel.CallbackInput)
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return ctx
})
- out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
+ out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
usage := out.ResponseMeta.Usage
+ s.messages.Create(sessionID, *out)
if usage != nil {
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
session, err := s.sessions.Get(sessionID)
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
session.PromptTokens += int64(usage.PromptTokens)
session.CompletionTokens += int64(usage.CompletionTokens)
// TODO: calculate cost
+ model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
+ session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
+ float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
+ var newTitle string
+ if len(history) == 1 {
+ // first message generate the title
+ newTitle, err = agent.GenerateTitle(s.ctx, content)
+ if err != nil {
+ s.Publish(AgentErrorEvent, AgentEvent{
+ ID: id,
+ Type: AgentMessageTypeError,
+ AgentID: RootAgent,
+ MessageID: "",
+ SessionID: sessionID,
+ Content: err.Error(),
+ })
+ return
+ }
+ }
+ if newTitle != "" {
+ session.Title = newTitle
+ }
+
_, err = s.sessions.Save(session)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
}
- s.messages.Create(sessionID, *out)
}
func (s *service) SendRequest(sessionID string, content string) {
diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go
index 1895e256b..e59da194b 100644
--- a/internal/llm/models/models.go
+++ b/internal/llm/models/models.go
@@ -3,6 +3,7 @@ package models
import (
"context"
"errors"
+ "log"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -16,10 +17,12 @@ type (
)
type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider ModelProvider `json:"provider"`
- APIModel string `json:"api_model"` // Actual value used when calling the API
+ ID ModelID `json:"id"`
+ Name string `json:"name"`
+ Provider ModelProvider `json:"provider"`
+ APIModel string `json:"api_model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
}
const (
@@ -52,6 +55,9 @@ const (
// Meta
Llama3 ModelID = "llama-3"
Llama270B ModelID = "llama-2-70b"
+ // GROQ
+ GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
+ GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
)
const (
@@ -61,6 +67,7 @@ const (
ProviderXAI ModelProvider = "xai"
ProviderDeepSeek ModelProvider = "deepseek"
ProviderMeta ModelProvider = "meta"
+ ProviderGroq ModelProvider = "groq"
)
var SupportedModels = map[ModelID]Model{
@@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
APIModel: "gpt-4o",
},
GPT4oMini: {
- ID: GPT4oMini,
- Name: "GPT-4o Mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o-mini",
+ ID: GPT4oMini,
+ Name: "GPT-4o Mini",
+ Provider: ProviderOpenAI,
+ APIModel: "gpt-4o-mini",
+ CostPer1MIn: 0.150,
+ CostPer1MOut: 0.600,
},
GPT45: {
ID: GPT45,
@@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
Provider: ProviderMeta,
APIModel: "llama-2-70b",
},
+
+ // GROQ
+ GroqLlama3SpecDec: {
+ ID: GroqLlama3SpecDec,
+ Name: "GROQ LLaMA 3 SpecDec",
+ Provider: ProviderGroq,
+ APIModel: "llama-3.3-70b-specdec",
+ },
+ GroqQwen32BCoder: {
+ ID: GroqQwen32BCoder,
+ Name: "GROQ Qwen 2.5 Coder 32B",
+ Provider: ProviderGroq,
+ APIModel: "qwen-2.5-coder-32b",
+ },
}
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
provider := SupportedModels[model].Provider
+ log.Printf("Provider: %s", provider)
maxTokens := viper.GetInt("providers.common.max_tokens")
switch provider {
case ProviderOpenAI:
@@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
MaxTokens: maxTokens,
})
+ case ProviderGroq:
+ return openai.NewChatModel(ctx, &openai.ChatModelConfig{
+ BaseURL: "https://api.groq.com/openai/v1",
+ APIKey: viper.GetString("providers.groq.key"),
+ Model: string(SupportedModels[model].APIModel),
+ MaxTokens: &maxTokens,
+ })
+
}
return nil, errors.New("unsupported provider")
}