diff options
Diffstat (limited to 'internal/llm')
31 files changed, 2037 insertions, 1541 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 83160bb64..308412bde 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) + agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - err = agent.Generate(ctx, session.ID, params.Prompt) + done, err := agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } - - messages, err := b.messages.List(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) - } - - if len(messages) == 0 { - return tools.NewTextErrorResponse("no response"), nil + result := <-done + if result.Err() != nil { + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err()) } - response := messages[len(messages)-1] + response := result.Response() if response.Role != message.Assistant { return tools.NewTextErrorResponse("no response"), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 1958111a1..ab2742ec1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "runtime/debug" "strings" "sync" @@ -16,133 +14,101 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/session" ) // Common errors var ( - ErrProviderNotEnabled = errors.New("provider is not enabled") - ErrRequestCancelled = errors.New("request cancelled by user") - ErrSessionBusy = errors.New("session is currently processing another request") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") ) -// Service defines the interface for generating responses +type AgentEvent struct { + message message.Message + err error +} + +func (e *AgentEvent) Err() error { + return e.err +} + +func (e *AgentEvent) Response() message.Message { + return e.message +} + type Service interface { - Generate(ctx context.Context, sessionID string, content string) error - Cancel(sessionID string) error + Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) + Cancel(sessionID string) + IsSessionBusy(sessionID string) bool } type agent struct { - sessions session.Service - messages message.Service - model models.Model - tools []tools.BaseTool - agent provider.Provider - titleGenerator provider.Provider - activeRequests sync.Map // map[sessionID]context.CancelFunc + sessions session.Service + messages message.Service + + tools []tools.BaseTool + provider provider.Provider + + titleProvider provider.Provider + + activeRequests sync.Map } -// NewAgent creates a new agent instance with the given model and tools -func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) +func NewAgent( + agentName config.AgentName, + sessions session.Service, + messages message.Service, + agentTools []tools.BaseTool, +) (Service, error) { + agentProvider, err := createAgentProvider(agentName) if err != nil { - return nil, fmt.Errorf("failed to initialize providers: %w", err) + return nil, err + } + var titleProvider provider.Provider + // Only generate titles for the coder agent + if agentName == config.AgentCoder { + titleProvider, err = createAgentProvider(config.AgentTitle) + if err != nil { + return nil, err + } } - return &agent{ - model: model, - tools: tools, - sessions: sessions, + agent := &agent{ + provider: agentProvider, messages: messages, - agent: agentProvider, - titleGenerator: titleGenerator, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, activeRequests: sync.Map{}, - }, nil + } + + return agent, nil } -// Cancel cancels an active request by session ID -func (a *agent) Cancel(sessionID string) error { +func (a *agent) Cancel(sessionID string) { if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) cancel() - return nil } } - return errors.New("no active request found for this session") } -// Generate starts the generation process -func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { - // Check if this session already has an active request - if _, busy := a.activeRequests.Load(sessionID); busy { - return ErrSessionBusy - } - - // Create a cancellable context - genCtx, cancel := context.WithCancel(ctx) - - // Store cancel function to allow user cancellation - a.activeRequests.Store(sessionID, cancel) - - // Launch the generation in a goroutine - go func() { - defer func() { - if r := recover(); r != nil { - logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) - - // dump stack trace into a file - file, err := os.Create("panic.log") - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) - return - } - - defer file.Close() - - stackTrace := debug.Stack() - if _, err := file.Write(stackTrace); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) - } - - } - }() - defer a.activeRequests.Delete(sessionID) - defer cancel() - - if err := a.generate(genCtx, sessionID, content); err != nil { - if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { - // Log the error (avoid logging cancellations as they're expected) - logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) - - // You may want to create an error message in the chat - bgCtx := context.Background() - errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) - _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ - Role: message.System, - Parts: []message.ContentPart{ - message.TextContent{ - Text: errorMsg, - }, - }, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) - } - } - } - }() - - return nil -} - -// IsSessionBusy checks if a session currently has an active request func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy -} // handleTitleGeneration asynchronously generates a title for new sessions -func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := a.titleGenerator.SendMessages( +} + +func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { + if a.titleProvider == nil { + return nil + } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return err + } + response, err := a.titleProvider.SendMessages( ctx, []message.Message{ { @@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st }, }, }, - nil, + make([]tools.BaseTool, 0), ) if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) - return + return err } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) - return + title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " ")) + if title == "" { + return nil } - if response.Content != "" { - session.Title = strings.TrimSpace(response.Content) - session.Title = strings.ReplaceAll(session.Title, "\n", " ") - if _, err := a.sessions.Save(ctx, session); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) - } + session.Title = title + _, err = a.sessions.Save(ctx, session) + return err +} + +func (a *agent) err(err error) AgentEvent { + return AgentEvent{ + err: err, } } -// TrackUsage updates token usage statistics for the session -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) +func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) { + events := make(chan AgentEvent) + if a.IsSessionBusy(sessionID) { + return nil, ErrSessionBusy } - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) + genCtx, cancel := context.WithCancel(ctx) + + a.activeRequests.Store(sessionID, cancel) + go func() { + logging.Debug("Request started", "sessionID", sessionID) + defer logging.RecoverPanic("agent.Run", func() { + events <- a.err(fmt.Errorf("panic while running the agent")) + }) - session.Cost += cost - session.CompletionTokens += usage.OutputTokens - session.PromptTokens += usage.InputTokens + result := a.processGeneration(genCtx, sessionID, content) + if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) { + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result)) + } + logging.Debug("Request completed", "sessionID", sessionID) + a.activeRequests.Delete(sessionID) + cancel() + events <- result + close(events) + }() + return events, nil +} - _, err = a.sessions.Save(ctx, session) +func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent { + // List existing messages; if none, start title generation asynchronously. + msgs, err := a.messages.List(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to save session: %w", err) + return a.err(fmt.Errorf("failed to list messages: %w", err)) + } + if len(msgs) == 0 { + go func() { + defer logging.RecoverPanic("agent.Run", func() { + logging.ErrorPersist("panic while generating title") + }) + titleErr := a.generateTitle(context.Background(), sessionID, content) + if titleErr != nil { + logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr)) + } + }() } - return nil -} -// processEvent handles different types of events during generation -func (a *agent) processEvent( - ctx context.Context, - sessionID string, - assistantMsg *message.Message, - event provider.ProviderEvent, -) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing + userMsg, err := a.createUserMessage(ctx, sessionID, content) + if err != nil { + return a.err(fmt.Errorf("failed to create user message: %w", err)) } - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - if errors.Is(event.Error, context.Canceled) { - logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) - return context.Canceled + // Append the new user message to the conversation history. + msgHistory := append(msgs, userMsg) + for { + // Check for cancellation before each iteration + select { + case <-ctx.Done(): + return a.err(ctx.Err()) + default: + // Continue processing } - logging.ErrorPersist(event.Error.Error()) - return event.Error - case provider.EventWarning: - logging.WarnPersist(event.Info) - case provider.EventInfo: - logging.InfoPersist(event.Info) - case provider.EventComplete: - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason) - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) + if err != nil { + if errors.Is(err, context.Canceled) { + return a.err(ErrRequestCancelled) + } + return a.err(fmt.Errorf("failed to process events: %w", err)) + } + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { + // We are not done, we need to respond with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } + return AgentEvent{ + message: agentMessage, } - return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } +} - return nil +func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) { + return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: content}, + }, + }) } -// ExecuteTools runs all tool calls sequentially and returns the results -func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - toolResults := make([]message.ToolResult, len(toolCalls)) +func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + }) + if err != nil { + return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) + } - // Create a child context that can be canceled - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Add the session and message ID into the context if needed by tools. + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - // Check if already canceled before starting any execution - if ctx.Err() != nil { - // Mark all tools as canceled - for i, toolCall := range toolCalls { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled by user", - IsError: true, - } + // Process each event in the stream. + for event := range eventChan { + if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { + a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, processErr + } + if ctx.Err() != nil { + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, ctx.Err() } - return toolResults, ctx.Err() } + toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) + toolCalls := assistantMsg.ToolCalls() for i, toolCall := range toolCalls { - // Check for cancellation before executing each tool select { case <-ctx.Done(): - // Mark this and all remaining tools as canceled + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + // Make all future tool calls cancelled for j := i; j < len(toolCalls); j++ { toolResults[j] = message.ToolResult{ ToolCallID: toolCalls[j].ID, @@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, IsError: true, } } - return toolResults, ctx.Err() + goto out default: // Continue processing - } - - response := "" - isError := false - found := false - - // Find and execute the appropriate tool - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled by user" - } else { - response = fmt.Sprintf("Error running tool: %s", toolErr) - } - isError = true - } else { - response = toolResult.Content - isError = toolResult.IsError + var tool tools.BaseTool + for _, availableTools := range a.tools { + if availableTools.Info().Name == toolCall.Name { + tool = availableTools } - break } - } - - if !found { - response = fmt.Sprintf("Tool not found: %s", toolCall.Name) - isError = true - } - - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - } - return toolResults, nil -} - -// handleToolExecution processes tool calls and creates tool result messages -func (a *agent) handleToolExecution( - ctx context.Context, - assistantMsg message.Message, -) (*message.Message, error) { - select { - case <-ctx.Done(): - // If cancelled, create tool results that indicate cancellation - if len(assistantMsg.ToolCalls()) > 0 { - toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) - for _, tc := range assistantMsg.ToolCalls() { - toolResults = append(toolResults, message.ToolResult{ - ToolCallID: tc.ID, - Content: "Tool execution canceled by user", + // Tool not found + if tool == nil { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), IsError: true, - }) + } + continue } - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) - } - msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, }) - if err != nil { - return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - return &msg, ctx.Err() - } - return nil, ctx.Err() - default: - // Continue processing - } - - if len(assistantMsg.ToolCalls()) == 0 { - return nil, nil - } - - toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) - if err != nil { - // If error is from cancellation, still return the partial results we have - if errors.Is(err, context.Canceled) { - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) + if toolErr != nil { + if errors.Is(toolErr, permission.ErrorPermissionDenied) { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Permission denied", + IsError: true, + } + for j := i + 1; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) + } else { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolErr.Error(), + IsError: true, + } + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Previous tool failed", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + } + // If permission is denied or an error happens we cancel all the following tools + break } - - msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) - return nil, err + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolResult.Content, + Metadata: toolResult.Metadata, + IsError: toolResult.IsError, } - return &msg, err } - return nil, err } - - parts := make([]message.ContentPart, 0, len(toolResults)) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) +out: + if len(toolResults) == 0 { + return assistantMsg, nil, nil } - - msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + parts := make([]message.ContentPart, 0) + for _, tr := range toolResults { + parts = append(parts, tr) + } + msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) if err != nil { - return nil, fmt.Errorf("failed to create tool message: %w", err) + return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) } - return &msg, nil + return assistantMsg, &msg, err } -// generate handles the main generation workflow -func (a *agent) generate(ctx context.Context, sessionID string, content string) error { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) +func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) { + msg.AddFinish(finishReson) + _ = a.messages.Update(ctx, *msg) +} - // Handle context cancellation at any point - if err := ctx.Err(); err != nil { - return ErrRequestCancelled +func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing. } - messages, err := a.messages.List(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to list messages: %w", err) + switch event.Type { + case provider.EventThinkingDelta: + assistantMsg.AppendReasoningContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventContentDelta: + assistantMsg.AppendContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventError: + if errors.Is(event.Error, context.Canceled) { + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled + } + logging.ErrorPersist(event.Error.Error()) + return event.Error + case provider.EventComplete: + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) } - if len(messages) == 0 { - titleCtx := context.Background() - go a.handleTitleGeneration(titleCtx, sessionID, content) - } + return nil +} - userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: content, - }, - }, - }) +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + sess, err := a.sessions.Get(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to create user message: %w", err) + return fmt.Errorf("failed to get session: %w", err) } - messages = append(messages, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - // Continue processing - } - - eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) - if err != nil { - if errors.Is(err, context.Canceled) { - return ErrRequestCancelled - } - return fmt.Errorf("failed to stream response: %w", err) - } - - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.model.ID, - }) - if err != nil { - return fmt.Errorf("failed to create assistant message: %w", err) - } - - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - - // Process events from the LLM provider - for event := range eventChan { - if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { - if errors.Is(err, context.Canceled) { - // Mark as canceled but don't create separate message - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - assistantMsg.AddFinish("error:" + err.Error()) - _ = a.messages.Update(ctx, assistantMsg) - return fmt.Errorf("event processing error: %w", err) - } - - // Check for cancellation during event processing - select { - case <-ctx.Done(): - // Mark as canceled - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - } - - // Check for cancellation before tool execution - select { - case <-ctx.Done(): - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - - // Execute any tool calls - toolMsg, err := a.handleToolExecution(ctx, assistantMsg) - if err != nil { - if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - return fmt.Errorf("tool execution error: %w", err) - } - - if err := a.messages.Update(ctx, assistantMsg); err != nil { - return fmt.Errorf("failed to update assistant message: %w", err) - } - - // If no tool calls, we're done - if len(assistantMsg.ToolCalls()) == 0 { - break - } + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) - // Add messages for next iteration - messages = append(messages, assistantMsg) - if toolMsg != nil { - messages = append(messages, *toolMsg) - } + sess.Cost += cost + sess.CompletionTokens += usage.OutputTokens + sess.PromptTokens += usage.InputTokens - // Check for cancellation after tool execution - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - } + _, err = a.sessions.Save(ctx, sess) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) } - return nil } -// getAgentProviders initializes the LLM providers based on the chosen model -func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { - maxTokens := config.Get().Model.CoderMaxTokens - - providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || providerConfig.Disabled { - return nil, nil, ErrProviderNotEnabled +func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { + cfg := config.Get() + agentConfig, ok := cfg.Agents[agentName] + if !ok { + return nil, fmt.Errorf("agent %s not found", agentName) + } + model, ok := models.SupportedModels[agentConfig.Model] + if !ok { + return nil, fmt.Errorf("model %s not supported", agentConfig.Model) } - var agentProvider provider.Provider - var titleGenerator provider.Provider - var err error - - switch model.Provider { - case models.ProviderOpenAI: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) - } - - case models.ProviderAnthropic: - agentProvider, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithAnthropicMaxTokens(maxTokens), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) - } - - titleGenerator, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithAnthropicMaxTokens(80), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) - } - - case models.ProviderGemini: - agentProvider, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithGeminiMaxTokens(int32(maxTokens)), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) - } - - titleGenerator, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithGeminiMaxTokens(80), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) - } - - case models.ProviderGROQ: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) - } - - case models.ProviderBedrock: - agentProvider, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithBedrockMaxTokens(maxTokens), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) - } - - titleGenerator, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithBedrockMaxTokens(80), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) - } - default: - return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) + providerCfg, ok := cfg.Providers[model.Provider] + if !ok { + return nil, fmt.Errorf("provider %s not supported", model.Provider) + } + if providerCfg.Disabled { + return nil, fmt.Errorf("provider %s is not enabled", model.Provider) + } + agentProvider, err := provider.NewProvider( + model.Provider, + provider.WithAPIKey(providerCfg.APIKey), + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), + provider.WithMaxTokens(agentConfig.MaxTokens), + ) + if err != nil { + return nil, fmt.Errorf("could not create provider: %v", err) } - return agentProvider, titleGenerator, nil + return agentProvider, nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go deleted file mode 100644 index a3db6b55c..000000000 --- a/internal/llm/agent/coder.go +++ /dev/null @@ -1,63 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type coderAgent struct { - Service -} - -func NewCoderAgent( - permissions permission.Service, - sessions session.Service, - messages message.Service, - lspClients map[string]*lsp.Client, -) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ), - ) - if err != nil { - return nil, err - } - - return &coderAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index b1c97b512..c7ea4916c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } @@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go deleted file mode 100644 index fca1f223f..000000000 --- a/internal/llm/agent/task.go +++ /dev/null @@ -1,47 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type taskAgent struct { - Service -} - -func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - ) - if err != nil { - return nil, err - } - - return &taskAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 000000000..a37f1d65d --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,50 @@ +package agent + +import ( + "context" + + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +func CoderAgentTools( + permissions permission.Service, + sessions session.Service, + messages message.Service, + history history.Service, + lspClients map[string]*lsp.Client, +) []tools.BaseTool { + ctx := context.Background() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + return append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + }, otherTools..., + ) +} + +func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { + return []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + } +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go new file mode 100644 index 000000000..48307e6d3 --- /dev/null +++ b/internal/llm/models/anthropic.go @@ -0,0 +1,71 @@ +package models + +const ( + ProviderAnthropic ModelProvider = "anthropic" + + // Models + Claude35Sonnet ModelID = "claude-3.5-sonnet" + Claude3Haiku ModelID = "claude-3-haiku" + Claude37Sonnet ModelID = "claude-3.7-sonnet" + Claude35Haiku ModelID = "claude-3.5-haiku" + Claude3Opus ModelID = "claude-3-opus" +) + +var AnthropicModels = map[ModelID]Model{ + // Anthropic + Claude35Sonnet: { + ID: Claude35Sonnet, + Name: "Claude 3.5 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude3Haiku: { + ID: Claude3Haiku, + Name: "Claude 3 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-haiku-latest", + CostPer1MIn: 0.25, + CostPer1MInCached: 0.30, + CostPer1MOutCached: 0.03, + CostPer1MOut: 1.25, + ContextWindow: 200000, + }, + Claude37Sonnet: { + ID: Claude37Sonnet, + Name: "Claude 3.7 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-7-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude35Haiku: { + ID: Claude35Haiku, + Name: "Claude 3.5 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-haiku-latest", + CostPer1MIn: 0.80, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + CostPer1MOut: 4.0, + ContextWindow: 200000, + }, + Claude3Opus: { + ID: Claude3Opus, + Name: "Claude 3 Opus", + Provider: ProviderAnthropic, + APIModel: "claude-3-opus-latest", + CostPer1MIn: 15.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.50, + CostPer1MOut: 75.0, + ContextWindow: 200000, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 140693237..4d4589bfd 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -1,5 +1,7 @@ package models +import "maps" + type ( ModelID string ModelProvider string @@ -14,15 +16,13 @@ type Model struct { CostPer1MOut float64 `json:"cost_per_1m_out"` CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` } // Model IDs const ( - // Anthropic - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" // OpenAI + GPT4o ModelID = "gpt-4o" GPT41 ModelID = "gpt-4.1" // GEMINI @@ -37,47 +37,59 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" - ProviderAnthropic ModelProvider = "anthropic" - ProviderBedrock ModelProvider = "bedrock" - ProviderGemini ModelProvider = "gemini" - ProviderGROQ ModelProvider = "groq" + ProviderOpenAI ModelProvider = "openai" + ProviderBedrock ModelProvider = "bedrock" + ProviderGemini ModelProvider = "gemini" + ProviderGROQ ModelProvider = "groq" + + // ForTests + ProviderMock ModelProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ - // Anthropic - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, + // // Anthropic + // Claude35Sonnet: { + // ID: Claude35Sonnet, + // Name: "Claude 3.5 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-5-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // Claude3Haiku: { + // ID: Claude3Haiku, + // Name: "Claude 3 Haiku", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-haiku-latest", + // CostPer1MIn: 0.80, + // CostPer1MInCached: 1, + // CostPer1MOutCached: 0.08, + // CostPer1MOut: 4, + // }, + // Claude37Sonnet: { + // ID: Claude37Sonnet, + // Name: "Claude 3.7 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-7-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // + // // OpenAI + GPT4o: { + ID: GPT4o, + Name: "GPT-4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0, + CostPer1MOut: 8.00, }, - - // OpenAI GPT41: { ID: GPT41, Name: "GPT-4.1", @@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 8.00, }, + // + // // GEMINI + // GEMINI25: { + // ID: GEMINI25, + // Name: "Gemini 2.5 Pro", + // Provider: ProviderGemini, + // APIModel: "gemini-2.5-pro-exp-03-25", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // GRMINI20Flash: { + // ID: GRMINI20Flash, + // Name: "Gemini 2.0 Flash", + // Provider: ProviderGemini, + // APIModel: "gemini-2.0-flash", + // CostPer1MIn: 0.1, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0.025, + // CostPer1MOut: 0.4, + // }, + // + // // GROQ + // QWENQwq: { + // ID: QWENQwq, + // Name: "Qwen Qwq", + // Provider: ProviderGROQ, + // APIModel: "qwen-qwq-32b", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // // Bedrock + // BedrockClaude37Sonnet: { + // ID: BedrockClaude37Sonnet, + // Name: "Bedrock: Claude 3.7 Sonnet", + // Provider: ProviderBedrock, + // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, +} - // GEMINI - GEMINI25: { - ID: GEMINI25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-exp-03-25", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - GRMINI20Flash: { - ID: GRMINI20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.1, - CostPer1MInCached: 0, - CostPer1MOutCached: 0.025, - CostPer1MOut: 0.4, - }, - - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, +func init() { + maps.Copy(SupportedModels, AnthropicModels) } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 47941f976..7439fd570 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,11 +9,22 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" ) -func CoderOpenAISystemPrompt() string { - basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. +func CoderPrompt(provider models.ModelProvider) string { + basePrompt := baseAnthropicCoderPrompt + switch provider { + case models.ProviderOpenAI: + basePrompt = baseOpenAICoderPrompt + } + envInfo := getEnvironmentInfo() + + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) +} + +const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. # Your mindset Act like a competent, efficient software engineer who is familiar with large codebases. You should: @@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines] Never commit changes unless the user explicitly asks you to.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - -func CoderAnthropicSystemPrompt() string { - basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. @@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - func getEnvironmentInfo() string { cwd := config.WorkingDirectory() isGit := isGitRepo(cwd) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go new file mode 100644 index 000000000..63fc2df7b --- /dev/null +++ b/internal/llm/prompt/prompt.go @@ -0,0 +1,19 @@ +package prompt + +import ( + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" +) + +func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { + switch agentName { + case config.AgentCoder: + return CoderPrompt(provider) + case config.AgentTitle: + return TitlePrompt(provider) + case config.AgentTask: + return TaskPrompt(provider) + default: + return "You are a helpful assistant" + } +} diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index ee3c707fa..8bf604ad9 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,11 +2,12 @@ package prompt import ( "fmt" + + "github.com/kujtimiihoxha/termai/internal/llm/models" ) -func TaskAgentSystemPrompt() string { +func TaskPrompt(_ models.ModelProvider) string { agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. - Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 5c47f4d64..3023a8550 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,8 @@ package prompt -func TitlePrompt() string { +import "github.com/kujtimiihoxha/termai/internal/llm/models" + +func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 93c4308ad..c3a4efc49 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,187 +12,257 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" ) -type anthropicProvider struct { - client anthropic.Client - model models.Model - maxTokens int64 - apiKey string - systemMessage string - useBedrock bool - disableCache bool +type anthropicOptions struct { + useBedrock bool + disableCache bool + shouldThink func(userMessage string) bool } -type AnthropicOption func(*anthropicProvider) +type AnthropicOption func(*anthropicOptions) -func WithAnthropicSystemMessage(message string) AnthropicOption { - return func(a *anthropicProvider) { - a.systemMessage = message - } +type anthropicClient struct { + providerOptions providerClientOptions + options anthropicOptions + client anthropic.Client } -func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption { - return func(a *anthropicProvider) { - a.maxTokens = maxTokens - } -} +type AnthropicClient ProviderClient -func WithAnthropicModel(model models.Model) AnthropicOption { - return func(a *anthropicProvider) { - a.model = model +func newAnthropicClient(opts providerClientOptions) AnthropicClient { + anthropicOpts := anthropicOptions{} + for _, o := range opts.anthropicOptions { + o(&anthropicOpts) } -} -func WithAnthropicKey(apiKey string) AnthropicOption { - return func(a *anthropicProvider) { - a.apiKey = apiKey + anthropicClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithAnthropicBedrock() AnthropicOption { - return func(a *anthropicProvider) { - a.useBedrock = true + if anthropicOpts.useBedrock { + anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } -} -func WithAnthropicDisableCache() AnthropicOption { - return func(a *anthropicProvider) { - a.disableCache = true + client := anthropic.NewClient(anthropicClientOptions...) + return &anthropicClient{ + providerOptions: opts, + options: anthropicOpts, + client: client, } } -func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { - provider := &anthropicProvider{ - maxTokens: 1024, - } +func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { + cachedBlocks := 0 + for _, msg := range messages { + switch msg.Role { + case message.User: + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - for _, opt := range opts { - opt(provider) - } + case message.Assistant: + blocks := []anthropic.ContentBlockParamUnion{} + if msg.Content().String() != "" { + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + blocks = append(blocks, content) + } - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } + for _, toolCall := range msg.ToolCalls() { + var inputMap map[string]any + err := json.Unmarshal([]byte(toolCall.Input), &inputMap) + if err != nil { + continue + } + blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) + } - anthropicOptions := []option.RequestOption{} + if len(blocks) == 0 { + logging.Warn("There is a message without content, investigate") + // This should never happend but we log this because we might have a bug in our cleanup method + continue + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - if provider.apiKey != "" { - anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) - } - if provider.useBedrock { - anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + case message.Tool: + results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) + for i, toolResult := range msg.ToolResults() { + results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) + } } - - provider.client = anthropic.NewClient(anthropicOptions...) - return provider, nil + return } -func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) - - response, err := a.client.Messages.New( - ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: anthropic.Float(0), - Messages: anthropicMessages, - Tools: anthropicTools, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, +func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { + anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) + + for i, tool := range tools { + info := tool.Info() + toolParam := anthropic.ToolParam{ + Name: info.Name, + Description: anthropic.String(info.Description), + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: info.Parameters, + // TODO: figure out how we can tell claude the required fields? }, - }, - ) - if err != nil { - return nil, err - } + } - content := "" - for _, block := range response.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text + if i == len(tools)-1 && !a.options.disableCache { + toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } } - } - toolCalls := a.extractToolCalls(response.Content) - tokenUsage := a.extractTokenUsage(response.Usage) + anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return anthropicTools } -func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) +func (a *anthropicClient) finishReason(reason string) message.FinishReason { + switch reason { + case "end_turn": + return message.FinishReasonEndTurn + case "max_tokens": + return message.FinishReasonMaxTokens + case "tool_use": + return message.FinishReasonToolUse + case "stop_sequence": + return message.FinishReasonEndTurn + default: + return message.FinishReasonUnknown + } +} +func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion lastMessage := messages[len(messages)-1] + isUser := lastMessage.Role == anthropic.MessageParamRoleUser + messageContent := "" temperature := anthropic.Float(0) - if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") { - thinkingParam = anthropic.ThinkingConfigParamUnion{ - OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: int64(float64(a.maxTokens) * 0.8), - Type: "enabled", - }, + if isUser { + for _, m := range lastMessage.Content { + if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" { + messageContent = m.OfRequestTextBlock.Text + } + } + if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { + thinkingParam = anthropic.ThinkingConfigParamUnion{ + OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8), + Type: "enabled", + }, + } + temperature = anthropic.Float(1) } - temperature = anthropic.Float(1) } - eventChan := make(chan ProviderEvent) + return anthropic.MessageNewParams{ + Model: anthropic.Model(a.providerOptions.model.APIModel), + MaxTokens: a.providerOptions.maxTokens, + Temperature: temperature, + Messages: messages, + Tools: tools, + Thinking: thinkingParam, + System: []anthropic.TextBlockParam{ + { + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }, + }, + } +} - go func() { - defer close(eventChan) +func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + for { + attempts++ + anthropicResponse, err := a.client.Messages.New( + ctx, + preparedMessages, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr + } - const maxRetries = 8 - attempts := 0 + content := "" + for _, block := range anthropicResponse.Content { + if text, ok := block.AsAny().(anthropic.TextBlock); ok { + content += text.Text + } + } - for { + return &ProviderResponse{ + Content: content, + ToolCalls: a.toolCalls(*anthropicResponse), + Usage: a.usage(*anthropicResponse), + }, nil + } +} +func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + eventChan := make(chan ProviderEvent) + go func() { + for { attempts++ - - stream := a.client.Messages.NewStreaming( + anthropicStream := a.client.Messages.NewStreaming( ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: temperature, - Messages: anthropicMessages, - Tools: anthropicTools, - Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, - }, + preparedMessages, ) - accumulatedMessage := anthropic.Message{} - for stream.Next() { - event := stream.Current() + for anthropicStream.Next() { + event := anthropicStream.Current() err := accumulatedMessage.Accumulate(event) if err != nil { eventChan <- ProviderEvent{Type: EventError, Error: err} - return // Don't retry on accumulation errors + continue } switch event := event.AsAny().(type) { @@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa Content: event.Delta.Text, } } + // TODO: check if we can somehow stream tool calls case anthropic.ContentBlockStopEvent: eventChan <- ProviderEvent{Type: EventContentStop} @@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa } } - toolCalls := a.extractToolCalls(accumulatedMessage.Content) - tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage) - eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(accumulatedMessage.StopReason), + ToolCalls: a.toolCalls(accumulatedMessage), + Usage: a.usage(accumulatedMessage), + FinishReason: a.finishReason(string(accumulatedMessage.StopReason)), }, } } } - err := stream.Err() + err := anthropicStream.Err() if err == nil || errors.Is(err, io.EOF) { + close(eventChan) return } - - var apierr *anthropic.Error - if !errors.As(err, &apierr) { - eventChan <- ProviderEvent{Type: EventError, Error: err} - return - } - - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { - eventChan <- ProviderEvent{Type: EventError, Error: err} + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } - - if attempts > maxRetries { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("maximum retry attempts reached for rate limit (429)"), - } - return - } - - retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") - if len(retryAfterValues) > 0 { - var retryAfterSec int - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil { - retryMs = retryAfterSec * 1000 - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec), + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue } - } else { - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries), - } - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs } - select { - case <-ctx.Done(): + if ctx.Err() != nil { eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - return - case <-time.After(time.Duration(retryMs) * time.Millisecond): - continue } + close(eventChan) + return } }() + return eventChan +} - return eventChan, nil +func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *anthropic.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil } -func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall { +func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { var toolCalls []message.ToolCall - for _, block := range content { + for _, block := range msg.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: toolCall := message.ToolCall{ @@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni return toolCalls } -func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage { +func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { return TokenUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - CacheCreationTokens: usage.CacheCreationInputTokens, - CacheReadTokens: usage.CacheReadInputTokens, + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + CacheCreationTokens: msg.Usage.CacheCreationInputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, } } -func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { - anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - toolParam := anthropic.ToolParam{ - Name: info.Name, - Description: anthropic.String(info.Description), - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: info.Parameters, - }, - } - - if i == len(tools)-1 && !a.disableCache { - toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - - anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} +func WithAnthropicBedrock(useBedrock bool) AnthropicOption { + return func(options *anthropicOptions) { + options.useBedrock = useBedrock } - - return anthropicTools } -func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam { - anthropicMessages := make([]anthropic.MessageParam, 0, len(messages)) - cachedBlocks := 0 - - for _, msg := range messages { - switch msg.Role { - case message.User: - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - - case message.Assistant: - blocks := []anthropic.ContentBlockParamUnion{} - if msg.Content().String() != "" { - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - blocks = append(blocks, content) - } - - for _, toolCall := range msg.ToolCalls() { - var inputMap map[string]any - err := json.Unmarshal([]byte(toolCall.Input), &inputMap) - if err != nil { - continue - } - blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) - } +func WithAnthropicDisableCache() AnthropicOption { + return func(options *anthropicOptions) { + options.disableCache = true + } +} - if len(blocks) > 0 { - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } +func DefaultShouldThinkFn(s string) bool { + return strings.Contains(strings.ToLower(s), "think") +} - case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) - for i, toolResult := range msg.ToolResults() { - results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) - } +func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { + return func(options *anthropicOptions) { + options.shouldThink = fn } - - return anthropicMessages } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 677f4676b..d76925ad1 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,33 +7,29 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -type bedrockProvider struct { - childProvider Provider - model models.Model - maxTokens int64 - systemMessage string +type bedrockOptions struct { + // Bedrock specific options can be added here } -func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - return b.childProvider.SendMessages(ctx, messages, tools) -} +type BedrockOption func(*bedrockOptions) -func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - return b.childProvider.StreamResponse(ctx, messages, tools) +type bedrockClient struct { + providerOptions providerClientOptions + options bedrockOptions + childProvider ProviderClient } -func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { - provider := &bedrockProvider{} - for _, opt := range opts { - opt(provider) - } +type BedrockClient ProviderClient + +func newBedrockClient(opts providerClientOptions) BedrockClient { + bedrockOpts := bedrockOptions{} + // Apply bedrock specific options if they are added in the future - // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + // Get AWS region from environment region := os.Getenv("AWS_REGION") if region == "" { region = os.Getenv("AWS_DEFAULT_REGION") @@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { region = "us-east-1" // default region } if len(region) < 2 { - return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, // Will cause an error when used + } } + + // Prefix the model name with region regionPrefix := region[:2] - provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + modelName := opts.model.APIModel + opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) - if strings.Contains(string(provider.model.APIModel), "anthropic") { - anthropic, err := NewAnthropicProvider( - WithAnthropicModel(provider.model), - WithAnthropicMaxTokens(provider.maxTokens), - WithAnthropicSystemMessage(provider.systemMessage), - WithAnthropicBedrock(), + // Determine which provider to use based on the model + if strings.Contains(string(opts.model.APIModel), "anthropic") { + // Create Anthropic client with Bedrock configuration + anthropicOpts := opts + anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, + WithAnthropicBedrock(true), WithAnthropicDisableCache(), ) - provider.childProvider = anthropic - if err != nil { - return nil, err + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: newAnthropicClient(anthropicOpts), } - } else { - return nil, errors.New("unsupported model for bedrock provider") } - return provider, nil -} - -type BedrockOption func(*bedrockProvider) -func WithBedrockSystemMessage(message string) BedrockOption { - return func(a *bedrockProvider) { - a.systemMessage = message + // Return client with nil childProvider if model is not supported + // This will cause an error when used + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, } } -func WithBedrockMaxTokens(maxTokens int64) BedrockOption { - return func(a *bedrockProvider) { - a.maxTokens = maxTokens +func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if b.childProvider == nil { + return nil, errors.New("unsupported model for bedrock provider") } + return b.childProvider.send(ctx, messages, tools) } -func WithBedrockModel(model models.Model) BedrockOption { - return func(a *bedrockProvider) { - a.model = model +func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + eventChan := make(chan ProviderEvent) + + if b.childProvider == nil { + go func() { + eventChan <- ProviderEvent{ + Type: EventError, + Error: errors.New("unsupported model for bedrock provider"), + } + close(eventChan) + }() + return eventChan } -} + + return b.childProvider.stream(ctx, messages, tools) +}
\ No newline at end of file diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 2d1db2b64..804baea28 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -4,80 +4,68 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" + "strings" + "time" "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) -type geminiProvider struct { - client *genai.Client - model models.Model - maxTokens int32 - apiKey string - systemMessage string +type geminiOptions struct { + disableCache bool } -type GeminiOption func(*geminiProvider) +type GeminiOption func(*geminiOptions) -func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) { - provider := &geminiProvider{ - maxTokens: 5000, - } +type geminiClient struct { + providerOptions providerClientOptions + options geminiOptions + client *genai.Client +} - for _, opt := range opts { - opt(provider) - } +type GeminiClient ProviderClient - if provider.systemMessage == "" { - return nil, errors.New("system message is required") +func newGeminiClient(opts providerClientOptions) GeminiClient { + geminiOpts := geminiOptions{} + for _, o := range opts.geminiOptions { + o(&geminiOpts) } - client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey)) + client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey)) if err != nil { - return nil, err - } - provider.client = client - - return provider, nil -} - -func WithGeminiSystemMessage(message string) GeminiOption { - return func(p *geminiProvider) { - p.systemMessage = message + logging.Error("Failed to create Gemini client", "error", err) + return nil } -} -func WithGeminiMaxTokens(maxTokens int32) GeminiOption { - return func(p *geminiProvider) { - p.maxTokens = maxTokens + return &geminiClient{ + providerOptions: opts, + options: geminiOpts, + client: client, } } -func WithGeminiModel(model models.Model) GeminiOption { - return func(p *geminiProvider) { - p.model = model - } -} - -func WithGeminiKey(apiKey string) GeminiOption { - return func(p *geminiProvider) { - p.apiKey = apiKey - } -} +func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { + var history []*genai.Content -func (p *geminiProvider) Close() { - if p.client != nil { - p.client.Close() - } -} + // Add system message first + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)}, + Role: "user", + }) -func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content { - var history []*genai.Content + // Add a system response to acknowledge the system message + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text("I'll help you with that.")}, + Role: "model", + }) for _, msg := range messages { switch msg.Role { @@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g Parts: []genai.Part{genai.Text(msg.Content().String())}, Role: "user", }) + case message.Assistant: content := &genai.Content{ Role: "model", @@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g } history = append(history, content) + case message.Tool: for _, result := range msg.ToolResults() { response := map[string]interface{}{"result": result.Content} @@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g if err == nil { response = parsed } + var toolCall message.ToolCall - for _, msg := range messages { - if msg.Role == message.Assistant { - for _, call := range msg.ToolCalls() { + for _, m := range messages { + if m.Role == message.Assistant { + for _, call := range m.ToolCalls() { if call.ID == result.ToolCallID { toolCall = call break @@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g return history } -func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage { - if resp == nil || resp.UsageMetadata == nil { - return TokenUsage{} - } +func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { + geminiTools := make([]*genai.Tool, 0, len(tools)) - return TokenUsage{ - InputTokens: int64(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), - CacheCreationTokens: 0, // Not directly provided by Gemini - CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + for _, tool := range tools { + info := tool.Info() + declaration := &genai.FunctionDeclaration{ + Name: info.Name, + Description: info.Description, + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: convertSchemaProperties(info.Parameters), + Required: info.Required, + }, + } + + geminiTools = append(geminiTools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{declaration}, + }) } + + return geminiTools } -func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) +func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { + reasonStr := reason.String() + switch { + case reasonStr == "STOP": + return message.FinishReasonEndTurn + case reasonStr == "MAX_TOKENS": + return message.FinishReasonMaxTokens + case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"): + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String())) - if err != nil { - return nil, err + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - var content string - var toolCalls []message.ToolCall + attempts := 0 + for { + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break + } + } - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - content = string(p) - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - }) + resp, err := chat.SendMessage(ctx, genai.Text(lastText)) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(resp) + content := "" + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + content = string(p) + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + }) + } + } + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: g.usage(resp), + FinishReason: g.finishReason(resp.Candidates[0].FinishReason), + }, nil + } } -func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) - - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - - iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String())) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) go func() { defer close(eventChan) - var finalResp *genai.GenerateContentResponse - currentContent := "" - toolCalls := []message.ToolCall{} - for { - resp, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break } - return } - finalResp = resp + iter := chat.SendMessageStream(ctx, genai.Text(lastText)) - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - newText := string(p) - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: newText, - } - currentContent += newText - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - newCall := message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - } + currentContent := "" + toolCalls := []message.ToolCall{} + var finalResp *genai.GenerateContentResponse - isNew := true - for _, existing := range toolCalls { - if existing.Name == newCall.Name && existing.Input == newCall.Input { - isNew = false - break + eventChan <- ProviderEvent{Type: EventContentStart} + + for { + resp, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + + return + case <-time.After(time.Duration(after) * time.Millisecond): + break } + } else { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + } + + finalResp = resp + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + newText := string(p) + delta := newText[len(currentContent):] + if delta != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: delta, + } + currentContent = newText + } + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + newCall := message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + } - if isNew { - toolCalls = append(toolCalls, newCall) + isNew := true + for _, existing := range toolCalls { + if existing.Name == newCall.Name && existing.Input == newCall.Input { + isNew = false + break + } + } + + if isNew { + toolCalls = append(toolCalls, newCall) + } } } } } - } - tokenUsage := p.extractTokenUsage(finalResp) + eventChan <- ProviderEvent{Type: EventContentStop} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(finalResp.Candidates[0].FinishReason.String()), - }, + if finalResp != nil { + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: g.usage(finalResp), + FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason), + }, + } + return + } + + // If we get here, we need to retry + if attempts > maxRetries { + eventChan <- ProviderEvent{ + Type: EventError, + Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries), + } + return + } + + // Wait before retrying + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + return + case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond): + continue + } } }() - return eventChan, nil + return eventChan } -func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { - declarations := make([]*genai.FunctionDeclaration, len(tools)) +func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + // Check if error is a rate limit error + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } - for i, tool := range tools { - info := tool.Info() - declarations[i] = &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: convertSchemaProperties(info.Parameters), - Required: info.Required, - }, + // Gemini doesn't have a standard error type we can check against + // So we'll check the error message for rate limit indicators + if errors.Is(err, io.EOF) { + return false, 0, err + } + + errMsg := err.Error() + isRateLimit := false + + // Check for common rate limit error messages + if contains(errMsg, "rate limit", "quota exceeded", "too many requests") { + isRateLimit = true + } + + if !isRateLimit { + return false, 0, err + } + + // Calculate backoff with jitter + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs := backoffMs + jitterMs + + return true, int64(retryMs), nil +} + +func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + if funcCall, ok := part.(genai.FunctionCall); ok { + id := "call_" + uuid.New().String() + args, _ := json.Marshal(funcCall.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: funcCall.Name, + Input: string(args), + Type: "function", + }) + } } } - return declarations + return toolCalls +} + +func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { + if resp == nil || resp.UsageMetadata == nil { + return TokenUsage{} + } + + return TokenUsage{ + InputTokens: int64(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), + CacheCreationTokens: 0, // Not directly provided by Gemini + CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + } +} + +func WithGeminiDisableCache() GeminiOption { + return func(options *geminiOptions) { + options.disableCache = true + } +} + +// Helper functions +func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { + var result map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + return result, err } func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema { @@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type { } } -func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { - var result map[string]interface{} - err := json.Unmarshal([]byte(jsonStr), &result) - return result, err +func contains(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) { + return true + } + } + return false } + diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index dbfde3fa8..9c2ad2012 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -2,89 +2,65 @@ package provider import ( "context" + "encoding/json" "errors" + "fmt" + "io" + "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) -type openaiProvider struct { - client openai.Client - model models.Model - maxTokens int64 - baseURL string - apiKey string - systemMessage string +type openaiOptions struct { + baseURL string + disableCache bool } -type OpenAIOption func(*openaiProvider) +type OpenAIOption func(*openaiOptions) -func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) { - provider := &openaiProvider{ - maxTokens: 5000, - } - - for _, opt := range opts { - opt(provider) - } - - clientOpts := []option.RequestOption{ - option.WithAPIKey(provider.apiKey), - } - if provider.baseURL != "" { - clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL)) - } - - provider.client = openai.NewClient(clientOpts...) - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } - - return provider, nil +type openaiClient struct { + providerOptions providerClientOptions + options openaiOptions + client openai.Client } -func WithOpenAISystemMessage(message string) OpenAIOption { - return func(p *openaiProvider) { - p.systemMessage = message - } -} +type OpenAIClient ProviderClient -func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption { - return func(p *openaiProvider) { - p.maxTokens = maxTokens +func newOpenAIClient(opts providerClientOptions) OpenAIClient { + openaiOpts := openaiOptions{} + for _, o := range opts.openaiOptions { + o(&openaiOpts) } -} -func WithOpenAIModel(model models.Model) OpenAIOption { - return func(p *openaiProvider) { - p.model = model + openaiClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(p *openaiProvider) { - p.baseURL = baseURL + if openaiOpts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) } -} -func WithOpenAIKey(apiKey string) OpenAIOption { - return func(p *openaiProvider) { - p.apiKey = apiKey + client := openai.NewClient(openaiClientOptions...) + return &openaiClient{ + providerOptions: opts, + options: openaiOpts, + client: client, } } -func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion { - var chatMessages []openai.ChatCompletionMessageParamUnion - - chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage)) +func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) for _, msg := range messages { switch msg.Role { case message.User: - chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String())) + openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String())) case message.Assistant: assistantMsg := openai.ChatCompletionAssistantMessageParam{ @@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o } } - chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{ + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ OfAssistant: &assistantMsg, }) case message.Tool: for _, result := range msg.ToolResults() { - chatMessages = append(chatMessages, + openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID), ) } } } - return chatMessages + return } -func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { +func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) for i, tool := range tools { @@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C return openaiTools } -func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage { - cachedTokens := int64(0) - - cachedTokens = usage.PromptTokensDetails.CachedTokens - inputTokens := usage.PromptTokens - cachedTokens - - return TokenUsage{ - InputTokens: inputTokens, - OutputTokens: usage.CompletionTokens, - CacheCreationTokens: 0, // OpenAI doesn't provide this directly - CacheReadTokens: cachedTokens, +func (o *openaiClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown } } -func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - } - - response, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, err +func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + MaxTokens: openai.Int(o.providerOptions.maxTokens), + Tools: tools, } +} - content := "" - if response.Choices[0].Message.Content != "" { - content = response.Choices[0].Message.Content +func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - - var toolCalls []message.ToolCall - if len(response.Choices[0].Message.ToolCalls) > 0 { - toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls)) - for i, call := range response.Choices[0].Message.ToolCalls { - toolCalls[i] = message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", + attempts := 0 + for { + attempts++ + openaiResponse, err := o.client.Chat.Completions.New( + ctx, + params, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(response.Usage) + content := "" + if openaiResponse.Choices[0].Message.Content != "" { + content = openaiResponse.Choices[0].Message.Content + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: o.toolCalls(*openaiResponse), + Usage: o.usage(*openaiResponse), + FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)), + }, nil + } } -func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - StreamOptions: openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - }, +func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) - toolCalls := make([]message.ToolCall, 0) go func() { - defer close(eventChan) - - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if tool, ok := acc.JustFinishedToolCall(); ok { - toolCalls = append(toolCalls, message.ToolCall{ - ID: tool.Id, - Name: tool.Name, - Input: tool.Arguments, - Type: "function", - }) - } + for { + attempts++ + openaiStream := o.client.Chat.Completions.NewStreaming( + ctx, + params, + ) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + + for openaiStream.Next() { + chunk := openaiStream.Current() + acc.AddChunk(chunk) + + if tool, ok := acc.JustFinishedToolCall(); ok { + toolCalls = append(toolCalls, message.ToolCall{ + ID: tool.Id, + Name: tool.Name, + Input: tool.Arguments, + Type: "function", + }) + } - for _, choice := range chunk.Choices { - if choice.Delta.Content != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: choice.Delta.Content, + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } + currentContent += choice.Delta.Content } - currentContent += choice.Delta.Content } } - } - if err := stream.Err(); err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + err := openaiStream.Err() + if err == nil || errors.Is(err, io.EOF) { + // Stream completed successfully + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: o.usage(acc.ChatCompletion), + FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)), + }, + } + close(eventChan) + return } + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() == nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } + }() - tokenUsage := p.extractTokenUsage(acc.Usage) + return eventChan +} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, +func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *openai.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 } - }() + } + return true, int64(retryMs), nil +} - return eventChan, nil +func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + } + toolCalls = append(toolCalls, toolCall) + } + } + + return toolCalls } + +func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // OpenAI doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func WithOpenAIBaseURL(baseURL string) OpenAIOption { + return func(options *openaiOptions) { + options.baseURL = baseURL + } +} + +func WithOpenAIDisableCache() OpenAIOption { + return func(options *openaiOptions) { + options.disableCache = true + } +} + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 34d91f2b7..1a5b3dc8a 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -2,14 +2,17 @@ package provider import ( "context" + "fmt" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -// EventType represents the type of streaming event type EventType string +const maxRetries = 8 + const ( EventContentStart EventType = "content_start" EventContentDelta EventType = "content_delta" @@ -18,7 +21,6 @@ const ( EventComplete EventType = "complete" EventError EventType = "error" EventWarning EventType = "warning" - EventInfo EventType = "info" ) type TokenUsage struct { @@ -32,61 +34,152 @@ type ProviderResponse struct { Content string ToolCalls []message.ToolCall Usage TokenUsage - FinishReason string + FinishReason message.FinishReason } type ProviderEvent struct { - Type EventType + Type EventType + Content string Thinking string - ToolCall *message.ToolCall - Error error Response *ProviderResponse - // Used for giving users info on e.x retry - Info string + Error error } - type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) + StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() models.Model +} + +type providerClientOptions struct { + apiKey string + model models.Model + maxTokens int64 + systemMessage string + + anthropicOptions []AnthropicOption + openaiOptions []OpenAIOption + geminiOptions []GeminiOption + bedrockOptions []BedrockOption +} + +type ProviderClientOption func(*providerClientOptions) + +type ProviderClient interface { + send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent +} + +type baseProvider[C ProviderClient] struct { + options providerClientOptions + client C +} + +func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{} + for _, o := range opts { + o(&clientOptions) + } + switch providerName { + case models.ProviderAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions), + }, nil + case models.ProviderOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case models.ProviderGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case models.ProviderBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case models.ProviderMock: + // TODO: implement mock client for test + panic("not implemented") + } + return nil, fmt.Errorf("provider not supported: %s", providerName) } -func cleanupMessages(messages []message.Message) []message.Message { - // First pass: filter out canceled messages - var cleanedMessages []message.Message +func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { - if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { - // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been - // cancelled - cleanedMessages = append(cleanedMessages, msg) + // The message has no content + if len(msg.Parts) == 0 { + continue } + cleaned = append(cleaned, msg) } + return +} - // Second pass: filter out tool messages without a corresponding tool call - var result []message.Message - toolMessageIDs := make(map[string]bool) +func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = p.cleanMessages(messages) + return p.client.send(ctx, messages, tools) +} - for _, msg := range cleanedMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolMessageIDs[toolCall.ID] = true // Mark as referenced - } - } +func (p *baseProvider[C]) Model() models.Model { + return p.options.model +} + +func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = p.cleanMessages(messages) + return p.client.stream(ctx, messages, tools) +} + +func WithAPIKey(apiKey string) ProviderClientOption { + return func(options *providerClientOptions) { + options.apiKey = apiKey } +} - // Keep only messages that aren't unreferenced tool messages - for _, msg := range cleanedMessages { - if msg.Role == message.Tool { - for _, toolCall := range msg.ToolResults() { - if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced { - result = append(result, msg) - } - } - } else { - result = append(result, msg) - } +func WithModel(model models.Model) ProviderClientOption { + return func(options *providerClientOptions) { + options.model = model + } +} + +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func WithSystemMessage(systemMessage string) ProviderClientOption { + return func(options *providerClientOptions) { + options.systemMessage = systemMessage + } +} + +func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.anthropicOptions = anthropicOptions + } +} + +func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = openaiOptions + } +} + +func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.geminiOptions = geminiOptions + } +} + +func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.bedrockOptions = bedrockOptions } - return result } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0cea20878..c7c970e5a 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -23,7 +23,8 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - Took int64 `json:"took"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` } type bashTool struct { permissions permission.Service @@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } - took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - Took: took, + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), } if stdout == "" { return WithResponseMetadata(NewTextResponse("no output"), metadata), nil diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 97be3683a..dafb0ccc5 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) { }) } } - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 08d6d446c..148e7aba7 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -35,6 +36,7 @@ type EditResponseMetadata struct { type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } const ( @@ -88,10 +90,11 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return response, nil } + if response.IsError { + // Return early if there was an error during content replacement + // This prevents unnecessary LSP diagnostics processing + return response, nil + } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content) @@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // File can't be in the history so we create a new file history + _, err = e.files.Create(ctx, sessionID, filePath, "") + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + // Add the new content to the file history + _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) + if err != nil { + // Log error but don't fail the operation + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string if err != nil { return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + if oldContent == newContent { + return NewTextErrorResponse("new content is the same as old content. No changes made."), nil + } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, - - Diff: diff, + Diff: diff, }, }, ) @@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 48a34ed75..0971775dd 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -14,7 +14,7 @@ import ( ) func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, EditToolName, info.Name) @@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "replace_content.txt") @@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "delete_content.txt") @@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: EditToolName, @@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := EditParams{ FilePath: "", @@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "non_existent_file.txt") params := EditParams{ @@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "content_not_found.txt") @@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file with duplicate content filePath := filepath.Join(tempDir, "duplicate_content.txt") @@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "not_read_file.txt") @@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "permission_denied.txt") diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 9c9707c9c..7f34fdc1f 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,8 +3,6 @@ package tools import ( "sync" "time" - - "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -19,14 +17,6 @@ var ( fileRecordMutex sync.RWMutex ) -func removeWorkingDirectoryPrefix(path string) string { - wd := config.WorkingDirectory() - if len(path) > len(wd) && path[:len(wd)] == wd { - return path[len(wd)+1:] - } - return path -} - func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index bdfc23b4a..7b4fb1187 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,7 +63,7 @@ type GlobParams struct { Path string `json:"path"` } -type GlobMetadata struct { +type GlobResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GlobMetadata{ + GlobResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7e52821d0..19333f50b 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,7 +27,7 @@ type grepMatch struct { modTime time.Time } -type GrepMetadata struct { +type GrepResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } @@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GrepMetadata{ + GrepResponseMetadata{ NumberOfMatches: len(matches), Truncated: truncated, }, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a679f261b..a63bf0eeb 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,7 +23,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -type LSMetadata struct { +type LSResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { return WithResponseMetadata( NewTextResponse(output), - LSMetadata{ + LSResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go new file mode 100644 index 000000000..321f09ac1 --- /dev/null +++ b/internal/llm/tools/mocks_test.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +// Mock permission service for testing +type mockPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + allow bool +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return m.allow +} + +func newMockPermissionService(allow bool) permission.Service { + return &mockPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } +} + +type mockFileHistoryService struct { + *pubsub.Broker[history.File] + files map[string]history.File // ID -> File + timeNow func() int64 +} + +// Create implements history.Service. +func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) +} + +// CreateVersion implements history.Service. +func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + var files []history.File + for _, file := range m.files { + if file.Path == path { + files = append(files, file) + } + } + + if len(files) == 0 { + // No previous versions, create initial + return m.Create(ctx, sessionID, path, content) + } + + // Sort files by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + // Get the latest version + latestFile := files[0] + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == history.InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return m.createWithVersion(ctx, sessionID, path, content, nextVersion) +} + +func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { + now := m.timeNow() + file := history.File{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + CreatedAt: now, + UpdatedAt: now, + } + + m.files[file.ID] = file + m.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +// Delete implements history.Service. +func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { + file, ok := m.files[id] + if !ok { + return fmt.Errorf("file not found: %s", id) + } + + delete(m.files, id) + m.Publish(pubsub.DeletedEvent, file) + return nil +} + +// DeleteSessionFiles implements history.Service. +func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := m.ListBySession(ctx, sessionID) + if err != nil { + return err + } + + for _, file := range files { + err = m.Delete(ctx, file.ID) + if err != nil { + return err + } + } + + return nil +} + +// Get implements history.Service. +func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { + file, ok := m.files[id] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", id) + } + return file, nil +} + +// GetByPathAndSession implements history.Service. +func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { + var latestFile history.File + var found bool + var latestTime int64 + + for _, file := range m.files { + if file.Path == path && file.SessionID == sessionID { + if !found || file.CreatedAt > latestTime { + latestFile = file + latestTime = file.CreatedAt + found = true + } + } + } + + if !found { + return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) + } + return latestFile, nil +} + +// ListBySession implements history.Service. +func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + var files []history.File + for _, file := range m.files { + if file.SessionID == sessionID { + files = append(files, file) + } + } + + // Sort by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + return files, nil +} + +// ListLatestSessionFiles implements history.Service. +func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + // Map to track the latest file for each path + latestFiles := make(map[string]history.File) + + for _, file := range m.files { + if file.SessionID == sessionID { + existing, ok := latestFiles[file.Path] + if !ok || file.CreatedAt > existing.CreatedAt { + latestFiles[file.Path] = file + } + } + } + + // Convert map to slice + var result []history.File + for _, file := range latestFiles { + result = append(result, file) + } + + // Sort by CreatedAt in descending order + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result, nil +} + +// Subscribe implements history.Service. +func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + return m.Broker.Subscribe(ctx) +} + +// Update implements history.Service. +func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + _, ok := m.files[file.ID] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", file.ID) + } + + file.UpdatedAt = m.timeNow() + m.files[file.ID] = file + m.Publish(pubsub.UpdatedEvent, file) + return file, nil +} + +func newMockFileHistoryService() history.Service { + return &mockFileHistoryService{ + Broker: pubsub.NewBroker[history.File](), + files: make(map[string]history.File), + timeNow: func() int64 { return time.Now().Unix() }, + } +} diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 64592f67d..4a776478a 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell { commandQueue: make(chan *commandExecution, 10), } - go shell.processCommands() + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() go func() { err := cmd.Wait() if err != nil { + // Log the error if needed } shell.isAlive = false close(shell.commandQueue) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 17bc610ea..a6f2c8afb 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,7 +18,7 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } -type SourcegraphMetadata struct { +type SourcegraphResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 07afe1363..bf0f8df0b 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -14,12 +14,17 @@ type ToolInfo struct { type toolResponseType string +type ( + sessionIDContextKey string + messageIDContextKey string +) + const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" - SessionIDContextKey = "session_id" - MessageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" ) type ToolResponse struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 889561d2a..bb49381fd 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -27,6 +28,7 @@ type WritePermissionsParams struct { type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } type WriteResponseMetadata struct { @@ -67,10 +69,11 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } + // Check if file exists in history + file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = w.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) waitForLspDiagnostics(ctx, filePath, w.lspClients) diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 50dafc14f..2264f36fb 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -14,7 +14,7 @@ import ( ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: WriteToolName, @@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: "", @@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") |
