diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-03 15:20:15 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-03 17:23:41 +0200 |
| commit | cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0 (patch) | |
| tree | a822bfde1463a7080c0ea06dd17796d7a1617d3d /internal/llm | |
| parent | afd9ad0560d76c2a6d161dad52553b10ff428905 (diff) | |
| download | opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.tar.gz opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.zip | |
add initial lsp support
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/agent/agent-tool.go | 2 | ||||
| -rw-r--r-- | internal/llm/agent/agent.go | 49 | ||||
| -rw-r--r-- | internal/llm/agent/coder.go | 13 | ||||
| -rw-r--r-- | internal/llm/agent/task.go | 2 | ||||
| -rw-r--r-- | internal/llm/prompt/coder.go | 23 | ||||
| -rw-r--r-- | internal/llm/provider/anthropic.go | 21 | ||||
| -rw-r--r-- | internal/llm/provider/gemini.go | 256 | ||||
| -rw-r--r-- | internal/llm/provider/openai.go | 15 | ||||
| -rw-r--r-- | internal/llm/provider/provider.go | 7 | ||||
| -rw-r--r-- | internal/llm/tools/diagnostics.go | 229 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 21 | ||||
| -rw-r--r-- | internal/llm/tools/shell/shell.go | 2 | ||||
| -rw-r--r-- | internal/llm/tools/view.go | 26 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 18 | ||||
| -rw-r--r-- | internal/llm/tools/write_test.go | 76 |
15 files changed, 511 insertions, 249 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index a87f48339..bf5e31f8f 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -91,7 +91,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - return tools.NewTextResponse(response.Content), nil + return tools.NewTextResponse(response.Content().String()), nil } func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 06dbca4e8..cb123e78c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "strings" "sync" "github.com/kujtimiihoxha/termai/internal/app" @@ -33,8 +34,12 @@ func (c *agent) handleTitleGeneration(sessionID, content string) { c.Context, []message.Message{ { - Role: message.User, - Content: content, + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{ + Text: content, + }, + }, }, }, nil, @@ -49,6 +54,8 @@ func (c *agent) handleTitleGeneration(sessionID, content string) { } if response.Content != "" { session.Title = response.Content + session.Title = strings.TrimSpace(session.Title) + session.Title = strings.ReplaceAll(session.Title, "\n", " ") c.Sessions.Save(session) } } @@ -79,17 +86,18 @@ func (c *agent) processEvent( ) error { switch event.Type { case provider.EventThinkingDelta: - assistantMsg.Thinking += event.Thinking + assistantMsg.AppendReasoningContent(event.Content) return c.Messages.Update(*assistantMsg) case provider.EventContentDelta: - assistantMsg.Content += event.Content + assistantMsg.AppendContent(event.Content) return c.Messages.Update(*assistantMsg) case provider.EventError: log.Println("error", event.Error) return event.Error case provider.EventComplete: - assistantMsg.ToolCalls = event.Response.ToolCalls + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) err := c.Messages.Update(*assistantMsg) if err != nil { return err @@ -157,18 +165,21 @@ func (c *agent) handleToolExecution( ctx context.Context, assistantMsg message.Message, ) (*message.Message, error) { - if len(assistantMsg.ToolCalls) == 0 { + if len(assistantMsg.ToolCalls()) == 0 { return nil, nil } - toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools) + toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools) if err != nil { return nil, err } - + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - ToolResults: toolResults, + Role: message.Tool, + Parts: parts, }) return &msg, err @@ -185,8 +196,12 @@ func (c *agent) generate(sessionID string, content string) error { } userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ - Role: message.User, - Content: content, + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{ + Text: content, + }, + }, }) if err != nil { return err @@ -201,8 +216,8 @@ func (c *agent) generate(sessionID string, content string) error { } assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Content: "", + Role: message.Assistant, + Parts: []message.ContentPart{}, }) if err != nil { return err @@ -210,20 +225,20 @@ func (c *agent) generate(sessionID string, content string) error { for event := range eventChan { err = c.processEvent(sessionID, &assistantMsg, event) if err != nil { - assistantMsg.Finished = true + assistantMsg.AddFinish("error:" + err.Error()) c.Messages.Update(assistantMsg) return err } } msg, err := c.handleToolExecution(c.Context, assistantMsg) - assistantMsg.Finished = true + c.Messages.Update(assistantMsg) if err != nil { return err } - if len(assistantMsg.ToolCalls) == 0 { + if len(assistantMsg.ToolCalls()) == 0 { break } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index 38dfd2de1..d167ede99 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -44,20 +44,23 @@ func NewCoderAgent(app *app.App) (Agent, error) { return nil, err } - mcpTools := GetMcpTools(app.Context) + otherTools := GetMcpTools(app.Context) + if len(app.LSPClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) + } return &coderAgent{ agent: &agent{ App: app, tools: append( []tools.BaseTool{ tools.NewBashTool(), - tools.NewEditTool(), + tools.NewEditTool(app.LSPClients), tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), - tools.NewViewTool(), - tools.NewWriteTool(), - }, mcpTools..., + tools.NewViewTool(app.LSPClients), + tools.NewWriteTool(app.LSPClients), + }, otherTools..., ), model: model, agent: agentProvider, diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index 9e0311ed1..97611e62b 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -34,7 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) { tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), - tools.NewViewTool(), + tools.NewViewTool(app.LSPClients), }, model: model, agent: agentProvider, diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index cd71bc005..27bb7e431 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -67,7 +67,7 @@ Never commit changes unless the user explicitly asks you to.` envInfo := getEnvironmentInfo() - return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo) + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) } func CoderAnthropicSystemPrompt() string { @@ -168,7 +168,7 @@ You MUST answer concisely with fewer than 4 lines of text (not including tool us envInfo := getEnvironmentInfo() - return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo) + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) } func getEnvironmentInfo() string { @@ -198,6 +198,25 @@ func isGitRepo(dir string) bool { return err == nil } +func lspInformation() string { + cfg := config.Get() + hasLSP := false + for _, v := range cfg.LSP { + if !v.Disabled { + hasLSP = true + break + } + } + if !hasLSP { + return "" + } + return `# LSP Information +Tools that support it will also include useful diagnostics such as linting and typechecking. +These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the <file_diagnostics></file_diagnostics> and <project_diagnostics></project_diagnostics> tags. +Take necessary actions to fix the issues. +` +} + func boolToYesNo(b bool) string { if b { return "Yes" diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 63a68b92b..2b960ebca 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -111,7 +111,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa var thinkingParam anthropic.ThinkingConfigParamUnion lastMessage := messages[len(messages)-1] temperature := anthropic.Float(0) - if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") { + if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") { thinkingParam = anthropic.ThinkingConfigParamUnion{ OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ BudgetTokens: int64(float64(a.maxTokens) * 0.8), @@ -187,9 +187,10 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, + Content: content, + ToolCalls: toolCalls, + Usage: tokenUsage, + FinishReason: string(accumulatedMessage.StopReason), }, } } @@ -263,7 +264,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag for i, msg := range messages { switch msg.Role { case message.User: - content := anthropic.NewTextBlock(msg.Content) + content := anthropic.NewTextBlock(msg.Content().String()) if cachedBlocks < 2 { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", @@ -274,8 +275,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag case message.Assistant: blocks := []anthropic.ContentBlockParamUnion{} - if msg.Content != "" { - content := anthropic.NewTextBlock(msg.Content) + if msg.Content().String() != "" { + content := anthropic.NewTextBlock(msg.Content().String()) if cachedBlocks < 2 { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", @@ -285,7 +286,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag blocks = append(blocks, content) } - for _, toolCall := range msg.ToolCalls { + for _, toolCall := range msg.ToolCalls() { var inputMap map[string]any err := json.Unmarshal([]byte(toolCall.Input), &inputMap) if err != nil { @@ -297,8 +298,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...) case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults)) - for i, toolResult := range msg.ToolResults { + results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) + for i, toolResult := range msg.ToolResults() { results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) } anthropicMessages[i] = anthropic.NewUserMessage(results...) diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 6b252b581..53ffa154e 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -78,7 +78,6 @@ func (p *geminiProvider) Close() { } } -// convertToGeminiHistory converts the message history to Gemini's format func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content { var history []*genai.Content @@ -86,7 +85,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g switch msg.Role { case message.User: history = append(history, &genai.Content{ - Parts: []genai.Part{genai.Text(msg.Content)}, + Parts: []genai.Part{genai.Text(msg.Content().String())}, Role: "user", }) case message.Assistant: @@ -95,14 +94,12 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g Parts: []genai.Part{}, } - // Handle regular content - if msg.Content != "" { - content.Parts = append(content.Parts, genai.Text(msg.Content)) + if msg.Content().String() != "" { + content.Parts = append(content.Parts, genai.Text(msg.Content().String())) } - // Handle tool calls if any - if len(msg.ToolCalls) > 0 { - for _, call := range msg.ToolCalls { + if len(msg.ToolCalls()) > 0 { + for _, call := range msg.ToolCalls() { args, _ := parseJsonToMap(call.Input) content.Parts = append(content.Parts, genai.FunctionCall{ Name: call.Name, @@ -113,8 +110,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g history = append(history, content) case message.Tool: - for _, result := range msg.ToolResults { - // Parse response content to map if possible + for _, result := range msg.ToolResults() { response := map[string]interface{}{"result": result.Content} parsed, err := parseJsonToMap(result.Content) if err == nil { @@ -123,7 +119,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g var toolCall message.ToolCall for _, msg := range messages { if msg.Role == message.Assistant { - for _, call := range msg.ToolCalls { + for _, call := range msg.ToolCalls() { if call.ID == result.ToolCallID { toolCall = call break @@ -146,108 +142,6 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g return history } -// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations -func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { - declarations := make([]*genai.FunctionDeclaration, len(tools)) - - for i, tool := range tools { - info := tool.Info() - - // Convert parameters to genai.Schema format - properties := make(map[string]*genai.Schema) - for name, param := range info.Parameters { - // Try to extract type and description from the parameter - paramMap, ok := param.(map[string]interface{}) - if !ok { - // Default to string if unable to determine type - properties[name] = &genai.Schema{Type: genai.TypeString} - continue - } - - schemaType := genai.TypeString // Default - var description string - var itemsTypeSchema *genai.Schema - if typeVal, found := paramMap["type"]; found { - if typeStr, ok := typeVal.(string); ok { - switch typeStr { - case "string": - schemaType = genai.TypeString - case "number": - schemaType = genai.TypeNumber - case "integer": - schemaType = genai.TypeInteger - case "boolean": - schemaType = genai.TypeBoolean - case "array": - schemaType = genai.TypeArray - items, found := paramMap["items"] - if found { - itemsMap, ok := items.(map[string]interface{}) - if ok { - itemsType, found := itemsMap["type"] - if found { - itemsTypeStr, ok := itemsType.(string) - if ok { - switch itemsTypeStr { - case "string": - itemsTypeSchema = &genai.Schema{ - Type: genai.TypeString, - } - case "number": - itemsTypeSchema = &genai.Schema{ - Type: genai.TypeNumber, - } - case "integer": - itemsTypeSchema = &genai.Schema{ - Type: genai.TypeInteger, - } - case "boolean": - itemsTypeSchema = &genai.Schema{ - Type: genai.TypeBoolean, - } - } - } - } - } - } - case "object": - schemaType = genai.TypeObject - if _, found := paramMap["properties"]; !found { - continue - } - // TODO: Add support for other types - } - } - } - - if desc, found := paramMap["description"]; found { - if descStr, ok := desc.(string); ok { - description = descStr - } - } - - properties[name] = &genai.Schema{ - Type: schemaType, - Description: description, - Items: itemsTypeSchema, - } - } - - declarations[i] = &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: properties, - Required: info.Required, - }, - } - } - - return declarations -} - -// extractTokenUsage extracts token usage information from Gemini's response func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage { if resp == nil || resp.UsageMetadata == nil { return TokenUsage{} @@ -261,41 +155,28 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) } } -// SendMessages sends a batch of messages to Gemini and returns the response func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - // Create a generative model model := p.client.GenerativeModel(p.model.APIModel) model.SetMaxOutputTokens(p.maxTokens) - // Set system instruction model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) - // Set up tools if provided if len(tools) > 0 { declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}} + for _, declaration := range declarations { + model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) + } } - // Create chat session and set history chat := model.StartChat() chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message - // Get the most recent user message - var lastUserMsg message.Message - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == message.User { - lastUserMsg = messages[i] - break - } - } - - // Send the message - resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content)) + lastUserMsg := messages[len(messages)-1] + resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String())) if err != nil { return nil, err } - // Process the response var content string var toolCalls []message.ToolCall @@ -317,7 +198,6 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me } } - // Extract token usage tokenUsage := p.extractTokenUsage(resp) return &ProviderResponse{ @@ -327,16 +207,12 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me }, nil } -// StreamResponse streams the response from Gemini func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - // Create a generative model model := p.client.GenerativeModel(p.model.APIModel) model.SetMaxOutputTokens(p.maxTokens) - // Set system instruction model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) - // Set up tools if provided if len(tools) > 0 { declarations := p.convertToolsToGeminiFunctionDeclarations(tools) for _, declaration := range declarations { @@ -344,14 +220,12 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message. } } - // Create chat session and set history chat := model.StartChat() chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message lastUserMsg := messages[len(messages)-1] - // Start streaming - iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content)) + iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String())) eventChan := make(chan ProviderEvent) @@ -392,7 +266,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message. } currentContent += newText case genai.FunctionCall: - // For function calls, we assume they come complete, not streamed in parts id := "call_" + uuid.New().String() args, _ := json.Marshal(p.Args) newCall := message.ToolCall{ @@ -402,7 +275,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message. Type: "function", } - // Check if this is a new tool call isNew := true for _, existing := range toolCalls { if existing.Name == newCall.Name && existing.Input == newCall.Input { @@ -419,15 +291,15 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message. } } - // Extract token usage from the final response tokenUsage := p.extractTokenUsage(finalResp) eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, + Content: currentContent, + ToolCalls: toolCalls, + Usage: tokenUsage, + FinishReason: string(finalResp.Candidates[0].FinishReason.String()), }, } }() @@ -435,7 +307,99 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message. return eventChan, nil } -// Helper function to parse JSON string into map +func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { + declarations := make([]*genai.FunctionDeclaration, len(tools)) + + for i, tool := range tools { + info := tool.Info() + declarations[i] = &genai.FunctionDeclaration{ + Name: info.Name, + Description: info.Description, + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: convertSchemaProperties(info.Parameters), + Required: info.Required, + }, + } + } + + return declarations +} + +func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema { + properties := make(map[string]*genai.Schema) + + for name, param := range parameters { + properties[name] = convertToSchema(param) + } + + return properties +} + +func convertToSchema(param interface{}) *genai.Schema { + schema := &genai.Schema{Type: genai.TypeString} + + paramMap, ok := param.(map[string]interface{}) + if !ok { + return schema + } + + if desc, ok := paramMap["description"].(string); ok { + schema.Description = desc + } + + typeVal, hasType := paramMap["type"] + if !hasType { + return schema + } + + typeStr, ok := typeVal.(string) + if !ok { + return schema + } + + schema.Type = mapJSONTypeToGenAI(typeStr) + + switch typeStr { + case "array": + schema.Items = processArrayItems(paramMap) + case "object": + if props, ok := paramMap["properties"].(map[string]interface{}); ok { + schema.Properties = convertSchemaProperties(props) + } + } + + return schema +} + +func processArrayItems(paramMap map[string]interface{}) *genai.Schema { + items, ok := paramMap["items"].(map[string]interface{}) + if !ok { + return nil + } + + return convertToSchema(items) +} + +func mapJSONTypeToGenAI(jsonType string) genai.Type { + switch jsonType { + case "string": + return genai.TypeString + case "number": + return genai.TypeNumber + case "integer": + return genai.TypeInteger + case "boolean": + return genai.TypeBoolean + case "array": + return genai.TypeArray + case "object": + return genai.TypeObject + default: + return genai.TypeString // Default to string for unknown types + } +} + func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { var result map[string]interface{} err := json.Unmarshal([]byte(jsonStr), &result) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index d86a58690..c8e04d5ee 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -84,22 +84,22 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o for _, msg := range messages { switch msg.Role { case message.User: - chatMessages = append(chatMessages, openai.UserMessage(msg.Content)) + chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String())) case message.Assistant: assistantMsg := openai.ChatCompletionAssistantMessageParam{ Role: "assistant", } - if msg.Content != "" { + if msg.Content().String() != "" { assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ - OfString: openai.String(msg.Content), + OfString: openai.String(msg.Content().String()), } } - if len(msg.ToolCalls) > 0 { - assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls)) - for i, call := range msg.ToolCalls { + if len(msg.ToolCalls()) > 0 { + assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls())) + for i, call := range msg.ToolCalls() { assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{ ID: call.ID, Type: "function", @@ -116,7 +116,7 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o }) case message.Tool: - for _, result := range msg.ToolResults { + for _, result := range msg.ToolResults() { chatMessages = append(chatMessages, openai.ToolMessage(result.Content, result.ToolCallID), ) @@ -276,3 +276,4 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message. return eventChan, nil } + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 9ac1def37..f40429738 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -27,9 +27,10 @@ type TokenUsage struct { } type ProviderResponse struct { - Content string - ToolCalls []message.ToolCall - Usage TokenUsage + Content string + ToolCalls []message.ToolCall + Usage TokenUsage + FinishReason string } type ProviderEvent struct { diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go new file mode 100644 index 000000000..dc90e5860 --- /dev/null +++ b/internal/llm/tools/diagnostics.go @@ -0,0 +1,229 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + "time" + + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/protocol" +) + +type diagnosticsTool struct { + lspClients map[string]*lsp.Client +} + +const ( + DiagnosticsToolName = "diagnostics" +) + +type DiagnosticsParams struct { + FilePath string `json:"file_path"` +} + +func (b *diagnosticsTool) Info() ToolInfo { + return ToolInfo{ + Name: DiagnosticsToolName, + Description: "Get diagnostics for a file and/or project.", + Parameters: map[string]any{ + "file_path": map[string]any{ + "type": "string", + "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)", + }, + }, + Required: []string{}, + } +} + +func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params DiagnosticsParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + + lsps := b.lspClients + + if len(lsps) == 0 { + return NewTextErrorResponse("no LSP clients available"), nil + } + + if params.FilePath == "" { + notifyLspOpenFile(ctx, params.FilePath, lsps) + } + + output := appendDiagnostics(params.FilePath, lsps) + + return NewTextResponse(output), nil +} + +func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) { + for _, client := range lsps { + err := client.OpenFile(ctx, filePath) + if err != nil { + // Wait for the file to be opened and diagnostics to be received + // TODO: see if we can do this in a more efficient way + time.Sleep(2 * time.Second) + } + + } +} + +func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { + fileDiagnostics := []string{} + projectDiagnostics := []string{} + + // Enhanced format function that includes more diagnostic information + formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string { + // Base components + severity := "Info" + switch diagnostic.Severity { + case protocol.SeverityError: + severity = "Error" + case protocol.SeverityWarning: + severity = "Warn" + case protocol.SeverityHint: + severity = "Hint" + } + + // Location information + location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1) + + // Source information (LSP name) + sourceInfo := "" + if diagnostic.Source != "" { + sourceInfo = diagnostic.Source + } else if source != "" { + sourceInfo = source + } + + // Code information + codeInfo := "" + if diagnostic.Code != nil { + codeInfo = fmt.Sprintf("[%v]", diagnostic.Code) + } + + // Tags information + tagsInfo := "" + if len(diagnostic.Tags) > 0 { + tags := []string{} + for _, tag := range diagnostic.Tags { + switch tag { + case protocol.Unnecessary: + tags = append(tags, "unnecessary") + case protocol.Deprecated: + tags = append(tags, "deprecated") + } + } + if len(tags) > 0 { + tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", ")) + } + } + + // Assemble the full diagnostic message + return fmt.Sprintf("%s: %s [%s]%s%s %s", + severity, + location, + sourceInfo, + codeInfo, + tagsInfo, + diagnostic.Message) + } + + for lspName, client := range lsps { + diagnostics := client.GetDiagnostics() + if len(diagnostics) > 0 { + for location, diags := range diagnostics { + isCurrentFile := location.Path() == filePath + + // Group diagnostics by severity for better organization + for _, diag := range diags { + formattedDiag := formatDiagnostic(location.Path(), diag, lspName) + + if isCurrentFile { + fileDiagnostics = append(fileDiagnostics, formattedDiag) + } else { + projectDiagnostics = append(projectDiagnostics, formattedDiag) + } + } + } + } + } + + // Sort diagnostics by severity (errors first) and then by location + sort.Slice(fileDiagnostics, func(i, j int) bool { + iIsError := strings.HasPrefix(fileDiagnostics[i], "Error") + jIsError := strings.HasPrefix(fileDiagnostics[j], "Error") + if iIsError != jIsError { + return iIsError // Errors come first + } + return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically + }) + + sort.Slice(projectDiagnostics, func(i, j int) bool { + iIsError := strings.HasPrefix(projectDiagnostics[i], "Error") + jIsError := strings.HasPrefix(projectDiagnostics[j], "Error") + if iIsError != jIsError { + return iIsError + } + return projectDiagnostics[i] < projectDiagnostics[j] + }) + + output := "" + + if len(fileDiagnostics) > 0 { + output += "\n<file_diagnostics>\n" + if len(fileDiagnostics) > 10 { + output += strings.Join(fileDiagnostics[:10], "\n") + output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10) + } else { + output += strings.Join(fileDiagnostics, "\n") + } + output += "\n</file_diagnostics>\n" + } + + if len(projectDiagnostics) > 0 { + output += "\n<project_diagnostics>\n" + if len(projectDiagnostics) > 10 { + output += strings.Join(projectDiagnostics[:10], "\n") + output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10) + } else { + output += strings.Join(projectDiagnostics, "\n") + } + output += "\n</project_diagnostics>\n" + } + + // Add summary counts + if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 { + fileErrors := countSeverity(fileDiagnostics, "Error") + fileWarnings := countSeverity(fileDiagnostics, "Warn") + projectErrors := countSeverity(projectDiagnostics, "Error") + projectWarnings := countSeverity(projectDiagnostics, "Warn") + + output += "\n<diagnostic_summary>\n" + output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings) + output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings) + output += "</diagnostic_summary>\n" + } + + return output +} + +// Helper function to count diagnostics by severity +func countSeverity(diagnostics []string, severity string) int { + count := 0 + for _, diag := range diagnostics { + if strings.HasPrefix(diag, severity) { + count++ + } + } + return count +} + +func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { + return &diagnosticsTool{ + lspClients, + } +} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 8c5427a58..c84bbd7a0 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,11 +10,14 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/sergi/go-diff/diffmatchpatch" ) -type editTool struct{} +type editTool struct { + lspClients map[string]*lsp.Client +} const ( EditToolName = "edit" @@ -71,6 +74,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) params.FilePath = filepath.Join(wd, params.FilePath) } + notifyLspOpenFile(ctx, params.FilePath, e.lspClients) if params.OldString == "" { result, err := createNewFile(params.FilePath, params.NewString) if err != nil { @@ -91,6 +95,9 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } + + result = fmt.Sprintf("<result>\n%s\n</result>\n", result) + result += appendDiagnostics(params.FilePath, e.lspClients) return NewTextResponse(result), nil } @@ -296,18 +303,18 @@ func GenerateDiff(oldContent, newContent string) string { switch diff.Type { case diffmatchpatch.DiffInsert: - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString("+ " + line + "\n") } case diffmatchpatch.DiffDelete: - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString("- " + line + "\n") } case diffmatchpatch.DiffEqual: if len(text) > 40 { _, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n") } else { - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString(" " + line + "\n") } } @@ -366,6 +373,8 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` } -func NewEditTool() BaseTool { - return &editTool{} +func NewEditTool(lspClients map[string]*lsp.Client) BaseTool { + return &editTool{ + lspClients, + } } diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index d76cb1a2e..64592f67d 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -221,7 +221,7 @@ func (s *PersistentShell) killChildren() { return } - for _, pidStr := range strings.Split(string(output), "\n") { + for pidStr := range strings.SplitSeq(string(output), "\n") { if pidStr = strings.TrimSpace(pidStr); pidStr != "" { var pid int fmt.Sscanf(pidStr, "%d", &pid) diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index dca522b9c..743cef6f4 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -11,9 +11,12 @@ import ( "strings" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" ) -type viewTool struct{} +type viewTool struct { + lspClients map[string]*lsp.Client +} const ( ViewToolName = "view" @@ -127,15 +130,18 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil } + notifyLspOpenFile(ctx, filePath, v.lspClients) + output := "<file>\n" // Format the output with line numbers - output := addLineNumbers(content, params.Offset+1) + output += addLineNumbers(content, params.Offset+1) // Add a note if the content was truncated if lineCount > params.Offset+len(strings.Split(content, "\n")) { output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)", params.Offset+len(strings.Split(content, "\n"))) } - + output += "\n</file>\n" + output += appendDiagnostics(filePath, v.lspClients) recordFileRead(filePath) return NewTextResponse(output), nil } @@ -155,10 +161,10 @@ func addLineNumbers(content string, startLine int) string { numStr := fmt.Sprintf("%d", lineNum) if len(numStr) >= 6 { - result = append(result, fmt.Sprintf("%s\t%s", numStr, line)) + result = append(result, fmt.Sprintf("%s|%s", numStr, line)) } else { paddedNum := fmt.Sprintf("%6s", numStr) - result = append(result, fmt.Sprintf("%s\t|%s", paddedNum, line)) + result = append(result, fmt.Sprintf("%s|%s", paddedNum, line)) } } @@ -173,8 +179,9 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) { defer file.Close() lineCount := 0 + + scanner := NewLineScanner(file) if offset > 0 { - scanner := NewLineScanner(file) for lineCount < offset && scanner.Scan() { lineCount++ } @@ -192,7 +199,6 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) { var lines []string lineCount = offset - scanner := NewLineScanner(file) for scanner.Scan() && len(lines) < limit { lineCount++ @@ -290,6 +296,8 @@ TIPS: - When viewing large files, use the offset parameter to read specific sections` } -func NewViewTool() BaseTool { - return &viewTool{} +func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { + return &viewTool{ + lspClients, + } } diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 003753d08..3d66d64e2 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,10 +9,13 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) -type writeTool struct{} +type writeTool struct { + lspClients map[string]*lsp.Client +} const ( WriteToolName = "write" @@ -96,6 +99,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error if err = os.MkdirAll(dir, 0o755); err != nil { return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil } + + notifyLspOpenFile(ctx, filePath, w.lspClients) p := permission.Default.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -122,7 +127,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error recordFileWrite(filePath) recordFileRead(filePath) - return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil + result := fmt.Sprintf("File successfully written: %s", filePath) + result = fmt.Sprintf("<result>\n%s\n</result>", result) + result += appendDiagnostics(filePath, w.lspClients) + return NewTextResponse(result), nil } func writeDescription() string { @@ -156,6 +164,8 @@ TIPS: - Always include descriptive comments when making changes to existing code` } -func NewWriteTool() BaseTool { - return &writeTool{} +func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool { + return &writeTool{ + lspClients, + } } diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 1c92e3baa..893a48b62 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -8,13 +8,14 @@ import ( "testing" "time" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool() + tool := NewWriteTool(make(map[string]*lsp.Client)) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -40,11 +41,11 @@ func TestWriteTool_Run(t *testing.T) { t.Run("creates a new file successfully", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" - + params := WriteParams{ FilePath: filePath, Content: content, @@ -70,11 +71,11 @@ func TestWriteTool_Run(t *testing.T) { t.Run("creates file with nested directories", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" - + params := WriteParams{ FilePath: filePath, Content: content, @@ -100,17 +101,17 @@ func TestWriteTool_Run(t *testing.T) { t.Run("updates existing file", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0644) + err := os.WriteFile(filePath, []byte(initialContent), 0o644) require.NoError(t, err) - + // Record the file read to avoid modification time check failure recordFileRead(filePath) - + // Update the file updatedContent := "Updated content" params := WriteParams{ @@ -138,8 +139,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles invalid parameters", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + call := ToolCall{ Name: WriteToolName, Input: "invalid json", @@ -152,8 +153,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles missing file_path", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + params := WriteParams{ FilePath: "", Content: "Some content", @@ -174,8 +175,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles missing content", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), Content: "", @@ -196,13 +197,13 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles writing to a directory path", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a directory dirPath := filepath.Join(tempDir, "test_dir") - err := os.Mkdir(dirPath, 0755) + err := os.Mkdir(dirPath, 0o755) require.NoError(t, err) - + params := WriteParams{ FilePath: dirPath, Content: "Some content", @@ -223,8 +224,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles permission denied", func(t *testing.T) { permission.Default = newMockPermissionService(false) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ FilePath: filePath, @@ -242,7 +243,7 @@ func TestWriteTool_Run(t *testing.T) { response, err := tool.Run(context.Background(), call) require.NoError(t, err) assert.Contains(t, response.Content, "Permission denied") - + // Verify file was not created _, err = os.Stat(filePath) assert.True(t, os.IsNotExist(err)) @@ -250,14 +251,14 @@ func TestWriteTool_Run(t *testing.T) { t.Run("detects file modified since last read", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0644) + err := os.WriteFile(filePath, []byte(initialContent), 0o644) require.NoError(t, err) - + // Record an old read time fileRecordMutex.Lock() fileRecords[filePath] = fileRecord{ @@ -265,7 +266,7 @@ func TestWriteTool_Run(t *testing.T) { readTime: time.Now().Add(-1 * time.Hour), } fileRecordMutex.Unlock() - + // Try to update the file params := WriteParams{ FilePath: filePath, @@ -283,7 +284,7 @@ func TestWriteTool_Run(t *testing.T) { response, err := tool.Run(context.Background(), call) require.NoError(t, err) assert.Contains(t, response.Content, "has been modified since it was last read") - + // Verify file was not modified fileContent, err := os.ReadFile(filePath) require.NoError(t, err) @@ -292,17 +293,17 @@ func TestWriteTool_Run(t *testing.T) { t.Run("skips writing when content is identical", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") content := "Content that won't change" - err := os.WriteFile(filePath, []byte(content), 0644) + err := os.WriteFile(filePath, []byte(content), 0o644) require.NoError(t, err) - + // Record a read time recordFileRead(filePath) - + // Try to write the same content params := WriteParams{ FilePath: filePath, @@ -321,4 +322,5 @@ func TestWriteTool_Run(t *testing.T) { require.NoError(t, err) assert.Contains(t, response.Content, "already contains the exact content") }) -}
\ No newline at end of file +} + |
