diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-03 15:20:15 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-03 17:23:41 +0200 |
| commit | cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0 (patch) | |
| tree | a822bfde1463a7080c0ea06dd17796d7a1617d3d /internal/llm/tools | |
| parent | afd9ad0560d76c2a6d161dad52553b10ff428905 (diff) | |
| download | opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.tar.gz opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.zip | |
add initial lsp support
Diffstat (limited to 'internal/llm/tools')
| -rw-r--r-- | internal/llm/tools/diagnostics.go | 229 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 21 | ||||
| -rw-r--r-- | internal/llm/tools/shell/shell.go | 2 | ||||
| -rw-r--r-- | internal/llm/tools/view.go | 26 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 18 | ||||
| -rw-r--r-- | internal/llm/tools/write_test.go | 76 |
6 files changed, 315 insertions, 57 deletions
diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go new file mode 100644 index 000000000..dc90e5860 --- /dev/null +++ b/internal/llm/tools/diagnostics.go @@ -0,0 +1,229 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + "time" + + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/protocol" +) + +type diagnosticsTool struct { + lspClients map[string]*lsp.Client +} + +const ( + DiagnosticsToolName = "diagnostics" +) + +type DiagnosticsParams struct { + FilePath string `json:"file_path"` +} + +func (b *diagnosticsTool) Info() ToolInfo { + return ToolInfo{ + Name: DiagnosticsToolName, + Description: "Get diagnostics for a file and/or project.", + Parameters: map[string]any{ + "file_path": map[string]any{ + "type": "string", + "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)", + }, + }, + Required: []string{}, + } +} + +func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params DiagnosticsParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + + lsps := b.lspClients + + if len(lsps) == 0 { + return NewTextErrorResponse("no LSP clients available"), nil + } + + if params.FilePath == "" { + notifyLspOpenFile(ctx, params.FilePath, lsps) + } + + output := appendDiagnostics(params.FilePath, lsps) + + return NewTextResponse(output), nil +} + +func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) { + for _, client := range lsps { + err := client.OpenFile(ctx, filePath) + if err != nil { + // Wait for the file to be opened and diagnostics to be received + // TODO: see if we can do this in a more efficient way + time.Sleep(2 * time.Second) + } + + } +} + +func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { + fileDiagnostics := []string{} + projectDiagnostics := []string{} + + // Enhanced format function that includes more diagnostic information + formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string { + // Base components + severity := "Info" + switch diagnostic.Severity { + case protocol.SeverityError: + severity = "Error" + case protocol.SeverityWarning: + severity = "Warn" + case protocol.SeverityHint: + severity = "Hint" + } + + // Location information + location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1) + + // Source information (LSP name) + sourceInfo := "" + if diagnostic.Source != "" { + sourceInfo = diagnostic.Source + } else if source != "" { + sourceInfo = source + } + + // Code information + codeInfo := "" + if diagnostic.Code != nil { + codeInfo = fmt.Sprintf("[%v]", diagnostic.Code) + } + + // Tags information + tagsInfo := "" + if len(diagnostic.Tags) > 0 { + tags := []string{} + for _, tag := range diagnostic.Tags { + switch tag { + case protocol.Unnecessary: + tags = append(tags, "unnecessary") + case protocol.Deprecated: + tags = append(tags, "deprecated") + } + } + if len(tags) > 0 { + tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", ")) + } + } + + // Assemble the full diagnostic message + return fmt.Sprintf("%s: %s [%s]%s%s %s", + severity, + location, + sourceInfo, + codeInfo, + tagsInfo, + diagnostic.Message) + } + + for lspName, client := range lsps { + diagnostics := client.GetDiagnostics() + if len(diagnostics) > 0 { + for location, diags := range diagnostics { + isCurrentFile := location.Path() == filePath + + // Group diagnostics by severity for better organization + for _, diag := range diags { + formattedDiag := formatDiagnostic(location.Path(), diag, lspName) + + if isCurrentFile { + fileDiagnostics = append(fileDiagnostics, formattedDiag) + } else { + projectDiagnostics = append(projectDiagnostics, formattedDiag) + } + } + } + } + } + + // Sort diagnostics by severity (errors first) and then by location + sort.Slice(fileDiagnostics, func(i, j int) bool { + iIsError := strings.HasPrefix(fileDiagnostics[i], "Error") + jIsError := strings.HasPrefix(fileDiagnostics[j], "Error") + if iIsError != jIsError { + return iIsError // Errors come first + } + return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically + }) + + sort.Slice(projectDiagnostics, func(i, j int) bool { + iIsError := strings.HasPrefix(projectDiagnostics[i], "Error") + jIsError := strings.HasPrefix(projectDiagnostics[j], "Error") + if iIsError != jIsError { + return iIsError + } + return projectDiagnostics[i] < projectDiagnostics[j] + }) + + output := "" + + if len(fileDiagnostics) > 0 { + output += "\n<file_diagnostics>\n" + if len(fileDiagnostics) > 10 { + output += strings.Join(fileDiagnostics[:10], "\n") + output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10) + } else { + output += strings.Join(fileDiagnostics, "\n") + } + output += "\n</file_diagnostics>\n" + } + + if len(projectDiagnostics) > 0 { + output += "\n<project_diagnostics>\n" + if len(projectDiagnostics) > 10 { + output += strings.Join(projectDiagnostics[:10], "\n") + output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10) + } else { + output += strings.Join(projectDiagnostics, "\n") + } + output += "\n</project_diagnostics>\n" + } + + // Add summary counts + if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 { + fileErrors := countSeverity(fileDiagnostics, "Error") + fileWarnings := countSeverity(fileDiagnostics, "Warn") + projectErrors := countSeverity(projectDiagnostics, "Error") + projectWarnings := countSeverity(projectDiagnostics, "Warn") + + output += "\n<diagnostic_summary>\n" + output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings) + output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings) + output += "</diagnostic_summary>\n" + } + + return output +} + +// Helper function to count diagnostics by severity +func countSeverity(diagnostics []string, severity string) int { + count := 0 + for _, diag := range diagnostics { + if strings.HasPrefix(diag, severity) { + count++ + } + } + return count +} + +func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool { + return &diagnosticsTool{ + lspClients, + } +} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 8c5427a58..c84bbd7a0 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,11 +10,14 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/sergi/go-diff/diffmatchpatch" ) -type editTool struct{} +type editTool struct { + lspClients map[string]*lsp.Client +} const ( EditToolName = "edit" @@ -71,6 +74,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) params.FilePath = filepath.Join(wd, params.FilePath) } + notifyLspOpenFile(ctx, params.FilePath, e.lspClients) if params.OldString == "" { result, err := createNewFile(params.FilePath, params.NewString) if err != nil { @@ -91,6 +95,9 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } + + result = fmt.Sprintf("<result>\n%s\n</result>\n", result) + result += appendDiagnostics(params.FilePath, e.lspClients) return NewTextResponse(result), nil } @@ -296,18 +303,18 @@ func GenerateDiff(oldContent, newContent string) string { switch diff.Type { case diffmatchpatch.DiffInsert: - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString("+ " + line + "\n") } case diffmatchpatch.DiffDelete: - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString("- " + line + "\n") } case diffmatchpatch.DiffEqual: if len(text) > 40 { _, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n") } else { - for _, line := range strings.Split(text, "\n") { + for line := range strings.SplitSeq(text, "\n") { _, _ = buff.WriteString(" " + line + "\n") } } @@ -366,6 +373,8 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` } -func NewEditTool() BaseTool { - return &editTool{} +func NewEditTool(lspClients map[string]*lsp.Client) BaseTool { + return &editTool{ + lspClients, + } } diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index d76cb1a2e..64592f67d 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -221,7 +221,7 @@ func (s *PersistentShell) killChildren() { return } - for _, pidStr := range strings.Split(string(output), "\n") { + for pidStr := range strings.SplitSeq(string(output), "\n") { if pidStr = strings.TrimSpace(pidStr); pidStr != "" { var pid int fmt.Sscanf(pidStr, "%d", &pid) diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index dca522b9c..743cef6f4 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -11,9 +11,12 @@ import ( "strings" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" ) -type viewTool struct{} +type viewTool struct { + lspClients map[string]*lsp.Client +} const ( ViewToolName = "view" @@ -127,15 +130,18 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil } + notifyLspOpenFile(ctx, filePath, v.lspClients) + output := "<file>\n" // Format the output with line numbers - output := addLineNumbers(content, params.Offset+1) + output += addLineNumbers(content, params.Offset+1) // Add a note if the content was truncated if lineCount > params.Offset+len(strings.Split(content, "\n")) { output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)", params.Offset+len(strings.Split(content, "\n"))) } - + output += "\n</file>\n" + output += appendDiagnostics(filePath, v.lspClients) recordFileRead(filePath) return NewTextResponse(output), nil } @@ -155,10 +161,10 @@ func addLineNumbers(content string, startLine int) string { numStr := fmt.Sprintf("%d", lineNum) if len(numStr) >= 6 { - result = append(result, fmt.Sprintf("%s\t%s", numStr, line)) + result = append(result, fmt.Sprintf("%s|%s", numStr, line)) } else { paddedNum := fmt.Sprintf("%6s", numStr) - result = append(result, fmt.Sprintf("%s\t|%s", paddedNum, line)) + result = append(result, fmt.Sprintf("%s|%s", paddedNum, line)) } } @@ -173,8 +179,9 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) { defer file.Close() lineCount := 0 + + scanner := NewLineScanner(file) if offset > 0 { - scanner := NewLineScanner(file) for lineCount < offset && scanner.Scan() { lineCount++ } @@ -192,7 +199,6 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) { var lines []string lineCount = offset - scanner := NewLineScanner(file) for scanner.Scan() && len(lines) < limit { lineCount++ @@ -290,6 +296,8 @@ TIPS: - When viewing large files, use the offset parameter to read specific sections` } -func NewViewTool() BaseTool { - return &viewTool{} +func NewViewTool(lspClients map[string]*lsp.Client) BaseTool { + return &viewTool{ + lspClients, + } } diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 003753d08..3d66d64e2 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,10 +9,13 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) -type writeTool struct{} +type writeTool struct { + lspClients map[string]*lsp.Client +} const ( WriteToolName = "write" @@ -96,6 +99,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error if err = os.MkdirAll(dir, 0o755); err != nil { return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil } + + notifyLspOpenFile(ctx, filePath, w.lspClients) p := permission.Default.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -122,7 +127,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error recordFileWrite(filePath) recordFileRead(filePath) - return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil + result := fmt.Sprintf("File successfully written: %s", filePath) + result = fmt.Sprintf("<result>\n%s\n</result>", result) + result += appendDiagnostics(filePath, w.lspClients) + return NewTextResponse(result), nil } func writeDescription() string { @@ -156,6 +164,8 @@ TIPS: - Always include descriptive comments when making changes to existing code` } -func NewWriteTool() BaseTool { - return &writeTool{} +func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool { + return &writeTool{ + lspClients, + } } diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 1c92e3baa..893a48b62 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -8,13 +8,14 @@ import ( "testing" "time" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool() + tool := NewWriteTool(make(map[string]*lsp.Client)) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -40,11 +41,11 @@ func TestWriteTool_Run(t *testing.T) { t.Run("creates a new file successfully", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" - + params := WriteParams{ FilePath: filePath, Content: content, @@ -70,11 +71,11 @@ func TestWriteTool_Run(t *testing.T) { t.Run("creates file with nested directories", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" - + params := WriteParams{ FilePath: filePath, Content: content, @@ -100,17 +101,17 @@ func TestWriteTool_Run(t *testing.T) { t.Run("updates existing file", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0644) + err := os.WriteFile(filePath, []byte(initialContent), 0o644) require.NoError(t, err) - + // Record the file read to avoid modification time check failure recordFileRead(filePath) - + // Update the file updatedContent := "Updated content" params := WriteParams{ @@ -138,8 +139,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles invalid parameters", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + call := ToolCall{ Name: WriteToolName, Input: "invalid json", @@ -152,8 +153,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles missing file_path", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + params := WriteParams{ FilePath: "", Content: "Some content", @@ -174,8 +175,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles missing content", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), Content: "", @@ -196,13 +197,13 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles writing to a directory path", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a directory dirPath := filepath.Join(tempDir, "test_dir") - err := os.Mkdir(dirPath, 0755) + err := os.Mkdir(dirPath, 0o755) require.NoError(t, err) - + params := WriteParams{ FilePath: dirPath, Content: "Some content", @@ -223,8 +224,8 @@ func TestWriteTool_Run(t *testing.T) { t.Run("handles permission denied", func(t *testing.T) { permission.Default = newMockPermissionService(false) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ FilePath: filePath, @@ -242,7 +243,7 @@ func TestWriteTool_Run(t *testing.T) { response, err := tool.Run(context.Background(), call) require.NoError(t, err) assert.Contains(t, response.Content, "Permission denied") - + // Verify file was not created _, err = os.Stat(filePath) assert.True(t, os.IsNotExist(err)) @@ -250,14 +251,14 @@ func TestWriteTool_Run(t *testing.T) { t.Run("detects file modified since last read", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0644) + err := os.WriteFile(filePath, []byte(initialContent), 0o644) require.NoError(t, err) - + // Record an old read time fileRecordMutex.Lock() fileRecords[filePath] = fileRecord{ @@ -265,7 +266,7 @@ func TestWriteTool_Run(t *testing.T) { readTime: time.Now().Add(-1 * time.Hour), } fileRecordMutex.Unlock() - + // Try to update the file params := WriteParams{ FilePath: filePath, @@ -283,7 +284,7 @@ func TestWriteTool_Run(t *testing.T) { response, err := tool.Run(context.Background(), call) require.NoError(t, err) assert.Contains(t, response.Content, "has been modified since it was last read") - + // Verify file was not modified fileContent, err := os.ReadFile(filePath) require.NoError(t, err) @@ -292,17 +293,17 @@ func TestWriteTool_Run(t *testing.T) { t.Run("skips writing when content is identical", func(t *testing.T) { permission.Default = newMockPermissionService(true) - tool := NewWriteTool() - + tool := NewWriteTool(make(map[string]*lsp.Client)) + // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") content := "Content that won't change" - err := os.WriteFile(filePath, []byte(content), 0644) + err := os.WriteFile(filePath, []byte(content), 0o644) require.NoError(t, err) - + // Record a read time recordFileRead(filePath) - + // Try to write the same content params := WriteParams{ FilePath: filePath, @@ -321,4 +322,5 @@ func TestWriteTool_Run(t *testing.T) { require.NoError(t, err) assert.Contains(t, response.Content, "already contains the exact content") }) -}
\ No newline at end of file +} + |
