diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-13 14:37:05 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-21 13:41:27 +0200 |
| commit | cdc5f209dccdc980714f2ca1aeb52133d6e93cce (patch) | |
| tree | 02fe97994dfce7f2e842be7b4c5170f534220eee /internal/llm | |
| parent | 3ad983db0f2c08826d56cb5de274d706c95b3353 (diff) | |
| download | opencode-cdc5f209dccdc980714f2ca1aeb52133d6e93cce.tar.gz opencode-cdc5f209dccdc980714f2ca1aeb52133d6e93cce.zip | |
cleanup diff, cleanup agent
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/agent/agent-tool.go | 34 | ||||
| -rw-r--r-- | internal/llm/agent/agent.go | 522 | ||||
| -rw-r--r-- | internal/llm/agent/coder.go | 83 | ||||
| -rw-r--r-- | internal/llm/agent/task.go | 7 | ||||
| -rw-r--r-- | internal/llm/provider/provider.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 7 | ||||
| -rw-r--r-- | internal/llm/tools/tools.go | 2 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 2 |
8 files changed, 404 insertions, 257 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 91c46da8b..a9c6f93a7 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,14 +5,16 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) type agentTool struct { - parentSessionID string - app *app.App + sessions session.Service + messages message.Service + lspClients map[string]*lsp.Client } const ( @@ -46,12 +48,17 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("prompt is required"), nil } - agent, err := NewTaskAgent(b.app) + sessionID, messageID := tools.GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return tools.NewTextErrorResponse("session ID and message ID are required"), nil + } + + agent, err := NewTaskAgent(b.lspClients) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil } - session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session") + session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } @@ -61,7 +68,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } - messages, err := b.app.Messages.List(ctx, session.ID) + messages, err := b.messages.List(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil } @@ -74,11 +81,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("no assistant message found"), nil } - updatedSession, err := b.app.Sessions.Get(ctx, session.ID) + updatedSession, err := b.sessions.Get(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID) + parentSession, err := b.sessions.Get(ctx, sessionID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } @@ -87,16 +94,19 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes parentSession.PromptTokens += updatedSession.PromptTokens parentSession.CompletionTokens += updatedSession.CompletionTokens - _, err = b.app.Sessions.Save(ctx, parentSession) + _, err = b.sessions.Save(ctx, parentSession) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } return tools.NewTextResponse(response.Content().String()), nil } -func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool { +func NewAgentTool( + Sessions session.Service, + Messages message.Service, +) tools.BaseTool { return &agentTool{ - parentSessionID: parentSessionID, - app: app, + sessions: Sessions, + messages: Messages, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b7c736e6c..997004e12 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/prompt" @@ -15,22 +14,118 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) -type Agent interface { +// Common errors +var ( + ErrProviderNotEnabled = errors.New("provider is not enabled") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") +) + +// Service defines the interface for generating responses +type Service interface { Generate(ctx context.Context, sessionID string, content string) error + Cancel(sessionID string) error } type agent struct { - *app.App + sessions session.Service + messages message.Service model models.Model tools []tools.BaseTool agent provider.Provider titleGenerator provider.Provider + activeRequests sync.Map // map[sessionID]context.CancelFunc +} + +// NewAgent creates a new agent instance with the given model and tools +func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to initialize providers: %w", err) + } + + return &agent{ + model: model, + tools: tools, + sessions: sessions, + messages: messages, + agent: agentProvider, + titleGenerator: titleGenerator, + activeRequests: sync.Map{}, + }, nil +} + +// Cancel cancels an active request by session ID +func (a *agent) Cancel(sessionID string) error { + if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { + if cancel, ok := cancelFunc.(context.CancelFunc); ok { + logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) + cancel() + return nil + } + } + return errors.New("no active request found for this session") } -func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := c.titleGenerator.SendMessages( +// Generate starts the generation process +func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { + // Check if this session already has an active request + if _, busy := a.activeRequests.Load(sessionID); busy { + return ErrSessionBusy + } + + // Create a cancellable context + genCtx, cancel := context.WithCancel(ctx) + + // Store cancel function to allow user cancellation + a.activeRequests.Store(sessionID, cancel) + + // Launch the generation in a goroutine + go func() { + defer func() { + if r := recover(); r != nil { + logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) + } + }() + defer a.activeRequests.Delete(sessionID) + defer cancel() + + if err := a.generate(genCtx, sessionID, content); err != nil { + if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { + // Log the error (avoid logging cancellations as they're expected) + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) + + // You may want to create an error message in the chat + bgCtx := context.Background() + errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) + _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ + Role: message.System, + Parts: []message.ContentPart{ + message.TextContent{ + Text: errorMsg, + }, + }, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) + } + } + } + }() + + return nil +} + +// IsSessionBusy checks if a session currently has an active request +func (a *agent) IsSessionBusy(sessionID string) bool { + _, busy := a.activeRequests.Load(sessionID) + return busy +} // handleTitleGeneration asynchronously generates a title for new sessions +func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { + response, err := a.titleGenerator.SendMessages( ctx, []message.Message{ { @@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st nil, ) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) return } - session, err := c.Sessions.Get(ctx, sessionID) + session, err := a.sessions.Get(ctx, sessionID) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) return } + if response.Content != "" { - session.Title = response.Content - session.Title = strings.TrimSpace(session.Title) + session.Title = strings.TrimSpace(response.Content) session.Title = strings.ReplaceAll(session.Title, "\n", " ") - c.Sessions.Save(ctx, session) + if _, err := a.sessions.Save(ctx, session); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) + } } } -func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := c.Sessions.Get(ctx, sessionID) +// TrackUsage updates token usage statistics for the session +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + session, err := a.sessions.Get(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to get session: %w", err) } cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + @@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M session.CompletionTokens += usage.OutputTokens session.PromptTokens += usage.InputTokens - _, err = c.Sessions.Save(ctx, session) - return err + _, err = a.sessions.Save(ctx, session) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) + } + return nil } -func (c *agent) processEvent( +// processEvent handles different types of events during generation +func (a *agent) processEvent( ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent, ) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing + } + switch event.Type { case provider.EventThinkingDelta: assistantMsg.AppendReasoningContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { - return nil + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled } logging.ErrorPersist(event.Error.Error()) return event.Error case provider.EventWarning: logging.WarnPersist(event.Info) - return nil case provider.EventInfo: logging.InfoPersist(event.Info) case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) - err := c.Messages.Update(ctx, *assistantMsg) - if err != nil { - return err + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) } - return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage) + return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } return nil } -func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - var wg sync.WaitGroup +// ExecuteTools runs all tool calls sequentially and returns the results +func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { toolResults := make([]message.ToolResult, len(toolCalls)) - mutex := &sync.Mutex{} - errChan := make(chan error, 1) // Create a child context that can be canceled ctx, cancel := context.WithCancel(ctx) defer cancel() - for i, tc := range toolCalls { - wg.Add(1) - go func(index int, toolCall message.ToolCall) { - defer wg.Done() + // Check if already canceled before starting any execution + if ctx.Err() != nil { + // Mark all tools as canceled + for i, toolCall := range toolCalls { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + return toolResults, ctx.Err() + } - // Check if context is already canceled - select { - case <-ctx.Done(): - mutex.Lock() - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled", + for i, toolCall := range toolCalls { + // Check for cancellation before executing each tool + select { + case <-ctx.Done(): + // Mark this and all remaining tools as canceled + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", IsError: true, } - mutex.Unlock() - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - return - default: } + return toolResults, ctx.Err() + default: + // Continue processing + } - response := "" - isError := false - found := false - - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled" - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - } else { - response = fmt.Sprintf("error running tool: %s", toolErr) - } - isError = true + response := "" + isError := false + found := false + + // Find and execute the appropriate tool + for _, tool := range tls { + if tool.Info().Name == toolCall.Name { + found = true + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, + }) + + if toolErr != nil { + if errors.Is(toolErr, context.Canceled) { + response = "Tool execution canceled by user" } else { - response = toolResult.Content - isError = toolResult.IsError + response = fmt.Sprintf("Error running tool: %s", toolErr) } - break + isError = true + } else { + response = toolResult.Content + isError = toolResult.IsError } + break } + } - if !found { - response = fmt.Sprintf("tool not found: %s", toolCall.Name) - isError = true - } - - mutex.Lock() - defer mutex.Unlock() - - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - }(i, tc) - } - - // Wait for all goroutines to finish or context to be canceled - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() + if !found { + response = fmt.Sprintf("Tool not found: %s", toolCall.Name) + isError = true + } - select { - case <-done: - // All tools completed successfully - case err := <-errChan: - // One of the tools encountered a cancellation - return toolResults, err - case <-ctx.Done(): - // Context was canceled externally - return toolResults, ctx.Err() + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: response, + IsError: isError, + } } return toolResults, nil } -func (c *agent) handleToolExecution( +// handleToolExecution processes tool calls and creates tool result messages +func (a *agent) handleToolExecution( ctx context.Context, assistantMsg message.Message, ) (*message.Message, error) { + select { + case <-ctx.Done(): + // If cancelled, create tool results that indicate cancellation + if len(assistantMsg.ToolCalls()) > 0 { + toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) + for _, tc := range assistantMsg.ToolCalls() { + toolResults = append(toolResults, message.ToolResult{ + ToolCallID: tc.ID, + Content: "Tool execution canceled by user", + IsError: true, + }) + } + + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if err != nil { + return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) + } + return &msg, ctx.Err() + } + return nil, ctx.Err() + default: + // Continue processing + } + if len(assistantMsg.ToolCalls()) == 0 { return nil, nil } - toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools) + toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) if err != nil { + // If error is from cancellation, still return the partial results we have + if errors.Is(err, context.Canceled) { + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + + msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) + return nil, err + } + return &msg, err + } return nil, err } - parts := make([]message.ContentPart, 0) + + parts := make([]message.ContentPart, 0, len(toolResults)) for _, toolResult := range toolResults { parts = append(parts, toolResult) } - msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + + msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) + if err != nil { + return nil, fmt.Errorf("failed to create tool message: %w", err) + } - return &msg, err + return &msg, nil } -func (c *agent) generate(ctx context.Context, sessionID string, content string) error { +// generate handles the main generation workflow +func (a *agent) generate(ctx context.Context, sessionID string, content string) error { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - messages, err := c.Messages.List(ctx, sessionID) + + // Handle context cancellation at any point + if err := ctx.Err(); err != nil { + return ErrRequestCancelled + } + + messages, err := a.messages.List(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to list messages: %w", err) } if len(messages) == 0 { - go c.handleTitleGeneration(ctx, sessionID, content) + titleCtx := context.Background() + go a.handleTitleGeneration(titleCtx, sessionID, content) } - userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.User, Parts: []message.ContentPart{ message.TextContent{ @@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) }, }) if err != nil { - return err + return fmt.Errorf("failed to create user message: %w", err) } messages = append(messages, userMsg) + for { + // Check for cancellation before each iteration select { case <-ctx.Done(): - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: // Continue processing } - eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) + eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled } - return err + return fmt.Errorf("failed to stream response: %w", err) } - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: c.model.ID, + Model: a.model.ID, }) if err != nil { - return err + return fmt.Errorf("failed to create assistant message: %w", err) } ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + + // Process events from the LLM provider for event := range eventChan { - err = c.processEvent(ctx, sessionID, &assistantMsg, event) - if err != nil { + if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { if errors.Is(err, context.Canceled) { + // Mark as canceled but don't create separate message assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } assistantMsg.AddFinish("error:" + err.Error()) - c.Messages.Update(ctx, assistantMsg) - return err + _ = a.messages.Update(ctx, assistantMsg) + return fmt.Errorf("event processing error: %w", err) } + // Check for cancellation during event processing select { case <-ctx.Done(): + // Mark as canceled assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: } } - // Check for context cancellation before tool execution + // Check for cancellation before tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: - // Continue processing } - msg, err := c.handleToolExecution(ctx, assistantMsg) + // Execute any tool calls + toolMsg, err := a.handleToolExecution(ctx, assistantMsg) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } - return err + return fmt.Errorf("tool execution error: %w", err) } - c.Messages.Update(ctx, assistantMsg) + if err := a.messages.Update(ctx, assistantMsg); err != nil { + return fmt.Errorf("failed to update assistant message: %w", err) + } + // If no tool calls, we're done if len(assistantMsg.ToolCalls()) == 0 { break } + // Add messages for next iteration messages = append(messages, assistantMsg) - if msg != nil { - messages = append(messages, *msg) + if toolMsg != nil { + messages = append(messages, *toolMsg) } - // Check for context cancellation after tool execution + // Check for cancellation after tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: - // Continue processing } } + return nil } +// getAgentProviders initializes the LLM providers based on the chosen model func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { maxTokens := config.Get().Model.CoderMaxTokens providerConfig, ok := config.Get().Providers[model.Provider] if !ok || providerConfig.Disabled { - return nil, nil, errors.New("provider is not enabled") + return nil, nil, ErrProviderNotEnabled } + var agentProvider provider.Provider var titleGenerator provider.Provider + var err error switch model.Provider { case models.ProviderOpenAI: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderOpenAISystemPrompt(), @@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) } + case models.ProviderAnthropic: - var err error agentProvider, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) } + titleGenerator, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.TitlePrompt(), @@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) } case models.ProviderGemini: - var err error agentProvider, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) } + titleGenerator, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) } + case models.ProviderGROQ: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) } case models.ProviderBedrock: - var err error agentProvider, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) } + titleGenerator, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.TitlePrompt(), ), - provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockMaxTokens(80), provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) } - + default: + return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) } return agentProvider, titleGenerator, nil diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index f8e1c40a0..8eea57041 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -4,71 +4,60 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" ) type coderAgent struct { - *agent + Service } -func (c *coderAgent) setAgentTool(sessionID string) { - inx := -1 - for i, tool := range c.tools { - if tool.Info().Name == AgentToolName { - inx = i - break - } - } - if inx == -1 { - c.tools = append(c.tools, NewAgentTool(sessionID, c.App)) - } else { - c.tools[inx] = NewAgentTool(sessionID, c.App) - } -} - -func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error { - c.setAgentTool(sessionID) - return c.generate(ctx, sessionID, content) -} - -func NewCoderAgent(app *app.App) (Agent, error) { +func NewCoderAgent( + permissions permission.Service, + sessions session.Service, + messages message.Service, + lspClients map[string]*lsp.Client, +) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") } ctx := context.Background() - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + agent, err := NewAgent( + ctx, + sessions, + messages, + model, + append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions), + NewAgentTool(sessions, messages), + }, otherTools..., + ), + ) if err != nil { return nil, err } - otherTools := GetMcpTools(ctx, app.Permissions) - if len(app.LSPClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) - } return &coderAgent{ - agent: &agent{ - App: app, - tools: append( - []tools.BaseTool{ - tools.NewBashTool(app.Permissions), - tools.NewEditTool(app.LSPClients, app.Permissions), - tools.NewFetchTool(app.Permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), - tools.NewWriteTool(app.LSPClients, app.Permissions), - }, otherTools..., - ), - model: model, - agent: agentProvider, - titleGenerator: titleGenerator, - }, + agent, }, nil } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index c196cb107..0a072044c 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -4,10 +4,10 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" ) type taskAgent struct { @@ -18,7 +18,7 @@ func (c *taskAgent) Generate(ctx context.Context, sessionID string, content stri return c.generate(ctx, sessionID, content) } -func NewTaskAgent(app *app.App) (Agent, error) { +func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") @@ -31,13 +31,12 @@ func NewTaskAgent(app *app.App) (Agent, error) { } return &taskAgent{ agent: &agent{ - App: app, tools: []tools.BaseTool{ tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), + tools.NewViewTool(lspClients), }, model: model, agent: agentProvider, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 938a8c0ad..34d91f2b7 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -57,7 +57,9 @@ func cleanupMessages(messages []message.Message) []message.Message { // First pass: filter out canceled messages var cleanedMessages []message.Message for _, msg := range messages { - if msg.FinishReason() != "canceled" { + if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { + // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been + // cancelled cleanedMessages = append(cleanedMessages, msg) } } diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index c9a0be079..647b8d35f 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -190,7 +190,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return er, fmt.Errorf("failed to create parent directories: %w", err) } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") } @@ -277,7 +277,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string newContent := oldContent[:index] + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -365,7 +365,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -409,4 +409,3 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return er, nil } - diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 473b787bb..07afe1363 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -66,7 +66,7 @@ type BaseTool interface { Run(ctx context.Context, params ToolCall) (ToolResponse, error) } -func getContextValues(ctx context.Context) (string, string) { +func GetContextValues(ctx context.Context) (string, string) { sessionID := ctx.Value(SessionIDContextKey) messageID := ctx.Value(MessageIDContextKey) if sessionID == nil { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 27c98bb9d..1b087c193 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -144,7 +144,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return NewTextErrorResponse("session ID or message ID is missing"), nil } |
