summaryrefslogtreecommitdiffhomepage
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/agent/agent-tool.go18
-rw-r--r--internal/llm/agent/agent.go861
-rw-r--r--internal/llm/agent/coder.go63
-rw-r--r--internal/llm/agent/mcp-tools.go4
-rw-r--r--internal/llm/agent/task.go47
-rw-r--r--internal/llm/agent/tools.go50
-rw-r--r--internal/llm/models/anthropic.go71
-rw-r--r--internal/llm/models/models.go190
-rw-r--r--internal/llm/prompt/coder.go28
-rw-r--r--internal/llm/prompt/prompt.go19
-rw-r--r--internal/llm/prompt/task.go5
-rw-r--r--internal/llm/prompt/title.go4
-rw-r--r--internal/llm/provider/anthropic.go531
-rw-r--r--internal/llm/provider/bedrock.go101
-rw-r--r--internal/llm/provider/gemini.go533
-rw-r--r--internal/llm/provider/openai.go401
-rw-r--r--internal/llm/provider/provider.go169
-rw-r--r--internal/llm/tools/bash.go7
-rw-r--r--internal/llm/tools/bash_test.go31
-rw-r--r--internal/llm/tools/edit.go75
-rw-r--r--internal/llm/tools/edit_test.go30
-rw-r--r--internal/llm/tools/file.go10
-rw-r--r--internal/llm/tools/glob.go4
-rw-r--r--internal/llm/tools/grep.go4
-rw-r--r--internal/llm/tools/ls.go4
-rw-r--r--internal/llm/tools/mocks_test.go246
-rw-r--r--internal/llm/tools/shell/shell.go12
-rw-r--r--internal/llm/tools/sourcegraph.go2
-rw-r--r--internal/llm/tools/tools.go9
-rw-r--r--internal/llm/tools/write.go27
-rw-r--r--internal/llm/tools/write_test.go22
31 files changed, 2037 insertions, 1541 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go
index 83160bb64..308412bde 100644
--- a/internal/llm/agent/agent-tool.go
+++ b/internal/llm/agent/agent-tool.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/message"
@@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
}
- agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients)
+ agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
}
@@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
}
- err = agent.Generate(ctx, session.ID, params.Prompt)
+ done, err := agent.Run(ctx, session.ID, params.Prompt)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
}
-
- messages, err := b.messages.List(ctx, session.ID)
- if err != nil {
- return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err)
- }
-
- if len(messages) == 0 {
- return tools.NewTextErrorResponse("no response"), nil
+ result := <-done
+ if result.Err() != nil {
+ return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err())
}
- response := messages[len(messages)-1]
+ response := result.Response()
if response.Role != message.Assistant {
return tools.NewTextErrorResponse("no response"), nil
}
diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go
index 1958111a1..ab2742ec1 100644
--- a/internal/llm/agent/agent.go
+++ b/internal/llm/agent/agent.go
@@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
- "os"
- "runtime/debug"
"strings"
"sync"
@@ -16,133 +14,101 @@ import (
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
"github.com/kujtimiihoxha/termai/internal/session"
)
// Common errors
var (
- ErrProviderNotEnabled = errors.New("provider is not enabled")
- ErrRequestCancelled = errors.New("request cancelled by user")
- ErrSessionBusy = errors.New("session is currently processing another request")
+ ErrRequestCancelled = errors.New("request cancelled by user")
+ ErrSessionBusy = errors.New("session is currently processing another request")
)
-// Service defines the interface for generating responses
+type AgentEvent struct {
+ message message.Message
+ err error
+}
+
+func (e *AgentEvent) Err() error {
+ return e.err
+}
+
+func (e *AgentEvent) Response() message.Message {
+ return e.message
+}
+
type Service interface {
- Generate(ctx context.Context, sessionID string, content string) error
- Cancel(sessionID string) error
+ Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
+ Cancel(sessionID string)
+ IsSessionBusy(sessionID string) bool
}
type agent struct {
- sessions session.Service
- messages message.Service
- model models.Model
- tools []tools.BaseTool
- agent provider.Provider
- titleGenerator provider.Provider
- activeRequests sync.Map // map[sessionID]context.CancelFunc
+ sessions session.Service
+ messages message.Service
+
+ tools []tools.BaseTool
+ provider provider.Provider
+
+ titleProvider provider.Provider
+
+ activeRequests sync.Map
}
-// NewAgent creates a new agent instance with the given model and tools
-func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
- agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+func NewAgent(
+ agentName config.AgentName,
+ sessions session.Service,
+ messages message.Service,
+ agentTools []tools.BaseTool,
+) (Service, error) {
+ agentProvider, err := createAgentProvider(agentName)
if err != nil {
- return nil, fmt.Errorf("failed to initialize providers: %w", err)
+ return nil, err
+ }
+ var titleProvider provider.Provider
+ // Only generate titles for the coder agent
+ if agentName == config.AgentCoder {
+ titleProvider, err = createAgentProvider(config.AgentTitle)
+ if err != nil {
+ return nil, err
+ }
}
- return &agent{
- model: model,
- tools: tools,
- sessions: sessions,
+ agent := &agent{
+ provider: agentProvider,
messages: messages,
- agent: agentProvider,
- titleGenerator: titleGenerator,
+ sessions: sessions,
+ tools: agentTools,
+ titleProvider: titleProvider,
activeRequests: sync.Map{},
- }, nil
+ }
+
+ return agent, nil
}
-// Cancel cancels an active request by session ID
-func (a *agent) Cancel(sessionID string) error {
+func (a *agent) Cancel(sessionID string) {
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
cancel()
- return nil
}
}
- return errors.New("no active request found for this session")
}
-// Generate starts the generation process
-func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
- // Check if this session already has an active request
- if _, busy := a.activeRequests.Load(sessionID); busy {
- return ErrSessionBusy
- }
-
- // Create a cancellable context
- genCtx, cancel := context.WithCancel(ctx)
-
- // Store cancel function to allow user cancellation
- a.activeRequests.Store(sessionID, cancel)
-
- // Launch the generation in a goroutine
- go func() {
- defer func() {
- if r := recover(); r != nil {
- logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
-
- // dump stack trace into a file
- file, err := os.Create("panic.log")
- if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
- return
- }
-
- defer file.Close()
-
- stackTrace := debug.Stack()
- if _, err := file.Write(stackTrace); err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err))
- }
-
- }
- }()
- defer a.activeRequests.Delete(sessionID)
- defer cancel()
-
- if err := a.generate(genCtx, sessionID, content); err != nil {
- if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
- // Log the error (avoid logging cancellations as they're expected)
- logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
-
- // You may want to create an error message in the chat
- bgCtx := context.Background()
- errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
- _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
- Role: message.System,
- Parts: []message.ContentPart{
- message.TextContent{
- Text: errorMsg,
- },
- },
- })
- if createErr != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
- }
- }
- }
- }()
-
- return nil
-}
-
-// IsSessionBusy checks if a session currently has an active request
func (a *agent) IsSessionBusy(sessionID string) bool {
_, busy := a.activeRequests.Load(sessionID)
return busy
-} // handleTitleGeneration asynchronously generates a title for new sessions
-func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
- response, err := a.titleGenerator.SendMessages(
+}
+
+func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
+ if a.titleProvider == nil {
+ return nil
+ }
+ session, err := a.sessions.Get(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+ response, err := a.titleProvider.SendMessages(
ctx,
[]message.Message{
{
@@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
},
},
},
- nil,
+ make([]tools.BaseTool, 0),
)
if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
- return
+ return err
}
- session, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
- return
+ title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
+ if title == "" {
+ return nil
}
- if response.Content != "" {
- session.Title = strings.TrimSpace(response.Content)
- session.Title = strings.ReplaceAll(session.Title, "\n", " ")
- if _, err := a.sessions.Save(ctx, session); err != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
- }
+ session.Title = title
+ _, err = a.sessions.Save(ctx, session)
+ return err
+}
+
+func (a *agent) err(err error) AgentEvent {
+ return AgentEvent{
+ err: err,
}
}
-// TrackUsage updates token usage statistics for the session
-func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
- session, err := a.sessions.Get(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to get session: %w", err)
+func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
+ events := make(chan AgentEvent)
+ if a.IsSessionBusy(sessionID) {
+ return nil, ErrSessionBusy
}
- cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
- model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
- model.CostPer1MIn/1e6*float64(usage.InputTokens) +
- model.CostPer1MOut/1e6*float64(usage.OutputTokens)
+ genCtx, cancel := context.WithCancel(ctx)
+
+ a.activeRequests.Store(sessionID, cancel)
+ go func() {
+ logging.Debug("Request started", "sessionID", sessionID)
+ defer logging.RecoverPanic("agent.Run", func() {
+ events <- a.err(fmt.Errorf("panic while running the agent"))
+ })
- session.Cost += cost
- session.CompletionTokens += usage.OutputTokens
- session.PromptTokens += usage.InputTokens
+ result := a.processGeneration(genCtx, sessionID, content)
+ if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
+ logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
+ }
+ logging.Debug("Request completed", "sessionID", sessionID)
+ a.activeRequests.Delete(sessionID)
+ cancel()
+ events <- result
+ close(events)
+ }()
+ return events, nil
+}
- _, err = a.sessions.Save(ctx, session)
+func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
+ // List existing messages; if none, start title generation asynchronously.
+ msgs, err := a.messages.List(ctx, sessionID)
if err != nil {
- return fmt.Errorf("failed to save session: %w", err)
+ return a.err(fmt.Errorf("failed to list messages: %w", err))
+ }
+ if len(msgs) == 0 {
+ go func() {
+ defer logging.RecoverPanic("agent.Run", func() {
+ logging.ErrorPersist("panic while generating title")
+ })
+ titleErr := a.generateTitle(context.Background(), sessionID, content)
+ if titleErr != nil {
+ logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
+ }
+ }()
}
- return nil
-}
-// processEvent handles different types of events during generation
-func (a *agent) processEvent(
- ctx context.Context,
- sessionID string,
- assistantMsg *message.Message,
- event provider.ProviderEvent,
-) error {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- // Continue processing
+ userMsg, err := a.createUserMessage(ctx, sessionID, content)
+ if err != nil {
+ return a.err(fmt.Errorf("failed to create user message: %w", err))
}
- switch event.Type {
- case provider.EventThinkingDelta:
- assistantMsg.AppendReasoningContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventContentDelta:
- assistantMsg.AppendContent(event.Content)
- return a.messages.Update(ctx, *assistantMsg)
- case provider.EventError:
- if errors.Is(event.Error, context.Canceled) {
- logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
- return context.Canceled
+ // Append the new user message to the conversation history.
+ msgHistory := append(msgs, userMsg)
+ for {
+ // Check for cancellation before each iteration
+ select {
+ case <-ctx.Done():
+ return a.err(ctx.Err())
+ default:
+ // Continue processing
}
- logging.ErrorPersist(event.Error.Error())
- return event.Error
- case provider.EventWarning:
- logging.WarnPersist(event.Info)
- case provider.EventInfo:
- logging.InfoPersist(event.Info)
- case provider.EventComplete:
- assistantMsg.SetToolCalls(event.Response.ToolCalls)
- assistantMsg.AddFinish(event.Response.FinishReason)
- if err := a.messages.Update(ctx, *assistantMsg); err != nil {
- return fmt.Errorf("failed to update message: %w", err)
+ agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return a.err(ErrRequestCancelled)
+ }
+ return a.err(fmt.Errorf("failed to process events: %w", err))
+ }
+ logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
+ if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
+ // We are not done, we need to respond with the tool response
+ msgHistory = append(msgHistory, agentMessage, *toolResults)
+ continue
+ }
+ return AgentEvent{
+ message: agentMessage,
}
- return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
}
+}
- return nil
+func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
+ return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
+ Role: message.User,
+ Parts: []message.ContentPart{
+ message.TextContent{Text: content},
+ },
+ })
}
-// ExecuteTools runs all tool calls sequentially and returns the results
-func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
- toolResults := make([]message.ToolResult, len(toolCalls))
+func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
+
+ assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: a.provider.Model().ID,
+ })
+ if err != nil {
+ return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
+ }
- // Create a child context that can be canceled
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
+ // Add the session and message ID into the context if needed by tools.
+ ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
+ ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- // Check if already canceled before starting any execution
- if ctx.Err() != nil {
- // Mark all tools as canceled
- for i, toolCall := range toolCalls {
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: "Tool execution canceled by user",
- IsError: true,
- }
+ // Process each event in the stream.
+ for event := range eventChan {
+ if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
+ return assistantMsg, nil, processErr
+ }
+ if ctx.Err() != nil {
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
+ return assistantMsg, nil, ctx.Err()
}
- return toolResults, ctx.Err()
}
+ toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
+ toolCalls := assistantMsg.ToolCalls()
for i, toolCall := range toolCalls {
- // Check for cancellation before executing each tool
select {
case <-ctx.Done():
- // Mark this and all remaining tools as canceled
+ a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
+ // Make all future tool calls cancelled
for j := i; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
ToolCallID: toolCalls[j].ID,
@@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
IsError: true,
}
}
- return toolResults, ctx.Err()
+ goto out
default:
// Continue processing
- }
-
- response := ""
- isError := false
- found := false
-
- // Find and execute the appropriate tool
- for _, tool := range tls {
- if tool.Info().Name == toolCall.Name {
- found = true
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Name,
- Input: toolCall.Input,
- })
-
- if toolErr != nil {
- if errors.Is(toolErr, context.Canceled) {
- response = "Tool execution canceled by user"
- } else {
- response = fmt.Sprintf("Error running tool: %s", toolErr)
- }
- isError = true
- } else {
- response = toolResult.Content
- isError = toolResult.IsError
+ var tool tools.BaseTool
+ for _, availableTools := range a.tools {
+ if availableTools.Info().Name == toolCall.Name {
+ tool = availableTools
}
- break
}
- }
-
- if !found {
- response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
- isError = true
- }
-
- toolResults[i] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: response,
- IsError: isError,
- }
- }
- return toolResults, nil
-}
-
-// handleToolExecution processes tool calls and creates tool result messages
-func (a *agent) handleToolExecution(
- ctx context.Context,
- assistantMsg message.Message,
-) (*message.Message, error) {
- select {
- case <-ctx.Done():
- // If cancelled, create tool results that indicate cancellation
- if len(assistantMsg.ToolCalls()) > 0 {
- toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
- for _, tc := range assistantMsg.ToolCalls() {
- toolResults = append(toolResults, message.ToolResult{
- ToolCallID: tc.ID,
- Content: "Tool execution canceled by user",
+ // Tool not found
+ if tool == nil {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
IsError: true,
- })
+ }
+ continue
}
- // Use background context to ensure the message is created even if original context is cancelled
- bgCtx := context.Background()
- parts := make([]message.ContentPart, 0)
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
- }
- msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
+ toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Name,
+ Input: toolCall.Input,
})
- if err != nil {
- return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
- }
- return &msg, ctx.Err()
- }
- return nil, ctx.Err()
- default:
- // Continue processing
- }
-
- if len(assistantMsg.ToolCalls()) == 0 {
- return nil, nil
- }
-
- toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
- if err != nil {
- // If error is from cancellation, still return the partial results we have
- if errors.Is(err, context.Canceled) {
- // Use background context to ensure the message is created even if original context is cancelled
- bgCtx := context.Background()
- parts := make([]message.ContentPart, 0)
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
+ if toolErr != nil {
+ if errors.Is(toolErr, permission.ErrorPermissionDenied) {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: "Permission denied",
+ IsError: true,
+ }
+ for j := i + 1; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Tool execution canceled by user",
+ IsError: true,
+ }
+ }
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
+ } else {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: toolErr.Error(),
+ IsError: true,
+ }
+ for j := i; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Previous tool failed",
+ IsError: true,
+ }
+ }
+ a.finishMessage(ctx, &assistantMsg, message.FinishReasonError)
+ }
+ // If permission is denied or an error happens we cancel all the following tools
+ break
}
-
- msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- Parts: parts,
- })
- if createErr != nil {
- logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
- return nil, err
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: toolResult.Content,
+ Metadata: toolResult.Metadata,
+ IsError: toolResult.IsError,
}
- return &msg, err
}
- return nil, err
}
-
- parts := make([]message.ContentPart, 0, len(toolResults))
- for _, toolResult := range toolResults {
- parts = append(parts, toolResult)
+out:
+ if len(toolResults) == 0 {
+ return assistantMsg, nil, nil
}
-
- msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
+ parts := make([]message.ContentPart, 0)
+ for _, tr := range toolResults {
+ parts = append(parts, tr)
+ }
+ msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
if err != nil {
- return nil, fmt.Errorf("failed to create tool message: %w", err)
+ return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
}
- return &msg, nil
+ return assistantMsg, &msg, err
}
-// generate handles the main generation workflow
-func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
- ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
+func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
+ msg.AddFinish(finishReson)
+ _ = a.messages.Update(ctx, *msg)
+}
- // Handle context cancellation at any point
- if err := ctx.Err(); err != nil {
- return ErrRequestCancelled
+func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ // Continue processing.
}
- messages, err := a.messages.List(ctx, sessionID)
- if err != nil {
- return fmt.Errorf("failed to list messages: %w", err)
+ switch event.Type {
+ case provider.EventThinkingDelta:
+ assistantMsg.AppendReasoningContent(event.Content)
+ return a.messages.Update(ctx, *assistantMsg)
+ case provider.EventContentDelta:
+ assistantMsg.AppendContent(event.Content)
+ return a.messages.Update(ctx, *assistantMsg)
+ case provider.EventError:
+ if errors.Is(event.Error, context.Canceled) {
+ logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
+ return context.Canceled
+ }
+ logging.ErrorPersist(event.Error.Error())
+ return event.Error
+ case provider.EventComplete:
+ assistantMsg.SetToolCalls(event.Response.ToolCalls)
+ assistantMsg.AddFinish(event.Response.FinishReason)
+ if err := a.messages.Update(ctx, *assistantMsg); err != nil {
+ return fmt.Errorf("failed to update message: %w", err)
+ }
+ return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
}
- if len(messages) == 0 {
- titleCtx := context.Background()
- go a.handleTitleGeneration(titleCtx, sessionID, content)
- }
+ return nil
+}
- userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.User,
- Parts: []message.ContentPart{
- message.TextContent{
- Text: content,
- },
- },
- })
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+ sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
- return fmt.Errorf("failed to create user message: %w", err)
+ return fmt.Errorf("failed to get session: %w", err)
}
- messages = append(messages, userMsg)
-
- for {
- // Check for cancellation before each iteration
- select {
- case <-ctx.Done():
- return ErrRequestCancelled
- default:
- // Continue processing
- }
-
- eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- return ErrRequestCancelled
- }
- return fmt.Errorf("failed to stream response: %w", err)
- }
-
- assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- Model: a.model.ID,
- })
- if err != nil {
- return fmt.Errorf("failed to create assistant message: %w", err)
- }
-
- ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
-
- // Process events from the LLM provider
- for event := range eventChan {
- if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
- if errors.Is(err, context.Canceled) {
- // Mark as canceled but don't create separate message
- assistantMsg.AddFinish("canceled")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- }
- assistantMsg.AddFinish("error:" + err.Error())
- _ = a.messages.Update(ctx, assistantMsg)
- return fmt.Errorf("event processing error: %w", err)
- }
-
- // Check for cancellation during event processing
- select {
- case <-ctx.Done():
- // Mark as canceled
- assistantMsg.AddFinish("canceled")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- default:
- }
- }
-
- // Check for cancellation before tool execution
- select {
- case <-ctx.Done():
- assistantMsg.AddFinish("canceled_by_user")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- default:
- }
-
- // Execute any tool calls
- toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
- if err != nil {
- if errors.Is(err, context.Canceled) {
- assistantMsg.AddFinish("canceled_by_user")
- _ = a.messages.Update(context.Background(), assistantMsg)
- return ErrRequestCancelled
- }
- return fmt.Errorf("tool execution error: %w", err)
- }
-
- if err := a.messages.Update(ctx, assistantMsg); err != nil {
- return fmt.Errorf("failed to update assistant message: %w", err)
- }
-
- // If no tool calls, we're done
- if len(assistantMsg.ToolCalls()) == 0 {
- break
- }
+ cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
+ model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
+ model.CostPer1MIn/1e6*float64(usage.InputTokens) +
+ model.CostPer1MOut/1e6*float64(usage.OutputTokens)
- // Add messages for next iteration
- messages = append(messages, assistantMsg)
- if toolMsg != nil {
- messages = append(messages, *toolMsg)
- }
+ sess.Cost += cost
+ sess.CompletionTokens += usage.OutputTokens
+ sess.PromptTokens += usage.InputTokens
- // Check for cancellation after tool execution
- select {
- case <-ctx.Done():
- return ErrRequestCancelled
- default:
- }
+ _, err = a.sessions.Save(ctx, sess)
+ if err != nil {
+ return fmt.Errorf("failed to save session: %w", err)
}
-
return nil
}
-// getAgentProviders initializes the LLM providers based on the chosen model
-func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
- maxTokens := config.Get().Model.CoderMaxTokens
-
- providerConfig, ok := config.Get().Providers[model.Provider]
- if !ok || providerConfig.Disabled {
- return nil, nil, ErrProviderNotEnabled
+func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
+ cfg := config.Get()
+ agentConfig, ok := cfg.Agents[agentName]
+ if !ok {
+ return nil, fmt.Errorf("agent %s not found", agentName)
+ }
+ model, ok := models.SupportedModels[agentConfig.Model]
+ if !ok {
+ return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
}
- var agentProvider provider.Provider
- var titleGenerator provider.Provider
- var err error
-
- switch model.Provider {
- case models.ProviderOpenAI:
- agentProvider, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.CoderOpenAISystemPrompt(),
- ),
- provider.WithOpenAIMaxTokens(maxTokens),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithOpenAIMaxTokens(80),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
- }
-
- case models.ProviderAnthropic:
- agentProvider, err = provider.NewAnthropicProvider(
- provider.WithAnthropicSystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithAnthropicMaxTokens(maxTokens),
- provider.WithAnthropicKey(providerConfig.APIKey),
- provider.WithAnthropicModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewAnthropicProvider(
- provider.WithAnthropicSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithAnthropicMaxTokens(80),
- provider.WithAnthropicKey(providerConfig.APIKey),
- provider.WithAnthropicModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
- }
-
- case models.ProviderGemini:
- agentProvider, err = provider.NewGeminiProvider(
- ctx,
- provider.WithGeminiSystemMessage(
- prompt.CoderOpenAISystemPrompt(),
- ),
- provider.WithGeminiMaxTokens(int32(maxTokens)),
- provider.WithGeminiKey(providerConfig.APIKey),
- provider.WithGeminiModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewGeminiProvider(
- ctx,
- provider.WithGeminiSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithGeminiMaxTokens(80),
- provider.WithGeminiKey(providerConfig.APIKey),
- provider.WithGeminiModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
- }
-
- case models.ProviderGROQ:
- agentProvider, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithOpenAIMaxTokens(maxTokens),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewOpenAIProvider(
- provider.WithOpenAISystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithOpenAIMaxTokens(80),
- provider.WithOpenAIModel(model),
- provider.WithOpenAIKey(providerConfig.APIKey),
- provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
- }
-
- case models.ProviderBedrock:
- agentProvider, err = provider.NewBedrockProvider(
- provider.WithBedrockSystemMessage(
- prompt.CoderAnthropicSystemPrompt(),
- ),
- provider.WithBedrockMaxTokens(maxTokens),
- provider.WithBedrockModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
- }
-
- titleGenerator, err = provider.NewBedrockProvider(
- provider.WithBedrockSystemMessage(
- prompt.TitlePrompt(),
- ),
- provider.WithBedrockMaxTokens(80),
- provider.WithBedrockModel(model),
- )
- if err != nil {
- return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
- }
- default:
- return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
+ providerCfg, ok := cfg.Providers[model.Provider]
+ if !ok {
+ return nil, fmt.Errorf("provider %s not supported", model.Provider)
+ }
+ if providerCfg.Disabled {
+ return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
+ }
+ agentProvider, err := provider.NewProvider(
+ model.Provider,
+ provider.WithAPIKey(providerCfg.APIKey),
+ provider.WithModel(model),
+ provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
+ provider.WithMaxTokens(agentConfig.MaxTokens),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("could not create provider: %v", err)
}
- return agentProvider, titleGenerator, nil
+ return agentProvider, nil
}
diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go
deleted file mode 100644
index a3db6b55c..000000000
--- a/internal/llm/agent/coder.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package agent
-
-import (
- "context"
- "errors"
-
- "github.com/kujtimiihoxha/termai/internal/config"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
- "github.com/kujtimiihoxha/termai/internal/llm/tools"
- "github.com/kujtimiihoxha/termai/internal/lsp"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/session"
-)
-
-type coderAgent struct {
- Service
-}
-
-func NewCoderAgent(
- permissions permission.Service,
- sessions session.Service,
- messages message.Service,
- lspClients map[string]*lsp.Client,
-) (Service, error) {
- model, ok := models.SupportedModels[config.Get().Model.Coder]
- if !ok {
- return nil, errors.New("model not supported")
- }
-
- ctx := context.Background()
- otherTools := GetMcpTools(ctx, permissions)
- if len(lspClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
- }
- agent, err := NewAgent(
- ctx,
- sessions,
- messages,
- model,
- append(
- []tools.BaseTool{
- tools.NewBashTool(permissions),
- tools.NewEditTool(lspClients, permissions),
- tools.NewFetchTool(permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- tools.NewWriteTool(lspClients, permissions),
- NewAgentTool(sessions, messages, lspClients),
- }, otherTools...,
- ),
- )
- if err != nil {
- return nil, err
- }
-
- return &coderAgent{
- agent,
- }, nil
-}
diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go
index b1c97b512..c7ea4916c 100644
--- a/internal/llm/agent/mcp-tools.go
+++ b/internal/llm/agent/mcp-tools.go
@@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "termai",
+ Name: "OpenCode",
Version: version.Version,
}
@@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "termai",
+ Name: "OpenCode",
Version: version.Version,
}
diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go
deleted file mode 100644
index fca1f223f..000000000
--- a/internal/llm/agent/task.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package agent
-
-import (
- "context"
- "errors"
-
- "github.com/kujtimiihoxha/termai/internal/config"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
- "github.com/kujtimiihoxha/termai/internal/llm/tools"
- "github.com/kujtimiihoxha/termai/internal/lsp"
- "github.com/kujtimiihoxha/termai/internal/message"
- "github.com/kujtimiihoxha/termai/internal/session"
-)
-
-type taskAgent struct {
- Service
-}
-
-func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) {
- model, ok := models.SupportedModels[config.Get().Model.Coder]
- if !ok {
- return nil, errors.New("model not supported")
- }
-
- ctx := context.Background()
-
- agent, err := NewAgent(
- ctx,
- sessions,
- messages,
- model,
- []tools.BaseTool{
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients),
- },
- )
- if err != nil {
- return nil, err
- }
-
- return &taskAgent{
- agent,
- }, nil
-}
diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go
new file mode 100644
index 000000000..a37f1d65d
--- /dev/null
+++ b/internal/llm/agent/tools.go
@@ -0,0 +1,50 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/session"
+)
+
+func CoderAgentTools(
+ permissions permission.Service,
+ sessions session.Service,
+ messages message.Service,
+ history history.Service,
+ lspClients map[string]*lsp.Client,
+) []tools.BaseTool {
+ ctx := context.Background()
+ otherTools := GetMcpTools(ctx, permissions)
+ if len(lspClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+ }
+ return append(
+ []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions, history),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewWriteTool(lspClients, permissions, history),
+ NewAgentTool(sessions, messages, lspClients),
+ }, otherTools...,
+ )
+}
+
+func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
+ return []tools.BaseTool{
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ }
+}
diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go
new file mode 100644
index 000000000..48307e6d3
--- /dev/null
+++ b/internal/llm/models/anthropic.go
@@ -0,0 +1,71 @@
+package models
+
+const (
+ ProviderAnthropic ModelProvider = "anthropic"
+
+ // Models
+ Claude35Sonnet ModelID = "claude-3.5-sonnet"
+ Claude3Haiku ModelID = "claude-3-haiku"
+ Claude37Sonnet ModelID = "claude-3.7-sonnet"
+ Claude35Haiku ModelID = "claude-3.5-haiku"
+ Claude3Opus ModelID = "claude-3-opus"
+)
+
+var AnthropicModels = map[ModelID]Model{
+ // Anthropic
+ Claude35Sonnet: {
+ ID: Claude35Sonnet,
+ Name: "Claude 3.5 Sonnet",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-5-sonnet-latest",
+ CostPer1MIn: 3.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.30,
+ CostPer1MOut: 15.0,
+ ContextWindow: 200000,
+ },
+ Claude3Haiku: {
+ ID: Claude3Haiku,
+ Name: "Claude 3 Haiku",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-haiku-latest",
+ CostPer1MIn: 0.25,
+ CostPer1MInCached: 0.30,
+ CostPer1MOutCached: 0.03,
+ CostPer1MOut: 1.25,
+ ContextWindow: 200000,
+ },
+ Claude37Sonnet: {
+ ID: Claude37Sonnet,
+ Name: "Claude 3.7 Sonnet",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-7-sonnet-latest",
+ CostPer1MIn: 3.0,
+ CostPer1MInCached: 3.75,
+ CostPer1MOutCached: 0.30,
+ CostPer1MOut: 15.0,
+ ContextWindow: 200000,
+ },
+ Claude35Haiku: {
+ ID: Claude35Haiku,
+ Name: "Claude 3.5 Haiku",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-5-haiku-latest",
+ CostPer1MIn: 0.80,
+ CostPer1MInCached: 1.0,
+ CostPer1MOutCached: 0.08,
+ CostPer1MOut: 4.0,
+ ContextWindow: 200000,
+ },
+ Claude3Opus: {
+ ID: Claude3Opus,
+ Name: "Claude 3 Opus",
+ Provider: ProviderAnthropic,
+ APIModel: "claude-3-opus-latest",
+ CostPer1MIn: 15.0,
+ CostPer1MInCached: 18.75,
+ CostPer1MOutCached: 1.50,
+ CostPer1MOut: 75.0,
+ ContextWindow: 200000,
+ },
+}
diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go
index 140693237..4d4589bfd 100644
--- a/internal/llm/models/models.go
+++ b/internal/llm/models/models.go
@@ -1,5 +1,7 @@
package models
+import "maps"
+
type (
ModelID string
ModelProvider string
@@ -14,15 +16,13 @@ type Model struct {
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
}
// Model IDs
const (
- // Anthropic
- Claude35Sonnet ModelID = "claude-3.5-sonnet"
- Claude3Haiku ModelID = "claude-3-haiku"
- Claude37Sonnet ModelID = "claude-3.7-sonnet"
// OpenAI
+ GPT4o ModelID = "gpt-4o"
GPT41 ModelID = "gpt-4.1"
// GEMINI
@@ -37,47 +37,59 @@ const (
)
const (
- ProviderOpenAI ModelProvider = "openai"
- ProviderAnthropic ModelProvider = "anthropic"
- ProviderBedrock ModelProvider = "bedrock"
- ProviderGemini ModelProvider = "gemini"
- ProviderGROQ ModelProvider = "groq"
+ ProviderOpenAI ModelProvider = "openai"
+ ProviderBedrock ModelProvider = "bedrock"
+ ProviderGemini ModelProvider = "gemini"
+ ProviderGROQ ModelProvider = "groq"
+
+ // ForTests
+ ProviderMock ModelProvider = "__mock"
)
var SupportedModels = map[ModelID]Model{
- // Anthropic
- Claude35Sonnet: {
- ID: Claude35Sonnet,
- Name: "Claude 3.5 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-5-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
- Claude3Haiku: {
- ID: Claude3Haiku,
- Name: "Claude 3 Haiku",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-haiku-latest",
- CostPer1MIn: 0.80,
- CostPer1MInCached: 1,
- CostPer1MOutCached: 0.08,
- CostPer1MOut: 4,
- },
- Claude37Sonnet: {
- ID: Claude37Sonnet,
- Name: "Claude 3.7 Sonnet",
- Provider: ProviderAnthropic,
- APIModel: "claude-3-7-sonnet-latest",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
+ // // Anthropic
+ // Claude35Sonnet: {
+ // ID: Claude35Sonnet,
+ // Name: "Claude 3.5 Sonnet",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-5-sonnet-latest",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+ // Claude3Haiku: {
+ // ID: Claude3Haiku,
+ // Name: "Claude 3 Haiku",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-haiku-latest",
+ // CostPer1MIn: 0.80,
+ // CostPer1MInCached: 1,
+ // CostPer1MOutCached: 0.08,
+ // CostPer1MOut: 4,
+ // },
+ // Claude37Sonnet: {
+ // ID: Claude37Sonnet,
+ // Name: "Claude 3.7 Sonnet",
+ // Provider: ProviderAnthropic,
+ // APIModel: "claude-3-7-sonnet-latest",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+ //
+ // // OpenAI
+ GPT4o: {
+ ID: GPT4o,
+ Name: "GPT-4o",
+ Provider: ProviderOpenAI,
+ APIModel: "gpt-4.1",
+ CostPer1MIn: 2.00,
+ CostPer1MInCached: 0.50,
+ CostPer1MOutCached: 0,
+ CostPer1MOut: 8.00,
},
-
- // OpenAI
GPT41: {
ID: GPT41,
Name: "GPT-4.1",
@@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{
CostPer1MOutCached: 0,
CostPer1MOut: 8.00,
},
+ //
+ // // GEMINI
+ // GEMINI25: {
+ // ID: GEMINI25,
+ // Name: "Gemini 2.5 Pro",
+ // Provider: ProviderGemini,
+ // APIModel: "gemini-2.5-pro-exp-03-25",
+ // CostPer1MIn: 0,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0,
+ // CostPer1MOut: 0,
+ // },
+ //
+ // GRMINI20Flash: {
+ // ID: GRMINI20Flash,
+ // Name: "Gemini 2.0 Flash",
+ // Provider: ProviderGemini,
+ // APIModel: "gemini-2.0-flash",
+ // CostPer1MIn: 0.1,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0.025,
+ // CostPer1MOut: 0.4,
+ // },
+ //
+ // // GROQ
+ // QWENQwq: {
+ // ID: QWENQwq,
+ // Name: "Qwen Qwq",
+ // Provider: ProviderGROQ,
+ // APIModel: "qwen-qwq-32b",
+ // CostPer1MIn: 0,
+ // CostPer1MInCached: 0,
+ // CostPer1MOutCached: 0,
+ // CostPer1MOut: 0,
+ // },
+ //
+ // // Bedrock
+ // BedrockClaude37Sonnet: {
+ // ID: BedrockClaude37Sonnet,
+ // Name: "Bedrock: Claude 3.7 Sonnet",
+ // Provider: ProviderBedrock,
+ // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ // CostPer1MIn: 3.0,
+ // CostPer1MInCached: 3.75,
+ // CostPer1MOutCached: 0.30,
+ // CostPer1MOut: 15.0,
+ // },
+}
- // GEMINI
- GEMINI25: {
- ID: GEMINI25,
- Name: "Gemini 2.5 Pro",
- Provider: ProviderGemini,
- APIModel: "gemini-2.5-pro-exp-03-25",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0,
- },
-
- GRMINI20Flash: {
- ID: GRMINI20Flash,
- Name: "Gemini 2.0 Flash",
- Provider: ProviderGemini,
- APIModel: "gemini-2.0-flash",
- CostPer1MIn: 0.1,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0.025,
- CostPer1MOut: 0.4,
- },
-
- // GROQ
- QWENQwq: {
- ID: QWENQwq,
- Name: "Qwen Qwq",
- Provider: ProviderGROQ,
- APIModel: "qwen-qwq-32b",
- CostPer1MIn: 0,
- CostPer1MInCached: 0,
- CostPer1MOutCached: 0,
- CostPer1MOut: 0,
- },
-
- // Bedrock
- BedrockClaude37Sonnet: {
- ID: BedrockClaude37Sonnet,
- Name: "Bedrock: Claude 3.7 Sonnet",
- Provider: ProviderBedrock,
- APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
- CostPer1MIn: 3.0,
- CostPer1MInCached: 3.75,
- CostPer1MOutCached: 0.30,
- CostPer1MOut: 15.0,
- },
+func init() {
+ maps.Copy(SupportedModels, AnthropicModels)
}
diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go
index 47941f976..7439fd570 100644
--- a/internal/llm/prompt/coder.go
+++ b/internal/llm/prompt/coder.go
@@ -9,11 +9,22 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
)
-func CoderOpenAISystemPrompt() string {
- basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
+func CoderPrompt(provider models.ModelProvider) string {
+ basePrompt := baseAnthropicCoderPrompt
+ switch provider {
+ case models.ProviderOpenAI:
+ basePrompt = baseOpenAICoderPrompt
+ }
+ envInfo := getEnvironmentInfo()
+
+ return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+}
+
+const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting.
# Your mindset
Act like a competent, efficient software engineer who is familiar with large codebases. You should:
@@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines]
Never commit changes unless the user explicitly asks you to.`
- envInfo := getEnvironmentInfo()
-
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
-}
-
-func CoderAnthropicSystemPrompt() string {
- basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
+const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure.
@@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.`
- envInfo := getEnvironmentInfo()
-
- return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
-}
-
func getEnvironmentInfo() string {
cwd := config.WorkingDirectory()
isGit := isGitRepo(cwd)
diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go
new file mode 100644
index 000000000..63fc2df7b
--- /dev/null
+++ b/internal/llm/prompt/prompt.go
@@ -0,0 +1,19 @@
+package prompt
+
+import (
+ "github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
+)
+
+func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
+ switch agentName {
+ case config.AgentCoder:
+ return CoderPrompt(provider)
+ case config.AgentTitle:
+ return TitlePrompt(provider)
+ case config.AgentTask:
+ return TaskPrompt(provider)
+ default:
+ return "You are a helpful assistant"
+ }
+}
diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go
index ee3c707fa..8bf604ad9 100644
--- a/internal/llm/prompt/task.go
+++ b/internal/llm/prompt/task.go
@@ -2,11 +2,12 @@ package prompt
import (
"fmt"
+
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
)
-func TaskAgentSystemPrompt() string {
+func TaskPrompt(_ models.ModelProvider) string {
agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question.
-
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go
index 5c47f4d64..3023a8550 100644
--- a/internal/llm/prompt/title.go
+++ b/internal/llm/prompt/title.go
@@ -1,6 +1,8 @@
package prompt
-func TitlePrompt() string {
+import "github.com/kujtimiihoxha/termai/internal/llm/models"
+
+func TitlePrompt(_ models.ModelProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message
diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go
index 93c4308ad..c3a4efc49 100644
--- a/internal/llm/provider/anthropic.go
+++ b/internal/llm/provider/anthropic.go
@@ -12,187 +12,257 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
)
-type anthropicProvider struct {
- client anthropic.Client
- model models.Model
- maxTokens int64
- apiKey string
- systemMessage string
- useBedrock bool
- disableCache bool
+type anthropicOptions struct {
+ useBedrock bool
+ disableCache bool
+ shouldThink func(userMessage string) bool
}
-type AnthropicOption func(*anthropicProvider)
+type AnthropicOption func(*anthropicOptions)
-func WithAnthropicSystemMessage(message string) AnthropicOption {
- return func(a *anthropicProvider) {
- a.systemMessage = message
- }
+type anthropicClient struct {
+ providerOptions providerClientOptions
+ options anthropicOptions
+ client anthropic.Client
}
-func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
- return func(a *anthropicProvider) {
- a.maxTokens = maxTokens
- }
-}
+type AnthropicClient ProviderClient
-func WithAnthropicModel(model models.Model) AnthropicOption {
- return func(a *anthropicProvider) {
- a.model = model
+func newAnthropicClient(opts providerClientOptions) AnthropicClient {
+ anthropicOpts := anthropicOptions{}
+ for _, o := range opts.anthropicOptions {
+ o(&anthropicOpts)
}
-}
-func WithAnthropicKey(apiKey string) AnthropicOption {
- return func(a *anthropicProvider) {
- a.apiKey = apiKey
+ anthropicClientOptions := []option.RequestOption{}
+ if opts.apiKey != "" {
+ anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
}
-}
-
-func WithAnthropicBedrock() AnthropicOption {
- return func(a *anthropicProvider) {
- a.useBedrock = true
+ if anthropicOpts.useBedrock {
+ anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
-}
-func WithAnthropicDisableCache() AnthropicOption {
- return func(a *anthropicProvider) {
- a.disableCache = true
+ client := anthropic.NewClient(anthropicClientOptions...)
+ return &anthropicClient{
+ providerOptions: opts,
+ options: anthropicOpts,
+ client: client,
}
}
-func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
- provider := &anthropicProvider{
- maxTokens: 1024,
- }
+func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
+ cachedBlocks := 0
+ for _, msg := range messages {
+ switch msg.Role {
+ case message.User:
+ content := anthropic.NewTextBlock(msg.Content().String())
+ if cachedBlocks < 2 && !a.options.disableCache {
+ content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
+ cachedBlocks++
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
- for _, opt := range opts {
- opt(provider)
- }
+ case message.Assistant:
+ blocks := []anthropic.ContentBlockParamUnion{}
+ if msg.Content().String() != "" {
+ content := anthropic.NewTextBlock(msg.Content().String())
+ if cachedBlocks < 2 && !a.options.disableCache {
+ content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
+ cachedBlocks++
+ }
+ blocks = append(blocks, content)
+ }
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
- }
+ for _, toolCall := range msg.ToolCalls() {
+ var inputMap map[string]any
+ err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
+ if err != nil {
+ continue
+ }
+ blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
+ }
- anthropicOptions := []option.RequestOption{}
+ if len(blocks) == 0 {
+ logging.Warn("There is a message without content, investigate")
+ // This should never happend but we log this because we might have a bug in our cleanup method
+ continue
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
- if provider.apiKey != "" {
- anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
- }
- if provider.useBedrock {
- anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
+ case message.Tool:
+ results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
+ for i, toolResult := range msg.ToolResults() {
+ results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
+ }
+ anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
+ }
}
-
- provider.client = anthropic.NewClient(anthropicOptions...)
- return provider, nil
+ return
}
-func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- anthropicMessages := a.convertToAnthropicMessages(messages)
- anthropicTools := a.convertToAnthropicTools(tools)
-
- response, err := a.client.Messages.New(
- ctx,
- anthropic.MessageNewParams{
- Model: anthropic.Model(a.model.APIModel),
- MaxTokens: a.maxTokens,
- Temperature: anthropic.Float(0),
- Messages: anthropicMessages,
- Tools: anthropicTools,
- System: []anthropic.TextBlockParam{
- {
- Text: a.systemMessage,
- CacheControl: anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- },
- },
+func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
+ anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
+
+ for i, tool := range tools {
+ info := tool.Info()
+ toolParam := anthropic.ToolParam{
+ Name: info.Name,
+ Description: anthropic.String(info.Description),
+ InputSchema: anthropic.ToolInputSchemaParam{
+ Properties: info.Parameters,
+ // TODO: figure out how we can tell claude the required fields?
},
- },
- )
- if err != nil {
- return nil, err
- }
+ }
- content := ""
- for _, block := range response.Content {
- if text, ok := block.AsAny().(anthropic.TextBlock); ok {
- content += text.Text
+ if i == len(tools)-1 && !a.options.disableCache {
+ toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ }
}
- }
- toolCalls := a.extractToolCalls(response.Content)
- tokenUsage := a.extractTokenUsage(response.Usage)
+ anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return anthropicTools
}
-func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- anthropicMessages := a.convertToAnthropicMessages(messages)
- anthropicTools := a.convertToAnthropicTools(tools)
+func (a *anthropicClient) finishReason(reason string) message.FinishReason {
+ switch reason {
+ case "end_turn":
+ return message.FinishReasonEndTurn
+ case "max_tokens":
+ return message.FinishReasonMaxTokens
+ case "tool_use":
+ return message.FinishReasonToolUse
+ case "stop_sequence":
+ return message.FinishReasonEndTurn
+ default:
+ return message.FinishReasonUnknown
+ }
+}
+func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
+ isUser := lastMessage.Role == anthropic.MessageParamRoleUser
+ messageContent := ""
temperature := anthropic.Float(0)
- if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
- thinkingParam = anthropic.ThinkingConfigParamUnion{
- OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
- BudgetTokens: int64(float64(a.maxTokens) * 0.8),
- Type: "enabled",
- },
+ if isUser {
+ for _, m := range lastMessage.Content {
+ if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" {
+ messageContent = m.OfRequestTextBlock.Text
+ }
+ }
+ if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
+ thinkingParam = anthropic.ThinkingConfigParamUnion{
+ OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
+ BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8),
+ Type: "enabled",
+ },
+ }
+ temperature = anthropic.Float(1)
}
- temperature = anthropic.Float(1)
}
- eventChan := make(chan ProviderEvent)
+ return anthropic.MessageNewParams{
+ Model: anthropic.Model(a.providerOptions.model.APIModel),
+ MaxTokens: a.providerOptions.maxTokens,
+ Temperature: temperature,
+ Messages: messages,
+ Tools: tools,
+ Thinking: thinkingParam,
+ System: []anthropic.TextBlockParam{
+ {
+ Text: a.providerOptions.systemMessage,
+ CacheControl: anthropic.CacheControlEphemeralParam{
+ Type: "ephemeral",
+ },
+ },
+ },
+ }
+}
- go func() {
- defer close(eventChan)
+func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) {
+ preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(preparedMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
+ for {
+ attempts++
+ anthropicResponse, err := a.client.Messages.New(
+ ctx,
+ preparedMessages,
+ )
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := a.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
+ }
- const maxRetries = 8
- attempts := 0
+ content := ""
+ for _, block := range anthropicResponse.Content {
+ if text, ok := block.AsAny().(anthropic.TextBlock); ok {
+ content += text.Text
+ }
+ }
- for {
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: a.toolCalls(*anthropicResponse),
+ Usage: a.usage(*anthropicResponse),
+ }, nil
+ }
+}
+func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(preparedMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
+ eventChan := make(chan ProviderEvent)
+ go func() {
+ for {
attempts++
-
- stream := a.client.Messages.NewStreaming(
+ anthropicStream := a.client.Messages.NewStreaming(
ctx,
- anthropic.MessageNewParams{
- Model: anthropic.Model(a.model.APIModel),
- MaxTokens: a.maxTokens,
- Temperature: temperature,
- Messages: anthropicMessages,
- Tools: anthropicTools,
- Thinking: thinkingParam,
- System: []anthropic.TextBlockParam{
- {
- Text: a.systemMessage,
- CacheControl: anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- },
- },
- },
- },
+ preparedMessages,
)
-
accumulatedMessage := anthropic.Message{}
- for stream.Next() {
- event := stream.Current()
+ for anthropicStream.Next() {
+ event := anthropicStream.Current()
err := accumulatedMessage.Accumulate(event)
if err != nil {
eventChan <- ProviderEvent{Type: EventError, Error: err}
- return // Don't retry on accumulation errors
+ continue
}
switch event := event.AsAny().(type) {
@@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
Content: event.Delta.Text,
}
}
+ // TODO: check if we can somehow stream tool calls
case anthropic.ContentBlockStopEvent:
eventChan <- ProviderEvent{Type: EventContentStop}
@@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
}
}
- toolCalls := a.extractToolCalls(accumulatedMessage.Content)
- tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
-
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- FinishReason: string(accumulatedMessage.StopReason),
+ ToolCalls: a.toolCalls(accumulatedMessage),
+ Usage: a.usage(accumulatedMessage),
+ FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
},
}
}
}
- err := stream.Err()
+ err := anthropicStream.Err()
if err == nil || errors.Is(err, io.EOF) {
+ close(eventChan)
return
}
-
- var apierr *anthropic.Error
- if !errors.As(err, &apierr) {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
- return
- }
-
- if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
+ // If there is an error we are going to see if we can retry the call
+ retry, after, retryErr := a.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
return
}
-
- if attempts > maxRetries {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: errors.New("maximum retry attempts reached for rate limit (429)"),
- }
- return
- }
-
- retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
- if len(retryAfterValues) > 0 {
- var retryAfterSec int
- if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
- retryMs = retryAfterSec * 1000
- eventChan <- ProviderEvent{
- Type: EventWarning,
- Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ // context cancelled
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
}
+ close(eventChan)
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
}
- } else {
- eventChan <- ProviderEvent{
- Type: EventWarning,
- Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
- }
-
- backoffMs := 2000 * (1 << (attempts - 1))
- jitterMs := int(float64(backoffMs) * 0.2)
- retryMs = backoffMs + jitterMs
}
- select {
- case <-ctx.Done():
+ if ctx.Err() != nil {
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
- return
- case <-time.After(time.Duration(retryMs) * time.Millisecond):
- continue
}
+ close(eventChan)
+ return
}
}()
+ return eventChan
+}
- return eventChan, nil
+func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ var apierr *anthropic.Error
+ if !errors.As(err, &apierr) {
+ return false, 0, err
+ }
+
+ if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
+ return false, 0, err
+ }
+
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
+
+ retryMs := 0
+ retryAfterValues := apierr.Response.Header.Values("Retry-After")
+
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs = backoffMs + jitterMs
+ if len(retryAfterValues) > 0 {
+ if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+ retryMs = retryMs * 1000
+ }
+ }
+ return true, int64(retryMs), nil
}
-func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
+func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
var toolCalls []message.ToolCall
- for _, block := range content {
+ for _, block := range msg.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
toolCall := message.ToolCall{
@@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni
return toolCalls
}
-func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
+func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
return TokenUsage{
- InputTokens: usage.InputTokens,
- OutputTokens: usage.OutputTokens,
- CacheCreationTokens: usage.CacheCreationInputTokens,
- CacheReadTokens: usage.CacheReadInputTokens,
+ InputTokens: msg.Usage.InputTokens,
+ OutputTokens: msg.Usage.OutputTokens,
+ CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
+ CacheReadTokens: msg.Usage.CacheReadInputTokens,
}
}
-func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
- anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
-
- for i, tool := range tools {
- info := tool.Info()
- toolParam := anthropic.ToolParam{
- Name: info.Name,
- Description: anthropic.String(info.Description),
- InputSchema: anthropic.ToolInputSchemaParam{
- Properties: info.Parameters,
- },
- }
-
- if i == len(tools)-1 && !a.disableCache {
- toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- }
-
- anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
+func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.useBedrock = useBedrock
}
-
- return anthropicTools
}
-func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
- anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
- cachedBlocks := 0
-
- for _, msg := range messages {
- switch msg.Role {
- case message.User:
- content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 && !a.disableCache {
- content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- cachedBlocks++
- }
- anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
-
- case message.Assistant:
- blocks := []anthropic.ContentBlockParamUnion{}
- if msg.Content().String() != "" {
- content := anthropic.NewTextBlock(msg.Content().String())
- if cachedBlocks < 2 && !a.disableCache {
- content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
- Type: "ephemeral",
- }
- cachedBlocks++
- }
- blocks = append(blocks, content)
- }
-
- for _, toolCall := range msg.ToolCalls() {
- var inputMap map[string]any
- err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
- if err != nil {
- continue
- }
- blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
- }
+func WithAnthropicDisableCache() AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.disableCache = true
+ }
+}
- if len(blocks) > 0 {
- anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
- }
+func DefaultShouldThinkFn(s string) bool {
+ return strings.Contains(strings.ToLower(s), "think")
+}
- case message.Tool:
- results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
- for i, toolResult := range msg.ToolResults() {
- results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
- }
- anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
- }
+func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
+ return func(options *anthropicOptions) {
+ options.shouldThink = fn
}
-
- return anthropicMessages
}
diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go
index 677f4676b..d76925ad1 100644
--- a/internal/llm/provider/bedrock.go
+++ b/internal/llm/provider/bedrock.go
@@ -7,33 +7,29 @@ import (
"os"
"strings"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
-type bedrockProvider struct {
- childProvider Provider
- model models.Model
- maxTokens int64
- systemMessage string
+type bedrockOptions struct {
+ // Bedrock specific options can be added here
}
-func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- return b.childProvider.SendMessages(ctx, messages, tools)
-}
+type BedrockOption func(*bedrockOptions)
-func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- return b.childProvider.StreamResponse(ctx, messages, tools)
+type bedrockClient struct {
+ providerOptions providerClientOptions
+ options bedrockOptions
+ childProvider ProviderClient
}
-func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
- provider := &bedrockProvider{}
- for _, opt := range opts {
- opt(provider)
- }
+type BedrockClient ProviderClient
+
+func newBedrockClient(opts providerClientOptions) BedrockClient {
+ bedrockOpts := bedrockOptions{}
+ // Apply bedrock specific options if they are added in the future
- // based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
+ // Get AWS region from environment
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
@@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
region = "us-east-1" // default region
}
if len(region) < 2 {
- return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: nil, // Will cause an error when used
+ }
}
+
+ // Prefix the model name with region
regionPrefix := region[:2]
- provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)
+ modelName := opts.model.APIModel
+ opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
- if strings.Contains(string(provider.model.APIModel), "anthropic") {
- anthropic, err := NewAnthropicProvider(
- WithAnthropicModel(provider.model),
- WithAnthropicMaxTokens(provider.maxTokens),
- WithAnthropicSystemMessage(provider.systemMessage),
- WithAnthropicBedrock(),
+ // Determine which provider to use based on the model
+ if strings.Contains(string(opts.model.APIModel), "anthropic") {
+ // Create Anthropic client with Bedrock configuration
+ anthropicOpts := opts
+ anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
+ WithAnthropicBedrock(true),
WithAnthropicDisableCache(),
)
- provider.childProvider = anthropic
- if err != nil {
- return nil, err
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: newAnthropicClient(anthropicOpts),
}
- } else {
- return nil, errors.New("unsupported model for bedrock provider")
}
- return provider, nil
-}
-
-type BedrockOption func(*bedrockProvider)
-func WithBedrockSystemMessage(message string) BedrockOption {
- return func(a *bedrockProvider) {
- a.systemMessage = message
+ // Return client with nil childProvider if model is not supported
+ // This will cause an error when used
+ return &bedrockClient{
+ providerOptions: opts,
+ options: bedrockOpts,
+ childProvider: nil,
}
}
-func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
- return func(a *bedrockProvider) {
- a.maxTokens = maxTokens
+func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ if b.childProvider == nil {
+ return nil, errors.New("unsupported model for bedrock provider")
}
+ return b.childProvider.send(ctx, messages, tools)
}
-func WithBedrockModel(model models.Model) BedrockOption {
- return func(a *bedrockProvider) {
- a.model = model
+func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ eventChan := make(chan ProviderEvent)
+
+ if b.childProvider == nil {
+ go func() {
+ eventChan <- ProviderEvent{
+ Type: EventError,
+ Error: errors.New("unsupported model for bedrock provider"),
+ }
+ close(eventChan)
+ }()
+ return eventChan
}
-}
+
+ return b.childProvider.stream(ctx, messages, tools)
+} \ No newline at end of file
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go
index 2d1db2b64..804baea28 100644
--- a/internal/llm/provider/gemini.go
+++ b/internal/llm/provider/gemini.go
@@ -4,80 +4,68 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
+ "io"
+ "strings"
+ "time"
"github.com/google/generative-ai-go/genai"
"github.com/google/uuid"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
-type geminiProvider struct {
- client *genai.Client
- model models.Model
- maxTokens int32
- apiKey string
- systemMessage string
+type geminiOptions struct {
+ disableCache bool
}
-type GeminiOption func(*geminiProvider)
+type GeminiOption func(*geminiOptions)
-func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) {
- provider := &geminiProvider{
- maxTokens: 5000,
- }
+type geminiClient struct {
+ providerOptions providerClientOptions
+ options geminiOptions
+ client *genai.Client
+}
- for _, opt := range opts {
- opt(provider)
- }
+type GeminiClient ProviderClient
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
+func newGeminiClient(opts providerClientOptions) GeminiClient {
+ geminiOpts := geminiOptions{}
+ for _, o := range opts.geminiOptions {
+ o(&geminiOpts)
}
- client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey))
+ client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey))
if err != nil {
- return nil, err
- }
- provider.client = client
-
- return provider, nil
-}
-
-func WithGeminiSystemMessage(message string) GeminiOption {
- return func(p *geminiProvider) {
- p.systemMessage = message
+ logging.Error("Failed to create Gemini client", "error", err)
+ return nil
}
-}
-func WithGeminiMaxTokens(maxTokens int32) GeminiOption {
- return func(p *geminiProvider) {
- p.maxTokens = maxTokens
+ return &geminiClient{
+ providerOptions: opts,
+ options: geminiOpts,
+ client: client,
}
}
-func WithGeminiModel(model models.Model) GeminiOption {
- return func(p *geminiProvider) {
- p.model = model
- }
-}
-
-func WithGeminiKey(apiKey string) GeminiOption {
- return func(p *geminiProvider) {
- p.apiKey = apiKey
- }
-}
+func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
+ var history []*genai.Content
-func (p *geminiProvider) Close() {
- if p.client != nil {
- p.client.Close()
- }
-}
+ // Add system message first
+ history = append(history, &genai.Content{
+ Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
+ Role: "user",
+ })
-func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
- var history []*genai.Content
+ // Add a system response to acknowledge the system message
+ history = append(history, &genai.Content{
+ Parts: []genai.Part{genai.Text("I'll help you with that.")},
+ Role: "model",
+ })
for _, msg := range messages {
switch msg.Role {
@@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
Parts: []genai.Part{genai.Text(msg.Content().String())},
Role: "user",
})
+
case message.Assistant:
content := &genai.Content{
Role: "model",
@@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
}
history = append(history, content)
+
case message.Tool:
for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
@@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
if err == nil {
response = parsed
}
+
var toolCall message.ToolCall
- for _, msg := range messages {
- if msg.Role == message.Assistant {
- for _, call := range msg.ToolCalls() {
+ for _, m := range messages {
+ if m.Role == message.Assistant {
+ for _, call := range m.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
@@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
return history
}
-func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
- if resp == nil || resp.UsageMetadata == nil {
- return TokenUsage{}
- }
+func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
+ geminiTools := make([]*genai.Tool, 0, len(tools))
- return TokenUsage{
- InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
- OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
- CacheCreationTokens: 0, // Not directly provided by Gemini
- CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
+ for _, tool := range tools {
+ info := tool.Info()
+ declaration := &genai.FunctionDeclaration{
+ Name: info.Name,
+ Description: info.Description,
+ Parameters: &genai.Schema{
+ Type: genai.TypeObject,
+ Properties: convertSchemaProperties(info.Parameters),
+ Required: info.Required,
+ },
+ }
+
+ geminiTools = append(geminiTools, &genai.Tool{
+ FunctionDeclarations: []*genai.FunctionDeclaration{declaration},
+ })
}
+
+ return geminiTools
}
-func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- model := p.client.GenerativeModel(p.model.APIModel)
- model.SetMaxOutputTokens(p.maxTokens)
+func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
+ reasonStr := reason.String()
+ switch {
+ case reasonStr == "STOP":
+ return message.FinishReasonEndTurn
+ case reasonStr == "MAX_TOKENS":
+ return message.FinishReasonMaxTokens
+ case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
+ return message.FinishReasonToolUse
+ default:
+ return message.FinishReasonUnknown
+ }
+}
- model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
+func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
+ model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
+ // Convert tools
if len(tools) > 0 {
- declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
- for _, declaration := range declarations {
- model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
- }
+ model.Tools = g.convertTools(tools)
}
- chat := model.StartChat()
- chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
+ // Convert messages
+ geminiMessages := g.convertMessages(messages)
- lastUserMsg := messages[len(messages)-1]
- resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
- if err != nil {
- return nil, err
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(geminiMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
}
- var content string
- var toolCalls []message.ToolCall
+ attempts := 0
+ for {
+ attempts++
+ chat := model.StartChat()
+ chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
+
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ var lastText string
+ for _, part := range lastMsg.Parts {
+ if text, ok := part.(genai.Text); ok {
+ lastText = string(text)
+ break
+ }
+ }
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- content = string(p)
- case genai.FunctionCall:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
- })
+ resp, err := chat.SendMessage(ctx, genai.Text(lastText))
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := g.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
}
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
}
- }
- tokenUsage := p.extractTokenUsage(resp)
+ content := ""
+ var toolCalls []message.ToolCall
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ switch p := part.(type) {
+ case genai.Text:
+ content = string(p)
+ case genai.FunctionCall:
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(p.Args)
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ })
+ }
+ }
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: toolCalls,
+ Usage: g.usage(resp),
+ FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
+ }, nil
+ }
}
-func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- model := p.client.GenerativeModel(p.model.APIModel)
- model.SetMaxOutputTokens(p.maxTokens)
-
- model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
+func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
+ model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
+ // Convert tools
if len(tools) > 0 {
- declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
- for _, declaration := range declarations {
- model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
- }
+ model.Tools = g.convertTools(tools)
}
- chat := model.StartChat()
- chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
+ // Convert messages
+ geminiMessages := g.convertMessages(messages)
- lastUserMsg := messages[len(messages)-1]
-
- iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(geminiMessages)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
eventChan := make(chan ProviderEvent)
go func() {
defer close(eventChan)
- var finalResp *genai.GenerateContentResponse
- currentContent := ""
- toolCalls := []message.ToolCall{}
-
for {
- resp, err := iter.Next()
- if err == iterator.Done {
- break
- }
- if err != nil {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: err,
+ attempts++
+ chat := model.StartChat()
+ chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
+
+ lastMsg := geminiMessages[len(geminiMessages)-1]
+ var lastText string
+ for _, part := range lastMsg.Parts {
+ if text, ok := part.(genai.Text); ok {
+ lastText = string(text)
+ break
}
- return
}
- finalResp = resp
+ iter := chat.SendMessageStream(ctx, genai.Text(lastText))
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- switch p := part.(type) {
- case genai.Text:
- newText := string(p)
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: newText,
- }
- currentContent += newText
- case genai.FunctionCall:
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(p.Args)
- newCall := message.ToolCall{
- ID: id,
- Name: p.Name,
- Input: string(args),
- Type: "function",
- }
+ currentContent := ""
+ toolCalls := []message.ToolCall{}
+ var finalResp *genai.GenerateContentResponse
- isNew := true
- for _, existing := range toolCalls {
- if existing.Name == newCall.Name && existing.Input == newCall.Input {
- isNew = false
- break
+ eventChan <- ProviderEvent{Type: EventContentStart}
+
+ for {
+ resp, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ retry, after, retryErr := g.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ return
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
}
+
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ break
}
+ } else {
+ eventChan <- ProviderEvent{Type: EventError, Error: err}
+ return
+ }
+ }
+
+ finalResp = resp
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ switch p := part.(type) {
+ case genai.Text:
+ newText := string(p)
+ delta := newText[len(currentContent):]
+ if delta != "" {
+ eventChan <- ProviderEvent{
+ Type: EventContentDelta,
+ Content: delta,
+ }
+ currentContent = newText
+ }
+ case genai.FunctionCall:
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(p.Args)
+ newCall := message.ToolCall{
+ ID: id,
+ Name: p.Name,
+ Input: string(args),
+ Type: "function",
+ }
- if isNew {
- toolCalls = append(toolCalls, newCall)
+ isNew := true
+ for _, existing := range toolCalls {
+ if existing.Name == newCall.Name && existing.Input == newCall.Input {
+ isNew = false
+ break
+ }
+ }
+
+ if isNew {
+ toolCalls = append(toolCalls, newCall)
+ }
}
}
}
}
- }
- tokenUsage := p.extractTokenUsage(finalResp)
+ eventChan <- ProviderEvent{Type: EventContentStop}
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
- },
+ if finalResp != nil {
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: g.usage(finalResp),
+ FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
+ },
+ }
+ return
+ }
+
+ // If we get here, we need to retry
+ if attempts > maxRetries {
+ eventChan <- ProviderEvent{
+ Type: EventError,
+ Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
+ }
+ return
+ }
+
+ // Wait before retrying
+ select {
+ case <-ctx.Done():
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+ }
+ return
+ case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
+ continue
+ }
}
}()
- return eventChan, nil
+ return eventChan
}
-func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
- declarations := make([]*genai.FunctionDeclaration, len(tools))
+func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ // Check if error is a rate limit error
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
- for i, tool := range tools {
- info := tool.Info()
- declarations[i] = &genai.FunctionDeclaration{
- Name: info.Name,
- Description: info.Description,
- Parameters: &genai.Schema{
- Type: genai.TypeObject,
- Properties: convertSchemaProperties(info.Parameters),
- Required: info.Required,
- },
+ // Gemini doesn't have a standard error type we can check against
+ // So we'll check the error message for rate limit indicators
+ if errors.Is(err, io.EOF) {
+ return false, 0, err
+ }
+
+ errMsg := err.Error()
+ isRateLimit := false
+
+ // Check for common rate limit error messages
+ if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
+ isRateLimit = true
+ }
+
+ if !isRateLimit {
+ return false, 0, err
+ }
+
+ // Calculate backoff with jitter
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs := backoffMs + jitterMs
+
+ return true, int64(retryMs), nil
+}
+
+func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
+ var toolCalls []message.ToolCall
+
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
+ for _, part := range resp.Candidates[0].Content.Parts {
+ if funcCall, ok := part.(genai.FunctionCall); ok {
+ id := "call_" + uuid.New().String()
+ args, _ := json.Marshal(funcCall.Args)
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: id,
+ Name: funcCall.Name,
+ Input: string(args),
+ Type: "function",
+ })
+ }
}
}
- return declarations
+ return toolCalls
+}
+
+func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
+ if resp == nil || resp.UsageMetadata == nil {
+ return TokenUsage{}
+ }
+
+ return TokenUsage{
+ InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
+ OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
+ CacheCreationTokens: 0, // Not directly provided by Gemini
+ CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
+ }
+}
+
+func WithGeminiDisableCache() GeminiOption {
+ return func(options *geminiOptions) {
+ options.disableCache = true
+ }
+}
+
+// Helper functions
+func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
+ var result map[string]interface{}
+ err := json.Unmarshal([]byte(jsonStr), &result)
+ return result, err
}
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
@@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type {
}
}
-func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
- var result map[string]interface{}
- err := json.Unmarshal([]byte(jsonStr), &result)
- return result, err
+func contains(s string, substrs ...string) bool {
+ for _, substr := range substrs {
+ if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
+ return true
+ }
+ }
+ return false
}
+
diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go
index dbfde3fa8..9c2ad2012 100644
--- a/internal/llm/provider/openai.go
+++ b/internal/llm/provider/openai.go
@@ -2,89 +2,65 @@ package provider
import (
"context"
+ "encoding/json"
"errors"
+ "fmt"
+ "io"
+ "time"
- "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
)
-type openaiProvider struct {
- client openai.Client
- model models.Model
- maxTokens int64
- baseURL string
- apiKey string
- systemMessage string
+type openaiOptions struct {
+ baseURL string
+ disableCache bool
}
-type OpenAIOption func(*openaiProvider)
+type OpenAIOption func(*openaiOptions)
-func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
- provider := &openaiProvider{
- maxTokens: 5000,
- }
-
- for _, opt := range opts {
- opt(provider)
- }
-
- clientOpts := []option.RequestOption{
- option.WithAPIKey(provider.apiKey),
- }
- if provider.baseURL != "" {
- clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
- }
-
- provider.client = openai.NewClient(clientOpts...)
- if provider.systemMessage == "" {
- return nil, errors.New("system message is required")
- }
-
- return provider, nil
+type openaiClient struct {
+ providerOptions providerClientOptions
+ options openaiOptions
+ client openai.Client
}
-func WithOpenAISystemMessage(message string) OpenAIOption {
- return func(p *openaiProvider) {
- p.systemMessage = message
- }
-}
+type OpenAIClient ProviderClient
-func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
- return func(p *openaiProvider) {
- p.maxTokens = maxTokens
+func newOpenAIClient(opts providerClientOptions) OpenAIClient {
+ openaiOpts := openaiOptions{}
+ for _, o := range opts.openaiOptions {
+ o(&openaiOpts)
}
-}
-func WithOpenAIModel(model models.Model) OpenAIOption {
- return func(p *openaiProvider) {
- p.model = model
+ openaiClientOptions := []option.RequestOption{}
+ if opts.apiKey != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
-}
-
-func WithOpenAIBaseURL(baseURL string) OpenAIOption {
- return func(p *openaiProvider) {
- p.baseURL = baseURL
+ if openaiOpts.baseURL != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
}
-}
-func WithOpenAIKey(apiKey string) OpenAIOption {
- return func(p *openaiProvider) {
- p.apiKey = apiKey
+ client := openai.NewClient(openaiClientOptions...)
+ return &openaiClient{
+ providerOptions: opts,
+ options: openaiOpts,
+ client: client,
}
}
-func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
- var chatMessages []openai.ChatCompletionMessageParamUnion
-
- chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
+func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+ // Add system message first
+ openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
for _, msg := range messages {
switch msg.Role {
case message.User:
- chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
+ openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
@@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
}
}
- chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
+ openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
OfAssistant: &assistantMsg,
})
case message.Tool:
for _, result := range msg.ToolResults() {
- chatMessages = append(chatMessages,
+ openaiMessages = append(openaiMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
}
}
}
- return chatMessages
+ return
}
-func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
+func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
for i, tool := range tools {
@@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C
return openaiTools
}
-func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
- cachedTokens := int64(0)
-
- cachedTokens = usage.PromptTokensDetails.CachedTokens
- inputTokens := usage.PromptTokens - cachedTokens
-
- return TokenUsage{
- InputTokens: inputTokens,
- OutputTokens: usage.CompletionTokens,
- CacheCreationTokens: 0, // OpenAI doesn't provide this directly
- CacheReadTokens: cachedTokens,
+func (o *openaiClient) finishReason(reason string) message.FinishReason {
+ switch reason {
+ case "stop":
+ return message.FinishReasonEndTurn
+ case "length":
+ return message.FinishReasonMaxTokens
+ case "tool_calls":
+ return message.FinishReasonToolUse
+ default:
+ return message.FinishReasonUnknown
}
}
-func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = cleanupMessages(messages)
- chatMessages := p.convertToOpenAIMessages(messages)
- openaiTools := p.convertToOpenAITools(tools)
-
- params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(p.model.APIModel),
- Messages: chatMessages,
- MaxTokens: openai.Int(p.maxTokens),
- Tools: openaiTools,
- }
-
- response, err := p.client.Chat.Completions.New(ctx, params)
- if err != nil {
- return nil, err
+func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+ return openai.ChatCompletionNewParams{
+ Model: openai.ChatModel(o.providerOptions.model.APIModel),
+ Messages: messages,
+ MaxTokens: openai.Int(o.providerOptions.maxTokens),
+ Tools: tools,
}
+}
- content := ""
- if response.Choices[0].Message.Content != "" {
- content = response.Choices[0].Message.Content
+func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
+ params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
}
-
- var toolCalls []message.ToolCall
- if len(response.Choices[0].Message.ToolCalls) > 0 {
- toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
- for i, call := range response.Choices[0].Message.ToolCalls {
- toolCalls[i] = message.ToolCall{
- ID: call.ID,
- Name: call.Function.Name,
- Input: call.Function.Arguments,
- Type: "function",
+ attempts := 0
+ for {
+ attempts++
+ openaiResponse, err := o.client.Chat.Completions.New(
+ ctx,
+ params,
+ )
+ // If there is an error we are going to see if we can retry the call
+ if err != nil {
+ retry, after, retryErr := o.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
}
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
}
- }
- tokenUsage := p.extractTokenUsage(response.Usage)
+ content := ""
+ if openaiResponse.Choices[0].Message.Content != "" {
+ content = openaiResponse.Choices[0].Message.Content
+ }
- return &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- }, nil
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: o.toolCalls(*openaiResponse),
+ Usage: o.usage(*openaiResponse),
+ FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
+ }, nil
+ }
}
-func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- messages = cleanupMessages(messages)
- chatMessages := p.convertToOpenAIMessages(messages)
- openaiTools := p.convertToOpenAITools(tools)
-
- params := openai.ChatCompletionNewParams{
- Model: openai.ChatModel(p.model.APIModel),
- Messages: chatMessages,
- MaxTokens: openai.Int(p.maxTokens),
- Tools: openaiTools,
- StreamOptions: openai.ChatCompletionStreamOptionsParam{
- IncludeUsage: openai.Bool(true),
- },
+func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
+ params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+ IncludeUsage: openai.Bool(true),
}
- stream := p.client.Chat.Completions.NewStreaming(ctx, params)
+ cfg := config.Get()
+ if cfg.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
eventChan := make(chan ProviderEvent)
- toolCalls := make([]message.ToolCall, 0)
go func() {
- defer close(eventChan)
-
- acc := openai.ChatCompletionAccumulator{}
- currentContent := ""
-
- for stream.Next() {
- chunk := stream.Current()
- acc.AddChunk(chunk)
-
- if tool, ok := acc.JustFinishedToolCall(); ok {
- toolCalls = append(toolCalls, message.ToolCall{
- ID: tool.Id,
- Name: tool.Name,
- Input: tool.Arguments,
- Type: "function",
- })
- }
+ for {
+ attempts++
+ openaiStream := o.client.Chat.Completions.NewStreaming(
+ ctx,
+ params,
+ )
+
+ acc := openai.ChatCompletionAccumulator{}
+ currentContent := ""
+ toolCalls := make([]message.ToolCall, 0)
+
+ for openaiStream.Next() {
+ chunk := openaiStream.Current()
+ acc.AddChunk(chunk)
+
+ if tool, ok := acc.JustFinishedToolCall(); ok {
+ toolCalls = append(toolCalls, message.ToolCall{
+ ID: tool.Id,
+ Name: tool.Name,
+ Input: tool.Arguments,
+ Type: "function",
+ })
+ }
- for _, choice := range chunk.Choices {
- if choice.Delta.Content != "" {
- eventChan <- ProviderEvent{
- Type: EventContentDelta,
- Content: choice.Delta.Content,
+ for _, choice := range chunk.Choices {
+ if choice.Delta.Content != "" {
+ eventChan <- ProviderEvent{
+ Type: EventContentDelta,
+ Content: choice.Delta.Content,
+ }
+ currentContent += choice.Delta.Content
}
- currentContent += choice.Delta.Content
}
}
- }
- if err := stream.Err(); err != nil {
- eventChan <- ProviderEvent{
- Type: EventError,
- Error: err,
+ err := openaiStream.Err()
+ if err == nil || errors.Is(err, io.EOF) {
+ // Stream completed successfully
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: o.usage(acc.ChatCompletion),
+ FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
+ },
+ }
+ close(eventChan)
+ return
}
+
+ // If there is an error we are going to see if we can retry the call
+ retry, after, retryErr := o.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
+ return
+ }
+ if retry {
+ logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ // context cancelled
+ if ctx.Err() == nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+ }
+ close(eventChan)
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
return
}
+ }()
- tokenUsage := p.extractTokenUsage(acc.Usage)
+ return eventChan
+}
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
- },
+func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ var apierr *openai.Error
+ if !errors.As(err, &apierr) {
+ return false, 0, err
+ }
+
+ if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
+ return false, 0, err
+ }
+
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
+
+ retryMs := 0
+ retryAfterValues := apierr.Response.Header.Values("Retry-After")
+
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs = backoffMs + jitterMs
+ if len(retryAfterValues) > 0 {
+ if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+ retryMs = retryMs * 1000
}
- }()
+ }
+ return true, int64(retryMs), nil
+}
- return eventChan, nil
+func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
+ var toolCalls []message.ToolCall
+
+ if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
+ for _, call := range completion.Choices[0].Message.ToolCalls {
+ toolCall := message.ToolCall{
+ ID: call.ID,
+ Name: call.Function.Name,
+ Input: call.Function.Arguments,
+ Type: "function",
+ }
+ toolCalls = append(toolCalls, toolCall)
+ }
+ }
+
+ return toolCalls
}
+
+func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
+ cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
+ inputTokens := completion.Usage.PromptTokens - cachedTokens
+
+ return TokenUsage{
+ InputTokens: inputTokens,
+ OutputTokens: completion.Usage.CompletionTokens,
+ CacheCreationTokens: 0, // OpenAI doesn't provide this directly
+ CacheReadTokens: cachedTokens,
+ }
+}
+
+func WithOpenAIBaseURL(baseURL string) OpenAIOption {
+ return func(options *openaiOptions) {
+ options.baseURL = baseURL
+ }
+}
+
+func WithOpenAIDisableCache() OpenAIOption {
+ return func(options *openaiOptions) {
+ options.disableCache = true
+ }
+}
+
diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go
index 34d91f2b7..1a5b3dc8a 100644
--- a/internal/llm/provider/provider.go
+++ b/internal/llm/provider/provider.go
@@ -2,14 +2,17 @@ package provider
import (
"context"
+ "fmt"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)
-// EventType represents the type of streaming event
type EventType string
+const maxRetries = 8
+
const (
EventContentStart EventType = "content_start"
EventContentDelta EventType = "content_delta"
@@ -18,7 +21,6 @@ const (
EventComplete EventType = "complete"
EventError EventType = "error"
EventWarning EventType = "warning"
- EventInfo EventType = "info"
)
type TokenUsage struct {
@@ -32,61 +34,152 @@ type ProviderResponse struct {
Content string
ToolCalls []message.ToolCall
Usage TokenUsage
- FinishReason string
+ FinishReason message.FinishReason
}
type ProviderEvent struct {
- Type EventType
+ Type EventType
+
Content string
Thinking string
- ToolCall *message.ToolCall
- Error error
Response *ProviderResponse
- // Used for giving users info on e.x retry
- Info string
+ Error error
}
-
type Provider interface {
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
- StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error)
+ StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+
+ Model() models.Model
+}
+
+type providerClientOptions struct {
+ apiKey string
+ model models.Model
+ maxTokens int64
+ systemMessage string
+
+ anthropicOptions []AnthropicOption
+ openaiOptions []OpenAIOption
+ geminiOptions []GeminiOption
+ bedrockOptions []BedrockOption
+}
+
+type ProviderClientOption func(*providerClientOptions)
+
+type ProviderClient interface {
+ send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
+ stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+}
+
+type baseProvider[C ProviderClient] struct {
+ options providerClientOptions
+ client C
+}
+
+func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
+ clientOptions := providerClientOptions{}
+ for _, o := range opts {
+ o(&clientOptions)
+ }
+ switch providerName {
+ case models.ProviderAnthropic:
+ return &baseProvider[AnthropicClient]{
+ options: clientOptions,
+ client: newAnthropicClient(clientOptions),
+ }, nil
+ case models.ProviderOpenAI:
+ return &baseProvider[OpenAIClient]{
+ options: clientOptions,
+ client: newOpenAIClient(clientOptions),
+ }, nil
+ case models.ProviderGemini:
+ return &baseProvider[GeminiClient]{
+ options: clientOptions,
+ client: newGeminiClient(clientOptions),
+ }, nil
+ case models.ProviderBedrock:
+ return &baseProvider[BedrockClient]{
+ options: clientOptions,
+ client: newBedrockClient(clientOptions),
+ }, nil
+ case models.ProviderMock:
+ // TODO: implement mock client for test
+ panic("not implemented")
+ }
+ return nil, fmt.Errorf("provider not supported: %s", providerName)
}
-func cleanupMessages(messages []message.Message) []message.Message {
- // First pass: filter out canceled messages
- var cleanedMessages []message.Message
+func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
- if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
- // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
- // cancelled
- cleanedMessages = append(cleanedMessages, msg)
+ // The message has no content
+ if len(msg.Parts) == 0 {
+ continue
}
+ cleaned = append(cleaned, msg)
}
+ return
+}
- // Second pass: filter out tool messages without a corresponding tool call
- var result []message.Message
- toolMessageIDs := make(map[string]bool)
+func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ messages = p.cleanMessages(messages)
+ return p.client.send(ctx, messages, tools)
+}
- for _, msg := range cleanedMessages {
- if msg.Role == message.Assistant {
- for _, toolCall := range msg.ToolCalls() {
- toolMessageIDs[toolCall.ID] = true // Mark as referenced
- }
- }
+func (p *baseProvider[C]) Model() models.Model {
+ return p.options.model
+}
+
+func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ messages = p.cleanMessages(messages)
+ return p.client.stream(ctx, messages, tools)
+}
+
+func WithAPIKey(apiKey string) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.apiKey = apiKey
}
+}
- // Keep only messages that aren't unreferenced tool messages
- for _, msg := range cleanedMessages {
- if msg.Role == message.Tool {
- for _, toolCall := range msg.ToolResults() {
- if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced {
- result = append(result, msg)
- }
- }
- } else {
- result = append(result, msg)
- }
+func WithModel(model models.Model) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.model = model
+ }
+}
+
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.maxTokens = maxTokens
+ }
+}
+
+func WithSystemMessage(systemMessage string) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.systemMessage = systemMessage
+ }
+}
+
+func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.anthropicOptions = anthropicOptions
+ }
+}
+
+func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.openaiOptions = openaiOptions
+ }
+}
+
+func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.geminiOptions = geminiOptions
+ }
+}
+
+func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.bedrockOptions = bedrockOptions
}
- return result
}
diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go
index 0cea20878..c7c970e5a 100644
--- a/internal/llm/tools/bash.go
+++ b/internal/llm/tools/bash.go
@@ -23,7 +23,8 @@ type BashPermissionsParams struct {
}
type BashResponseMetadata struct {
- Took int64 `json:"took"`
+ StartTime int64 `json:"start_time"`
+ EndTime int64 `json:"end_time"`
}
type bashTool struct {
permissions permission.Service
@@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
}
- took := time.Since(startTime).Milliseconds()
stdout = truncateOutput(stdout)
stderr = truncateOutput(stderr)
@@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
metadata := BashResponseMetadata{
- Took: took,
+ StartTime: startTime.UnixMilli(),
+ EndTime: time.Now().UnixMilli(),
}
if stdout == "" {
return WithResponseMetadata(NewTextResponse("no output"), metadata), nil
diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go
index 97be3683a..dafb0ccc5 100644
--- a/internal/llm/tools/bash_test.go
+++ b/internal/llm/tools/bash_test.go
@@ -8,8 +8,6 @@ import (
"testing"
"time"
- "github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) {
})
}
}
-
-// Mock permission service for testing
-type mockPermissionService struct {
- *pubsub.Broker[permission.PermissionRequest]
- allow bool
-}
-
-func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
- // Not needed for tests
-}
-
-func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
- return m.allow
-}
-
-func newMockPermissionService(allow bool) permission.Service {
- return &mockPermissionService{
- Broker: pubsub.NewBroker[permission.PermissionRequest](),
- allow: allow,
- }
-}
diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go
index 08d6d446c..148e7aba7 100644
--- a/internal/llm/tools/edit.go
+++ b/internal/llm/tools/edit.go
@@ -11,6 +11,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/diff"
+ "github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
@@ -35,6 +36,7 @@ type EditResponseMetadata struct {
type editTool struct {
lspClients map[string]*lsp.Client
permissions permission.Service
+ files history.Service
}
const (
@@ -88,10 +90,11 @@ When making edits:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
)
-func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
+func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
return &editTool{
lspClients: lspClients,
permissions: permissions,
+ files: files,
}
}
@@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return response, nil
}
+ if response.IsError {
+ // Return early if there was an error during content replacement
+ // This prevents unnecessary LSP diagnostics processing
+ return response, nil
+ }
waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
@@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+ // File can't be in the history so we create a new file history
+ _, err = e.files.Create(ctx, sessionID, filePath, "")
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+
+ // Add the new content to the file history
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
+ if err != nil {
+ // Log error but don't fail the operation
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
@@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
if err != nil {
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+
+ // Check if file exists in history
+ file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
@@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
+ if oldContent == newContent {
+ return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
+ }
sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
@@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
Description: fmt.Sprintf("Replace content in file %s", filePath),
Params: EditPermissionsParams{
FilePath: filePath,
-
- Diff: diff,
+ Diff: diff,
},
},
)
@@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
}
+ // Check if file exists in history
+ file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go
index 48a34ed75..0971775dd 100644
--- a/internal/llm/tools/edit_test.go
+++ b/internal/llm/tools/edit_test.go
@@ -14,7 +14,7 @@ import (
)
func TestEditTool_Info(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
info := tool.Info()
assert.Equal(t, EditToolName, info.Name)
@@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) {
defer os.RemoveAll(tempDir)
t.Run("creates a new file successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
@@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("creates file with nested directories", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
@@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("fails to create file that already exists", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
@@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("fails to create file when path is a directory", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
@@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("replaces content successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "replace_content.txt")
@@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("deletes content successfully", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "delete_content.txt")
@@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles invalid parameters", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
call := ToolCall{
Name: EditToolName,
@@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles missing file_path", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := EditParams{
FilePath: "",
@@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file not found", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "non_existent_file.txt")
params := EditParams{
@@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles old_string not found in file", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "content_not_found.txt")
@@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles multiple occurrences of old_string", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file with duplicate content
filePath := filepath.Join(tempDir, "duplicate_content.txt")
@@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file modified since last read", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
@@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles file not read before editing", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "not_read_file.txt")
@@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) {
})
t.Run("handles permission denied", func(t *testing.T) {
- tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false))
+ tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "permission_denied.txt")
diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go
index 9c9707c9c..7f34fdc1f 100644
--- a/internal/llm/tools/file.go
+++ b/internal/llm/tools/file.go
@@ -3,8 +3,6 @@ package tools
import (
"sync"
"time"
-
- "github.com/kujtimiihoxha/termai/internal/config"
)
// File record to track when files were read/written
@@ -19,14 +17,6 @@ var (
fileRecordMutex sync.RWMutex
)
-func removeWorkingDirectoryPrefix(path string) string {
- wd := config.WorkingDirectory()
- if len(path) > len(wd) && path[:len(wd)] == wd {
- return path[len(wd)+1:]
- }
- return path
-}
-
func recordFileRead(path string) {
fileRecordMutex.Lock()
defer fileRecordMutex.Unlock()
diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go
index bdfc23b4a..7b4fb1187 100644
--- a/internal/llm/tools/glob.go
+++ b/internal/llm/tools/glob.go
@@ -63,7 +63,7 @@ type GlobParams struct {
Path string `json:"path"`
}
-type GlobMetadata struct {
+type GlobResponseMetadata struct {
NumberOfFiles int `json:"number_of_files"`
Truncated bool `json:"truncated"`
}
@@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return WithResponseMetadata(
NewTextResponse(output),
- GlobMetadata{
+ GlobResponseMetadata{
NumberOfFiles: len(files),
Truncated: truncated,
},
diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go
index 7e52821d0..19333f50b 100644
--- a/internal/llm/tools/grep.go
+++ b/internal/llm/tools/grep.go
@@ -27,7 +27,7 @@ type grepMatch struct {
modTime time.Time
}
-type GrepMetadata struct {
+type GrepResponseMetadata struct {
NumberOfMatches int `json:"number_of_matches"`
Truncated bool `json:"truncated"`
}
@@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return WithResponseMetadata(
NewTextResponse(output),
- GrepMetadata{
+ GrepResponseMetadata{
NumberOfMatches: len(matches),
Truncated: truncated,
},
diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go
index a679f261b..a63bf0eeb 100644
--- a/internal/llm/tools/ls.go
+++ b/internal/llm/tools/ls.go
@@ -23,7 +23,7 @@ type TreeNode struct {
Children []*TreeNode `json:"children,omitempty"`
}
-type LSMetadata struct {
+type LSResponseMetadata struct {
NumberOfFiles int `json:"number_of_files"`
Truncated bool `json:"truncated"`
}
@@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
return WithResponseMetadata(
NewTextResponse(output),
- LSMetadata{
+ LSResponseMetadata{
NumberOfFiles: len(files),
Truncated: truncated,
},
diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go
new file mode 100644
index 000000000..321f09ac1
--- /dev/null
+++ b/internal/llm/tools/mocks_test.go
@@ -0,0 +1,246 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/pubsub"
+)
+
+// Mock permission service for testing
+type mockPermissionService struct {
+ *pubsub.Broker[permission.PermissionRequest]
+ allow bool
+}
+
+func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
+ // Not needed for tests
+}
+
+func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
+ return m.allow
+}
+
+func newMockPermissionService(allow bool) permission.Service {
+ return &mockPermissionService{
+ Broker: pubsub.NewBroker[permission.PermissionRequest](),
+ allow: allow,
+ }
+}
+
+type mockFileHistoryService struct {
+ *pubsub.Broker[history.File]
+ files map[string]history.File // ID -> File
+ timeNow func() int64
+}
+
+// Create implements history.Service.
+func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
+ return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion)
+}
+
+// CreateVersion implements history.Service.
+func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) {
+ var files []history.File
+ for _, file := range m.files {
+ if file.Path == path {
+ files = append(files, file)
+ }
+ }
+
+ if len(files) == 0 {
+ // No previous versions, create initial
+ return m.Create(ctx, sessionID, path, content)
+ }
+
+ // Sort files by CreatedAt in descending order
+ sort.Slice(files, func(i, j int) bool {
+ return files[i].CreatedAt > files[j].CreatedAt
+ })
+
+ // Get the latest version
+ latestFile := files[0]
+ latestVersion := latestFile.Version
+
+ // Generate the next version
+ var nextVersion string
+ if latestVersion == history.InitialVersion {
+ nextVersion = "v1"
+ } else if strings.HasPrefix(latestVersion, "v") {
+ versionNum, err := strconv.Atoi(latestVersion[1:])
+ if err != nil {
+ // If we can't parse the version, just use a timestamp-based version
+ nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
+ } else {
+ nextVersion = fmt.Sprintf("v%d", versionNum+1)
+ }
+ } else {
+ // If the version format is unexpected, use a timestamp-based version
+ nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt)
+ }
+
+ return m.createWithVersion(ctx, sessionID, path, content, nextVersion)
+}
+
+func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) {
+ now := m.timeNow()
+ file := history.File{
+ ID: uuid.New().String(),
+ SessionID: sessionID,
+ Path: path,
+ Content: content,
+ Version: version,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ m.files[file.ID] = file
+ m.Publish(pubsub.CreatedEvent, file)
+ return file, nil
+}
+
+// Delete implements history.Service.
+func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error {
+ file, ok := m.files[id]
+ if !ok {
+ return fmt.Errorf("file not found: %s", id)
+ }
+
+ delete(m.files, id)
+ m.Publish(pubsub.DeletedEvent, file)
+ return nil
+}
+
+// DeleteSessionFiles implements history.Service.
+func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error {
+ files, err := m.ListBySession(ctx, sessionID)
+ if err != nil {
+ return err
+ }
+
+ for _, file := range files {
+ err = m.Delete(ctx, file.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Get implements history.Service.
+func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) {
+ file, ok := m.files[id]
+ if !ok {
+ return history.File{}, fmt.Errorf("file not found: %s", id)
+ }
+ return file, nil
+}
+
+// GetByPathAndSession implements history.Service.
+func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) {
+ var latestFile history.File
+ var found bool
+ var latestTime int64
+
+ for _, file := range m.files {
+ if file.Path == path && file.SessionID == sessionID {
+ if !found || file.CreatedAt > latestTime {
+ latestFile = file
+ latestTime = file.CreatedAt
+ found = true
+ }
+ }
+ }
+
+ if !found {
+ return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID)
+ }
+ return latestFile, nil
+}
+
+// ListBySession implements history.Service.
+func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) {
+ var files []history.File
+ for _, file := range m.files {
+ if file.SessionID == sessionID {
+ files = append(files, file)
+ }
+ }
+
+ // Sort by CreatedAt in descending order
+ sort.Slice(files, func(i, j int) bool {
+ return files[i].CreatedAt > files[j].CreatedAt
+ })
+
+ return files, nil
+}
+
+// ListLatestSessionFiles implements history.Service.
+func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) {
+ // Map to track the latest file for each path
+ latestFiles := make(map[string]history.File)
+
+ for _, file := range m.files {
+ if file.SessionID == sessionID {
+ existing, ok := latestFiles[file.Path]
+ if !ok || file.CreatedAt > existing.CreatedAt {
+ latestFiles[file.Path] = file
+ }
+ }
+ }
+
+ // Convert map to slice
+ var result []history.File
+ for _, file := range latestFiles {
+ result = append(result, file)
+ }
+
+ // Sort by CreatedAt in descending order
+ sort.Slice(result, func(i, j int) bool {
+ return result[i].CreatedAt > result[j].CreatedAt
+ })
+
+ return result, nil
+}
+
+// Subscribe implements history.Service.
+func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] {
+ return m.Broker.Subscribe(ctx)
+}
+
+// Update implements history.Service.
+func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) {
+ _, ok := m.files[file.ID]
+ if !ok {
+ return history.File{}, fmt.Errorf("file not found: %s", file.ID)
+ }
+
+ file.UpdatedAt = m.timeNow()
+ m.files[file.ID] = file
+ m.Publish(pubsub.UpdatedEvent, file)
+ return file, nil
+}
+
+func newMockFileHistoryService() history.Service {
+ return &mockFileHistoryService{
+ Broker: pubsub.NewBroker[history.File](),
+ files: make(map[string]history.File),
+ timeNow: func() int64 { return time.Now().Unix() },
+ }
+}
diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go
index 64592f67d..4a776478a 100644
--- a/internal/llm/tools/shell/shell.go
+++ b/internal/llm/tools/shell/shell.go
@@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell {
commandQueue: make(chan *commandExecution, 10),
}
- go shell.processCommands()
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
+ shell.isAlive = false
+ close(shell.commandQueue)
+ }
+ }()
+ shell.processCommands()
+ }()
go func() {
err := cmd.Wait()
if err != nil {
+ // Log the error if needed
}
shell.isAlive = false
close(shell.commandQueue)
diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go
index 17bc610ea..a6f2c8afb 100644
--- a/internal/llm/tools/sourcegraph.go
+++ b/internal/llm/tools/sourcegraph.go
@@ -18,7 +18,7 @@ type SourcegraphParams struct {
Timeout int `json:"timeout,omitempty"`
}
-type SourcegraphMetadata struct {
+type SourcegraphResponseMetadata struct {
NumberOfMatches int `json:"number_of_matches"`
Truncated bool `json:"truncated"`
}
diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go
index 07afe1363..bf0f8df0b 100644
--- a/internal/llm/tools/tools.go
+++ b/internal/llm/tools/tools.go
@@ -14,12 +14,17 @@ type ToolInfo struct {
type toolResponseType string
+type (
+ sessionIDContextKey string
+ messageIDContextKey string
+)
+
const (
ToolResponseTypeText toolResponseType = "text"
ToolResponseTypeImage toolResponseType = "image"
- SessionIDContextKey = "session_id"
- MessageIDContextKey = "message_id"
+ SessionIDContextKey sessionIDContextKey = "session_id"
+ MessageIDContextKey messageIDContextKey = "message_id"
)
type ToolResponse struct {
diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go
index 889561d2a..bb49381fd 100644
--- a/internal/llm/tools/write.go
+++ b/internal/llm/tools/write.go
@@ -10,6 +10,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/diff"
+ "github.com/kujtimiihoxha/termai/internal/history"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
@@ -27,6 +28,7 @@ type WritePermissionsParams struct {
type writeTool struct {
lspClients map[string]*lsp.Client
permissions permission.Service
+ files history.Service
}
type WriteResponseMetadata struct {
@@ -67,10 +69,11 @@ TIPS:
- Always include descriptive comments when making changes to existing code`
)
-func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool {
+func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
return &writeTool{
lspClients: lspClients,
permissions: permissions,
+ files: files,
}
}
@@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
}
+ // Check if file exists in history
+ file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
+ if err != nil {
+ _, err = w.files.Create(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ // Log error but don't fail the operation
+ return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
+ }
+ }
+ if file.Content != oldContent {
+ // User Manually changed the content store an intermediate version
+ _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+ }
+ // Store the new version
+ _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
+ if err != nil {
+ fmt.Printf("Error creating file history version: %v\n", err)
+ }
+
recordFileWrite(filePath)
recordFileRead(filePath)
waitForLspDiagnostics(ctx, filePath, w.lspClients)
diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go
index 50dafc14f..2264f36fb 100644
--- a/internal/llm/tools/write_test.go
+++ b/internal/llm/tools/write_test.go
@@ -14,7 +14,7 @@ import (
)
func TestWriteTool_Info(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
info := tool.Info()
assert.Equal(t, WriteToolName, info.Name)
@@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) {
defer os.RemoveAll(tempDir)
t.Run("creates a new file successfully", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
@@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("creates file with nested directories", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
@@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("updates existing file", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
@@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles invalid parameters", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
call := ToolCall{
Name: WriteToolName,
@@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles missing file_path", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := WriteParams{
FilePath: "",
@@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles missing content", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
params := WriteParams{
FilePath: filepath.Join(tempDir, "file.txt"),
@@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles writing to a directory path", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
@@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("handles permission denied", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService())
filePath := filepath.Join(tempDir, "permission_denied.txt")
params := WriteParams{
@@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("detects file modified since last read", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
@@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) {
})
t.Run("skips writing when content is identical", func(t *testing.T) {
- tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true))
+ tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService())
// Create a file
filePath := filepath.Join(tempDir, "identical_content.txt")