diff options
Diffstat (limited to 'internal/llm/tools')
| -rw-r--r-- | internal/llm/tools/bash.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 234 | ||||
| -rw-r--r-- | internal/llm/tools/edit_test.go | 48 | ||||
| -rw-r--r-- | internal/llm/tools/file.go | 10 | ||||
| -rw-r--r-- | internal/llm/tools/tools.go | 15 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 29 |
6 files changed, 175 insertions, 165 deletions
diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index d20afb7f2..d55cb241b 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -22,7 +22,7 @@ type BashPermissionsParams struct { Timeout int `json:"timeout"` } -type BashToolResponseMetadata struct { +type BashResponseMetadata struct { Took int64 `json:"took"` } type bashTool struct { @@ -310,7 +310,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout += "\n" + errorMessage } - metadata := BashToolResponseMetadata{ + metadata := BashResponseMetadata{ Took: took, } if stdout == "" { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 32e2034e4..c9a0be079 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,9 +10,9 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/sergi/go-diff/diffmatchpatch" ) type EditParams struct { @@ -22,10 +22,13 @@ type EditParams struct { } type EditPermissionsParams struct { - FilePath string `json:"file_path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` - Diff string `json:"diff"` + FilePath string `json:"file_path"` + Diff string `json:"diff"` +} + +type EditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` } type editTool struct { @@ -129,48 +132,77 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if params.OldString == "" { - result, err := e.createNewFile(params.FilePath, params.NewString) + result, err := e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } if params.NewString == "" { - result, err := e.deleteContent(params.FilePath, params.OldString) + result, err := e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } - result, err := e.replaceContent(params.FilePath, params.OldString, params.NewString) + result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) - result = fmt.Sprintf("<result>\n%s\n</result>\n", result) - result += appendDiagnostics(params.FilePath, e.lspClients) - return NewTextResponse(result), nil + text := fmt.Sprintf("<result>\n%s\n</result>\n", result.text) + text += appendDiagnostics(params.FilePath, e.lspClients) + return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil +} + +type editResponse struct { + text string + additions int + removals int } -func (e *editTool) createNewFile(filePath, content string) (string, error) { +func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } - return "", fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) + return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) } else if !os.IsNotExist(err) { - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { - return "", fmt.Errorf("failed to create parent directories: %w", err) + return er, fmt.Errorf("failed to create parent directories: %w", err) } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + "", + content, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -178,71 +210,88 @@ func (e *editTool) createNewFile(filePath, content string) (string, error) { Action: "create", Description: fmt.Sprintf("Create file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: "", - NewString: content, - Diff: GenerateDiff("", content), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(content), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - return "File created: " + filePath, nil + er.text = "File created: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) deleteContent(filePath, oldString string) (string, error) { +func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + oldContent[index+len(oldString):] + sessionID, messageID := getContextValues(ctx) + + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -250,76 +299,85 @@ func (e *editTool) deleteContent(filePath, oldString string) (string, error) { Action: "delete", Description: fmt.Sprintf("Delete content from file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: "", - Diff: GenerateDiff(oldContent, newContent), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } - recordFileWrite(filePath) recordFileRead(filePath) - return "Content deleted from file: " + filePath, nil + er.text = "Content deleted from file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) replaceContent(filePath, oldString, newString string) (string, error) { +func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - startIndex := max(0, index-3) - oldEndIndex := min(len(oldContent), index+len(oldString)+3) - newEndIndex := min(len(newContent), index+len(newString)+3) + sessionID, messageID := getContextValues(ctx) - diff := GenerateDiff(oldContent[startIndex:oldEndIndex], newContent[startIndex:newEndIndex]) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -328,75 +386,27 @@ func (e *editTool) replaceContent(filePath, oldString, newString string) (string Action: "replace", Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: newString, - Diff: diff, + FilePath: filePath, + + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) + er.text = "Content replaced in file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals - return "Content replaced in file: " + filePath, nil + return er, nil } -func GenerateDiff(oldContent, newContent string) string { - dmp := diffmatchpatch.New() - fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent) - diffs := dmp.DiffMain(fileAdmp, fileBdmp, false) - diffs = dmp.DiffCharsToLines(diffs, dmpStrings) - diffs = dmp.DiffCleanupSemantic(diffs) - buff := strings.Builder{} - - buff.WriteString("Changes:\n") - - for _, diff := range diffs { - text := diff.Text - - switch diff.Type { - case diffmatchpatch.DiffInsert: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("+ " + line + "\n") - } - case diffmatchpatch.DiffDelete: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("- " + line + "\n") - } - case diffmatchpatch.DiffEqual: - lines := strings.Split(text, "\n") - if len(lines) > 3 { - if lines[0] != "" { - _, _ = buff.WriteString(" " + lines[0] + "\n") - } - _, _ = buff.WriteString(" ...\n") - if lines[len(lines)-1] != "" { - _, _ = buff.WriteString(" " + lines[len(lines)-1] + "\n") - } - } else { - for _, line := range lines { - if line == "" { - continue - } - _, _ = buff.WriteString(" " + line + "\n") - } - } - } - } - return buff.String() -} diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index dbc6e488f..48a34ed75 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -459,51 +459,3 @@ func TestEditTool_Run(t *testing.T) { assert.Equal(t, initialContent, string(fileContent)) }) } - -func TestGenerateDiff(t *testing.T) { - testCases := []struct { - name string - oldContent string - newContent string - expectedDiff string - }{ - { - name: "add content", - oldContent: "Line 1\nLine 2\n", - newContent: "Line 1\nLine 2\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n Line 2\n+ Line 3\n", - }, - { - name: "remove content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n Line 3\n", - }, - { - name: "replace content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nModified Line\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n+ Modified Line\n Line 3\n", - }, - { - name: "empty to content", - oldContent: "", - newContent: "Line 1\nLine 2\n", - expectedDiff: "Changes:\n+ Line 1\n+ Line 2\n", - }, - { - name: "content to empty", - oldContent: "Line 1\nLine 2\n", - newContent: "", - expectedDiff: "Changes:\n- Line 1\n- Line 2\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - diff := GenerateDiff(tc.oldContent, tc.newContent) - assert.Contains(t, diff, tc.expectedDiff) - }) - } -} - diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 7f34fdc1f..9c9707c9c 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,6 +3,8 @@ package tools import ( "sync" "time" + + "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -17,6 +19,14 @@ var ( fileRecordMutex sync.RWMutex ) +func removeWorkingDirectoryPrefix(path string) string { + wd := config.WorkingDirectory() + if len(path) > len(wd) && path[:len(wd)] == wd { + return path[len(wd)+1:] + } + return path +} + func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 6bb528686..473b787bb 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -17,6 +17,9 @@ type toolResponseType string const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" + + SessionIDContextKey = "session_id" + MessageIDContextKey = "message_id" ) type ToolResponse struct { @@ -62,3 +65,15 @@ type BaseTool interface { Info() ToolInfo Run(ctx context.Context, params ToolCall) (ToolResponse, error) } + +func getContextValues(ctx context.Context) (string, string) { + sessionID := ctx.Value(SessionIDContextKey) + messageID := ctx.Value(MessageIDContextKey) + if sessionID == nil { + return "", "" + } + if messageID == nil { + return sessionID.(string), "" + } + return sessionID.(string), messageID.(string) +} diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 7b698d2d8..27c98bb9d 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -20,7 +21,7 @@ type WriteParams struct { type WritePermissionsParams struct { FilePath string `json:"file_path"` - Content string `json:"content"` + Diff string `json:"diff"` } type writeTool struct { @@ -28,6 +29,11 @@ type writeTool struct { permissions permission.Service } +type WriteResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` +} + const ( WriteToolName = "write" writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. @@ -138,6 +144,18 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return NewTextErrorResponse("session ID or message ID is missing"), nil + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + params.Content, + ) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil + } p := w.permissions.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -146,7 +164,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error Description: fmt.Sprintf("Create file %s", filePath), Params: WritePermissionsParams{ FilePath: filePath, - Content: GenerateDiff(oldContent, params.Content), + Diff: diff, }, }, ) @@ -166,5 +184,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result := fmt.Sprintf("File successfully written: %s", filePath) result = fmt.Sprintf("<result>\n%s\n</result>", result) result += appendDiagnostics(filePath, w.lspClients) - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result), + WriteResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }, + ), nil } |
