diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-16 20:06:23 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-21 13:42:00 +0200 |
| commit | bbfa60c787f2ec459f1689b9a650ddbec9693ed9 (patch) | |
| tree | f7f2aa31c460c8cc22ec40cc299c386277152241 /internal/llm/agent | |
| parent | 76b4065f17b87a63092acfd98c997bab53700b35 (diff) | |
| download | opencode-bbfa60c787f2ec459f1689b9a650ddbec9693ed9.tar.gz opencode-bbfa60c787f2ec459f1689b9a650ddbec9693ed9.zip | |
reimplement agent,provider and add file history
Diffstat (limited to 'internal/llm/agent')
| -rw-r--r-- | internal/llm/agent/agent-tool.go | 18 | ||||
| -rw-r--r-- | internal/llm/agent/agent.go | 861 | ||||
| -rw-r--r-- | internal/llm/agent/coder.go | 63 | ||||
| -rw-r--r-- | internal/llm/agent/mcp-tools.go | 4 | ||||
| -rw-r--r-- | internal/llm/agent/task.go | 47 | ||||
| -rw-r--r-- | internal/llm/agent/tools.go | 50 |
6 files changed, 372 insertions, 671 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 83160bb64..308412bde 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) + agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - err = agent.Generate(ctx, session.ID, params.Prompt) + done, err := agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } - - messages, err := b.messages.List(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) - } - - if len(messages) == 0 { - return tools.NewTextErrorResponse("no response"), nil + result := <-done + if result.Err() != nil { + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err()) } - response := messages[len(messages)-1] + response := result.Response() if response.Role != message.Assistant { return tools.NewTextErrorResponse("no response"), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 1958111a1..ab2742ec1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "runtime/debug" "strings" "sync" @@ -16,133 +14,101 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/session" ) // Common errors var ( - ErrProviderNotEnabled = errors.New("provider is not enabled") - ErrRequestCancelled = errors.New("request cancelled by user") - ErrSessionBusy = errors.New("session is currently processing another request") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") ) -// Service defines the interface for generating responses +type AgentEvent struct { + message message.Message + err error +} + +func (e *AgentEvent) Err() error { + return e.err +} + +func (e *AgentEvent) Response() message.Message { + return e.message +} + type Service interface { - Generate(ctx context.Context, sessionID string, content string) error - Cancel(sessionID string) error + Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) + Cancel(sessionID string) + IsSessionBusy(sessionID string) bool } type agent struct { - sessions session.Service - messages message.Service - model models.Model - tools []tools.BaseTool - agent provider.Provider - titleGenerator provider.Provider - activeRequests sync.Map // map[sessionID]context.CancelFunc + sessions session.Service + messages message.Service + + tools []tools.BaseTool + provider provider.Provider + + titleProvider provider.Provider + + activeRequests sync.Map } -// NewAgent creates a new agent instance with the given model and tools -func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) +func NewAgent( + agentName config.AgentName, + sessions session.Service, + messages message.Service, + agentTools []tools.BaseTool, +) (Service, error) { + agentProvider, err := createAgentProvider(agentName) if err != nil { - return nil, fmt.Errorf("failed to initialize providers: %w", err) + return nil, err + } + var titleProvider provider.Provider + // Only generate titles for the coder agent + if agentName == config.AgentCoder { + titleProvider, err = createAgentProvider(config.AgentTitle) + if err != nil { + return nil, err + } } - return &agent{ - model: model, - tools: tools, - sessions: sessions, + agent := &agent{ + provider: agentProvider, messages: messages, - agent: agentProvider, - titleGenerator: titleGenerator, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, activeRequests: sync.Map{}, - }, nil + } + + return agent, nil } -// Cancel cancels an active request by session ID -func (a *agent) Cancel(sessionID string) error { +func (a *agent) Cancel(sessionID string) { if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) cancel() - return nil } } - return errors.New("no active request found for this session") } -// Generate starts the generation process -func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { - // Check if this session already has an active request - if _, busy := a.activeRequests.Load(sessionID); busy { - return ErrSessionBusy - } - - // Create a cancellable context - genCtx, cancel := context.WithCancel(ctx) - - // Store cancel function to allow user cancellation - a.activeRequests.Store(sessionID, cancel) - - // Launch the generation in a goroutine - go func() { - defer func() { - if r := recover(); r != nil { - logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) - - // dump stack trace into a file - file, err := os.Create("panic.log") - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) - return - } - - defer file.Close() - - stackTrace := debug.Stack() - if _, err := file.Write(stackTrace); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) - } - - } - }() - defer a.activeRequests.Delete(sessionID) - defer cancel() - - if err := a.generate(genCtx, sessionID, content); err != nil { - if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { - // Log the error (avoid logging cancellations as they're expected) - logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) - - // You may want to create an error message in the chat - bgCtx := context.Background() - errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) - _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ - Role: message.System, - Parts: []message.ContentPart{ - message.TextContent{ - Text: errorMsg, - }, - }, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) - } - } - } - }() - - return nil -} - -// IsSessionBusy checks if a session currently has an active request func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy -} // handleTitleGeneration asynchronously generates a title for new sessions -func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := a.titleGenerator.SendMessages( +} + +func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { + if a.titleProvider == nil { + return nil + } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return err + } + response, err := a.titleProvider.SendMessages( ctx, []message.Message{ { @@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st }, }, }, - nil, + make([]tools.BaseTool, 0), ) if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) - return + return err } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) - return + title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " ")) + if title == "" { + return nil } - if response.Content != "" { - session.Title = strings.TrimSpace(response.Content) - session.Title = strings.ReplaceAll(session.Title, "\n", " ") - if _, err := a.sessions.Save(ctx, session); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) - } + session.Title = title + _, err = a.sessions.Save(ctx, session) + return err +} + +func (a *agent) err(err error) AgentEvent { + return AgentEvent{ + err: err, } } -// TrackUsage updates token usage statistics for the session -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) +func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) { + events := make(chan AgentEvent) + if a.IsSessionBusy(sessionID) { + return nil, ErrSessionBusy } - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) + genCtx, cancel := context.WithCancel(ctx) + + a.activeRequests.Store(sessionID, cancel) + go func() { + logging.Debug("Request started", "sessionID", sessionID) + defer logging.RecoverPanic("agent.Run", func() { + events <- a.err(fmt.Errorf("panic while running the agent")) + }) - session.Cost += cost - session.CompletionTokens += usage.OutputTokens - session.PromptTokens += usage.InputTokens + result := a.processGeneration(genCtx, sessionID, content) + if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) { + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result)) + } + logging.Debug("Request completed", "sessionID", sessionID) + a.activeRequests.Delete(sessionID) + cancel() + events <- result + close(events) + }() + return events, nil +} - _, err = a.sessions.Save(ctx, session) +func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent { + // List existing messages; if none, start title generation asynchronously. + msgs, err := a.messages.List(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to save session: %w", err) + return a.err(fmt.Errorf("failed to list messages: %w", err)) + } + if len(msgs) == 0 { + go func() { + defer logging.RecoverPanic("agent.Run", func() { + logging.ErrorPersist("panic while generating title") + }) + titleErr := a.generateTitle(context.Background(), sessionID, content) + if titleErr != nil { + logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr)) + } + }() } - return nil -} -// processEvent handles different types of events during generation -func (a *agent) processEvent( - ctx context.Context, - sessionID string, - assistantMsg *message.Message, - event provider.ProviderEvent, -) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing + userMsg, err := a.createUserMessage(ctx, sessionID, content) + if err != nil { + return a.err(fmt.Errorf("failed to create user message: %w", err)) } - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - if errors.Is(event.Error, context.Canceled) { - logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) - return context.Canceled + // Append the new user message to the conversation history. + msgHistory := append(msgs, userMsg) + for { + // Check for cancellation before each iteration + select { + case <-ctx.Done(): + return a.err(ctx.Err()) + default: + // Continue processing } - logging.ErrorPersist(event.Error.Error()) - return event.Error - case provider.EventWarning: - logging.WarnPersist(event.Info) - case provider.EventInfo: - logging.InfoPersist(event.Info) - case provider.EventComplete: - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason) - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) + if err != nil { + if errors.Is(err, context.Canceled) { + return a.err(ErrRequestCancelled) + } + return a.err(fmt.Errorf("failed to process events: %w", err)) + } + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { + // We are not done, we need to respond with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } + return AgentEvent{ + message: agentMessage, } - return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } +} - return nil +func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) { + return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: content}, + }, + }) } -// ExecuteTools runs all tool calls sequentially and returns the results -func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - toolResults := make([]message.ToolResult, len(toolCalls)) +func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + }) + if err != nil { + return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) + } - // Create a child context that can be canceled - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Add the session and message ID into the context if needed by tools. + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - // Check if already canceled before starting any execution - if ctx.Err() != nil { - // Mark all tools as canceled - for i, toolCall := range toolCalls { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled by user", - IsError: true, - } + // Process each event in the stream. + for event := range eventChan { + if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { + a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, processErr + } + if ctx.Err() != nil { + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, ctx.Err() } - return toolResults, ctx.Err() } + toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) + toolCalls := assistantMsg.ToolCalls() for i, toolCall := range toolCalls { - // Check for cancellation before executing each tool select { case <-ctx.Done(): - // Mark this and all remaining tools as canceled + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + // Make all future tool calls cancelled for j := i; j < len(toolCalls); j++ { toolResults[j] = message.ToolResult{ ToolCallID: toolCalls[j].ID, @@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, IsError: true, } } - return toolResults, ctx.Err() + goto out default: // Continue processing - } - - response := "" - isError := false - found := false - - // Find and execute the appropriate tool - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled by user" - } else { - response = fmt.Sprintf("Error running tool: %s", toolErr) - } - isError = true - } else { - response = toolResult.Content - isError = toolResult.IsError + var tool tools.BaseTool + for _, availableTools := range a.tools { + if availableTools.Info().Name == toolCall.Name { + tool = availableTools } - break } - } - - if !found { - response = fmt.Sprintf("Tool not found: %s", toolCall.Name) - isError = true - } - - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - } - return toolResults, nil -} - -// handleToolExecution processes tool calls and creates tool result messages -func (a *agent) handleToolExecution( - ctx context.Context, - assistantMsg message.Message, -) (*message.Message, error) { - select { - case <-ctx.Done(): - // If cancelled, create tool results that indicate cancellation - if len(assistantMsg.ToolCalls()) > 0 { - toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) - for _, tc := range assistantMsg.ToolCalls() { - toolResults = append(toolResults, message.ToolResult{ - ToolCallID: tc.ID, - Content: "Tool execution canceled by user", + // Tool not found + if tool == nil { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), IsError: true, - }) + } + continue } - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) - } - msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, }) - if err != nil { - return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - return &msg, ctx.Err() - } - return nil, ctx.Err() - default: - // Continue processing - } - - if len(assistantMsg.ToolCalls()) == 0 { - return nil, nil - } - - toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) - if err != nil { - // If error is from cancellation, still return the partial results we have - if errors.Is(err, context.Canceled) { - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) + if toolErr != nil { + if errors.Is(toolErr, permission.ErrorPermissionDenied) { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Permission denied", + IsError: true, + } + for j := i + 1; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) + } else { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolErr.Error(), + IsError: true, + } + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Previous tool failed", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + } + // If permission is denied or an error happens we cancel all the following tools + break } - - msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) - return nil, err + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolResult.Content, + Metadata: toolResult.Metadata, + IsError: toolResult.IsError, } - return &msg, err } - return nil, err } - - parts := make([]message.ContentPart, 0, len(toolResults)) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) +out: + if len(toolResults) == 0 { + return assistantMsg, nil, nil } - - msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + parts := make([]message.ContentPart, 0) + for _, tr := range toolResults { + parts = append(parts, tr) + } + msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) if err != nil { - return nil, fmt.Errorf("failed to create tool message: %w", err) + return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) } - return &msg, nil + return assistantMsg, &msg, err } -// generate handles the main generation workflow -func (a *agent) generate(ctx context.Context, sessionID string, content string) error { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) +func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) { + msg.AddFinish(finishReson) + _ = a.messages.Update(ctx, *msg) +} - // Handle context cancellation at any point - if err := ctx.Err(); err != nil { - return ErrRequestCancelled +func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing. } - messages, err := a.messages.List(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to list messages: %w", err) + switch event.Type { + case provider.EventThinkingDelta: + assistantMsg.AppendReasoningContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventContentDelta: + assistantMsg.AppendContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventError: + if errors.Is(event.Error, context.Canceled) { + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled + } + logging.ErrorPersist(event.Error.Error()) + return event.Error + case provider.EventComplete: + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) } - if len(messages) == 0 { - titleCtx := context.Background() - go a.handleTitleGeneration(titleCtx, sessionID, content) - } + return nil +} - userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: content, - }, - }, - }) +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + sess, err := a.sessions.Get(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to create user message: %w", err) + return fmt.Errorf("failed to get session: %w", err) } - messages = append(messages, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - // Continue processing - } - - eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) - if err != nil { - if errors.Is(err, context.Canceled) { - return ErrRequestCancelled - } - return fmt.Errorf("failed to stream response: %w", err) - } - - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.model.ID, - }) - if err != nil { - return fmt.Errorf("failed to create assistant message: %w", err) - } - - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - - // Process events from the LLM provider - for event := range eventChan { - if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { - if errors.Is(err, context.Canceled) { - // Mark as canceled but don't create separate message - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - assistantMsg.AddFinish("error:" + err.Error()) - _ = a.messages.Update(ctx, assistantMsg) - return fmt.Errorf("event processing error: %w", err) - } - - // Check for cancellation during event processing - select { - case <-ctx.Done(): - // Mark as canceled - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - } - - // Check for cancellation before tool execution - select { - case <-ctx.Done(): - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - - // Execute any tool calls - toolMsg, err := a.handleToolExecution(ctx, assistantMsg) - if err != nil { - if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - return fmt.Errorf("tool execution error: %w", err) - } - - if err := a.messages.Update(ctx, assistantMsg); err != nil { - return fmt.Errorf("failed to update assistant message: %w", err) - } - - // If no tool calls, we're done - if len(assistantMsg.ToolCalls()) == 0 { - break - } + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) - // Add messages for next iteration - messages = append(messages, assistantMsg) - if toolMsg != nil { - messages = append(messages, *toolMsg) - } + sess.Cost += cost + sess.CompletionTokens += usage.OutputTokens + sess.PromptTokens += usage.InputTokens - // Check for cancellation after tool execution - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - } + _, err = a.sessions.Save(ctx, sess) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) } - return nil } -// getAgentProviders initializes the LLM providers based on the chosen model -func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { - maxTokens := config.Get().Model.CoderMaxTokens - - providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || providerConfig.Disabled { - return nil, nil, ErrProviderNotEnabled +func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { + cfg := config.Get() + agentConfig, ok := cfg.Agents[agentName] + if !ok { + return nil, fmt.Errorf("agent %s not found", agentName) + } + model, ok := models.SupportedModels[agentConfig.Model] + if !ok { + return nil, fmt.Errorf("model %s not supported", agentConfig.Model) } - var agentProvider provider.Provider - var titleGenerator provider.Provider - var err error - - switch model.Provider { - case models.ProviderOpenAI: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) - } - - case models.ProviderAnthropic: - agentProvider, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithAnthropicMaxTokens(maxTokens), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) - } - - titleGenerator, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithAnthropicMaxTokens(80), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) - } - - case models.ProviderGemini: - agentProvider, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithGeminiMaxTokens(int32(maxTokens)), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) - } - - titleGenerator, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithGeminiMaxTokens(80), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) - } - - case models.ProviderGROQ: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) - } - - case models.ProviderBedrock: - agentProvider, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithBedrockMaxTokens(maxTokens), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) - } - - titleGenerator, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithBedrockMaxTokens(80), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) - } - default: - return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) + providerCfg, ok := cfg.Providers[model.Provider] + if !ok { + return nil, fmt.Errorf("provider %s not supported", model.Provider) + } + if providerCfg.Disabled { + return nil, fmt.Errorf("provider %s is not enabled", model.Provider) + } + agentProvider, err := provider.NewProvider( + model.Provider, + provider.WithAPIKey(providerCfg.APIKey), + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), + provider.WithMaxTokens(agentConfig.MaxTokens), + ) + if err != nil { + return nil, fmt.Errorf("could not create provider: %v", err) } - return agentProvider, titleGenerator, nil + return agentProvider, nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go deleted file mode 100644 index a3db6b55c..000000000 --- a/internal/llm/agent/coder.go +++ /dev/null @@ -1,63 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type coderAgent struct { - Service -} - -func NewCoderAgent( - permissions permission.Service, - sessions session.Service, - messages message.Service, - lspClients map[string]*lsp.Client, -) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ), - ) - if err != nil { - return nil, err - } - - return &coderAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index b1c97b512..c7ea4916c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } @@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go deleted file mode 100644 index fca1f223f..000000000 --- a/internal/llm/agent/task.go +++ /dev/null @@ -1,47 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type taskAgent struct { - Service -} - -func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - ) - if err != nil { - return nil, err - } - - return &taskAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 000000000..a37f1d65d --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,50 @@ +package agent + +import ( + "context" + + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +func CoderAgentTools( + permissions permission.Service, + sessions session.Service, + messages message.Service, + history history.Service, + lspClients map[string]*lsp.Client, +) []tools.BaseTool { + ctx := context.Background() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + return append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + }, otherTools..., + ) +} + +func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { + return []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + } +} |
