diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/assets/diff/themes/dark.json | 73 | ||||
| -rw-r--r-- | internal/assets/embed.go | 6 | ||||
| -rw-r--r-- | internal/assets/write.go | 60 | ||||
| -rw-r--r-- | internal/git/diff.go | 265 | ||||
| -rw-r--r-- | internal/llm/agent/agent.go | 3 | ||||
| -rw-r--r-- | internal/llm/tools/bash.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 234 | ||||
| -rw-r--r-- | internal/llm/tools/edit_test.go | 48 | ||||
| -rw-r--r-- | internal/llm/tools/file.go | 10 | ||||
| -rw-r--r-- | internal/llm/tools/tools.go | 15 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 29 | ||||
| -rw-r--r-- | internal/tui/components/dialog/permission.go | 17 |
12 files changed, 593 insertions, 171 deletions
diff --git a/internal/assets/diff/themes/dark.json b/internal/assets/diff/themes/dark.json new file mode 100644 index 000000000..05c18e08c --- /dev/null +++ b/internal/assets/diff/themes/dark.json @@ -0,0 +1,73 @@ +{ + "SYNTAX_HIGHLIGHTING_THEME": "dark-plus", + "DEFAULT_COLOR": { + "color": "#ffffff", + "backgroundColor": "#212121" + }, + "COMMIT_HEADER_COLOR": { + "color": "#cccccc" + }, + "COMMIT_HEADER_LABEL_COLOR": { + "color": "#00000022" + }, + "COMMIT_SHA_COLOR": { + "color": "#00eeaa" + }, + "COMMIT_AUTHOR_COLOR": { + "color": "#00aaee" + }, + "COMMIT_DATE_COLOR": { + "color": "#cccccc" + }, + "COMMIT_MESSAGE_COLOR": { + "color": "#cccccc" + }, + "COMMIT_TITLE_COLOR": { + "modifiers": [ + "bold" + ] + }, + "FILE_NAME_COLOR": { + "color": "#ffdd99" + }, + "BORDER_COLOR": { + "color": "#ffdd9966", + "modifiers": [ + "dim" + ] + }, + "HUNK_HEADER_COLOR": { + "modifiers": [ + "dim" + ] + }, + "DELETED_WORD_COLOR": { + "color": "#ffcccc", + "backgroundColor": "#ff000033" + }, + "INSERTED_WORD_COLOR": { + "color": "#ccffcc", + "backgroundColor": "#00ff0033" + }, + "DELETED_LINE_NO_COLOR": { + "color": "#00000022", + "backgroundColor": "#00000022" + }, + "INSERTED_LINE_NO_COLOR": { + "color": "#00000022", + "backgroundColor": "#00000022" + }, + "UNMODIFIED_LINE_NO_COLOR": { + "color": "#666666" + }, + "DELETED_LINE_COLOR": { + "color": "#cc6666", + "backgroundColor": "#3a3030" + }, + "INSERTED_LINE_COLOR": { + "color": "#66cc66", + "backgroundColor": "#303a30" + }, + "UNMODIFIED_LINE_COLOR": {}, + "MISSING_LINE_COLOR": {} +} diff --git a/internal/assets/embed.go b/internal/assets/embed.go new file mode 100644 index 000000000..9e1316d08 --- /dev/null +++ b/internal/assets/embed.go @@ -0,0 +1,6 @@ +package assets + +import "embed" + +//go:embed diff +var FS embed.FS diff --git a/internal/assets/write.go b/internal/assets/write.go new file mode 100644 index 000000000..602b589ce --- /dev/null +++ b/internal/assets/write.go @@ -0,0 +1,60 @@ +package assets + +import ( + "os" + "path/filepath" + + "github.com/kujtimiihoxha/termai/internal/config" +) + +func WriteAssets() error { + appCfg := config.Get() + appWd := config.WorkingDirectory() + scriptDir := filepath.Join( + appWd, + appCfg.Data.Directory, + "diff", + ) + scriptPath := filepath.Join(scriptDir, "index.mjs") + // Before, run the script in cmd/diff/main.go to build this file + if _, err := os.Stat(scriptPath); err != nil { + scriptData, err := FS.ReadFile("diff/index.mjs") + if err != nil { + return err + } + + err = os.MkdirAll(scriptDir, 0o755) + if err != nil { + return err + } + err = os.WriteFile(scriptPath, scriptData, 0o755) + if err != nil { + return err + } + } + + themeDir := filepath.Join( + appWd, + appCfg.Data.Directory, + "themes", + ) + + themePath := filepath.Join(themeDir, "dark.json") + + if _, err := os.Stat(themePath); err != nil { + themeData, err := FS.ReadFile("diff/themes/dark.json") + if err != nil { + return err + } + + err = os.MkdirAll(themeDir, 0o755) + if err != nil { + return err + } + err = os.WriteFile(themePath, themeData, 0o755) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/git/diff.go b/internal/git/diff.go new file mode 100644 index 000000000..d87956f01 --- /dev/null +++ b/internal/git/diff.go @@ -0,0 +1,265 @@ +package git + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/object" + "github.com/kujtimiihoxha/termai/internal/config" +) + +type DiffStats struct { + Additions int + Removals int +} + +func GenerateGitDiff(filePath string, contentBefore string, contentAfter string) (string, error) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", fmt.Errorf("failed to initialize git repo: %w", err) + } + + wt, err := repo.Worktree() + if err != nil { + return "", fmt.Errorf("failed to get worktree: %w", err) + } + + fullPath := filepath.Join(tempDir, filePath) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", fmt.Errorf("failed to create directories: %w", err) + } + if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { + return "", fmt.Errorf("failed to write 'before' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", fmt.Errorf("failed to add file to git: %w", err) + } + + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "[email protected]", + When: time.Now(), + }, + }) + if err != nil { + return "", fmt.Errorf("failed to commit 'before' version: %w", err) + } + + if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { + return "", fmt.Errorf("failed to write 'after' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", fmt.Errorf("failed to add updated file to git: %w", err) + } + + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "[email protected]", + When: time.Now(), + }, + }) + if err != nil { + return "", fmt.Errorf("failed to commit 'after' version: %w", err) + } + + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", fmt.Errorf("failed to get 'before' commit: %w", err) + } + + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", fmt.Errorf("failed to get 'after' commit: %w", err) + } + + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", fmt.Errorf("failed to generate patch: %w", err) + } + + return patch.String(), nil +} + +func GenerateGitDiffWithStats(filePath string, contentBefore string, contentAfter string) (string, DiffStats, error) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to initialize git repo: %w", err) + } + + wt, err := repo.Worktree() + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get worktree: %w", err) + } + + fullPath := filepath.Join(tempDir, filePath) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to create directories: %w", err) + } + if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to write 'before' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to add file to git: %w", err) + } + + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "[email protected]", + When: time.Now(), + }, + }) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to commit 'before' version: %w", err) + } + + if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to write 'after' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to add updated file to git: %w", err) + } + + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "[email protected]", + When: time.Now(), + }, + }) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to commit 'after' version: %w", err) + } + + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get 'before' commit: %w", err) + } + + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get 'after' commit: %w", err) + } + + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to generate patch: %w", err) + } + + stats := DiffStats{} + for _, fileStat := range patch.Stats() { + stats.Additions += fileStat.Addition + stats.Removals += fileStat.Deletion + } + + return patch.String(), stats, nil +} + +func FormatDiff(diffText string, width int) (string, error) { + if isSplitDiffsAvailable() { + return formatWithSplitDiffs(diffText, width) + } + + return formatSimple(diffText), nil +} + +func isSplitDiffsAvailable() bool { + _, err := exec.LookPath("node") + return err == nil +} + +func formatWithSplitDiffs(diffText string, width int) (string, error) { + var cmd *exec.Cmd + + appCfg := config.Get() + appWd := config.WorkingDirectory() + script := filepath.Join( + appWd, + appCfg.Data.Directory, + "diff", + "index.mjs", + ) + + cmd = exec.Command("node", script, "--color") + + cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width)) + + cmd.Stdin = strings.NewReader(diffText) + + var out bytes.Buffer + cmd.Stdout = &out + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String()) + } + + return out.String(), nil +} + +func formatSimple(diffText string) string { + lines := strings.Split(diffText, "\n") + var result strings.Builder + + for _, line := range lines { + if len(line) == 0 { + result.WriteString("\n") + continue + } + + switch line[0] { + case '+': + result.WriteString("\033[32m" + line + "\033[0m\n") + case '-': + result.WriteString("\033[31m" + line + "\033[0m\n") + case '@': + result.WriteString("\033[36m" + line + "\033[0m\n") + case 'd': + if strings.HasPrefix(line, "diff --git") { + result.WriteString("\033[1m" + line + "\033[0m\n") + } else { + result.WriteString(line + "\n") + } + default: + result.WriteString(line + "\n") + } + } + + if !strings.HasSuffix(diffText, "\n") { + output := result.String() + return output[:len(output)-1] + } + + return result.String() +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b01ffec3c..89de627f7 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -246,6 +246,7 @@ func (c *agent) handleToolExecution( } func (c *agent) generate(ctx context.Context, sessionID string, content string) error { + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) messages, err := c.Messages.List(sessionID) if err != nil { return err @@ -310,6 +311,8 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) if err != nil { return err } + + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) for event := range eventChan { err = c.processEvent(sessionID, &assistantMsg, event) if err != nil { diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index d20afb7f2..d55cb241b 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -22,7 +22,7 @@ type BashPermissionsParams struct { Timeout int `json:"timeout"` } -type BashToolResponseMetadata struct { +type BashResponseMetadata struct { Took int64 `json:"took"` } type bashTool struct { @@ -310,7 +310,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout += "\n" + errorMessage } - metadata := BashToolResponseMetadata{ + metadata := BashResponseMetadata{ Took: took, } if stdout == "" { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 32e2034e4..c9a0be079 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,9 +10,9 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/sergi/go-diff/diffmatchpatch" ) type EditParams struct { @@ -22,10 +22,13 @@ type EditParams struct { } type EditPermissionsParams struct { - FilePath string `json:"file_path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` - Diff string `json:"diff"` + FilePath string `json:"file_path"` + Diff string `json:"diff"` +} + +type EditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` } type editTool struct { @@ -129,48 +132,77 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if params.OldString == "" { - result, err := e.createNewFile(params.FilePath, params.NewString) + result, err := e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } if params.NewString == "" { - result, err := e.deleteContent(params.FilePath, params.OldString) + result, err := e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } - result, err := e.replaceContent(params.FilePath, params.OldString, params.NewString) + result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) - result = fmt.Sprintf("<result>\n%s\n</result>\n", result) - result += appendDiagnostics(params.FilePath, e.lspClients) - return NewTextResponse(result), nil + text := fmt.Sprintf("<result>\n%s\n</result>\n", result.text) + text += appendDiagnostics(params.FilePath, e.lspClients) + return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil +} + +type editResponse struct { + text string + additions int + removals int } -func (e *editTool) createNewFile(filePath, content string) (string, error) { +func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } - return "", fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) + return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) } else if !os.IsNotExist(err) { - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { - return "", fmt.Errorf("failed to create parent directories: %w", err) + return er, fmt.Errorf("failed to create parent directories: %w", err) } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + "", + content, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -178,71 +210,88 @@ func (e *editTool) createNewFile(filePath, content string) (string, error) { Action: "create", Description: fmt.Sprintf("Create file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: "", - NewString: content, - Diff: GenerateDiff("", content), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(content), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - return "File created: " + filePath, nil + er.text = "File created: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) deleteContent(filePath, oldString string) (string, error) { +func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + oldContent[index+len(oldString):] + sessionID, messageID := getContextValues(ctx) + + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -250,76 +299,85 @@ func (e *editTool) deleteContent(filePath, oldString string) (string, error) { Action: "delete", Description: fmt.Sprintf("Delete content from file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: "", - Diff: GenerateDiff(oldContent, newContent), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } - recordFileWrite(filePath) recordFileRead(filePath) - return "Content deleted from file: " + filePath, nil + er.text = "Content deleted from file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) replaceContent(filePath, oldString, newString string) (string, error) { +func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - startIndex := max(0, index-3) - oldEndIndex := min(len(oldContent), index+len(oldString)+3) - newEndIndex := min(len(newContent), index+len(newString)+3) + sessionID, messageID := getContextValues(ctx) - diff := GenerateDiff(oldContent[startIndex:oldEndIndex], newContent[startIndex:newEndIndex]) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -328,75 +386,27 @@ func (e *editTool) replaceContent(filePath, oldString, newString string) (string Action: "replace", Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: newString, - Diff: diff, + FilePath: filePath, + + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) + er.text = "Content replaced in file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals - return "Content replaced in file: " + filePath, nil + return er, nil } -func GenerateDiff(oldContent, newContent string) string { - dmp := diffmatchpatch.New() - fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent) - diffs := dmp.DiffMain(fileAdmp, fileBdmp, false) - diffs = dmp.DiffCharsToLines(diffs, dmpStrings) - diffs = dmp.DiffCleanupSemantic(diffs) - buff := strings.Builder{} - - buff.WriteString("Changes:\n") - - for _, diff := range diffs { - text := diff.Text - - switch diff.Type { - case diffmatchpatch.DiffInsert: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("+ " + line + "\n") - } - case diffmatchpatch.DiffDelete: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("- " + line + "\n") - } - case diffmatchpatch.DiffEqual: - lines := strings.Split(text, "\n") - if len(lines) > 3 { - if lines[0] != "" { - _, _ = buff.WriteString(" " + lines[0] + "\n") - } - _, _ = buff.WriteString(" ...\n") - if lines[len(lines)-1] != "" { - _, _ = buff.WriteString(" " + lines[len(lines)-1] + "\n") - } - } else { - for _, line := range lines { - if line == "" { - continue - } - _, _ = buff.WriteString(" " + line + "\n") - } - } - } - } - return buff.String() -} diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index dbc6e488f..48a34ed75 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -459,51 +459,3 @@ func TestEditTool_Run(t *testing.T) { assert.Equal(t, initialContent, string(fileContent)) }) } - -func TestGenerateDiff(t *testing.T) { - testCases := []struct { - name string - oldContent string - newContent string - expectedDiff string - }{ - { - name: "add content", - oldContent: "Line 1\nLine 2\n", - newContent: "Line 1\nLine 2\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n Line 2\n+ Line 3\n", - }, - { - name: "remove content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n Line 3\n", - }, - { - name: "replace content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nModified Line\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n+ Modified Line\n Line 3\n", - }, - { - name: "empty to content", - oldContent: "", - newContent: "Line 1\nLine 2\n", - expectedDiff: "Changes:\n+ Line 1\n+ Line 2\n", - }, - { - name: "content to empty", - oldContent: "Line 1\nLine 2\n", - newContent: "", - expectedDiff: "Changes:\n- Line 1\n- Line 2\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - diff := GenerateDiff(tc.oldContent, tc.newContent) - assert.Contains(t, diff, tc.expectedDiff) - }) - } -} - diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 7f34fdc1f..9c9707c9c 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,6 +3,8 @@ package tools import ( "sync" "time" + + "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -17,6 +19,14 @@ var ( fileRecordMutex sync.RWMutex ) +func removeWorkingDirectoryPrefix(path string) string { + wd := config.WorkingDirectory() + if len(path) > len(wd) && path[:len(wd)] == wd { + return path[len(wd)+1:] + } + return path +} + func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 6bb528686..473b787bb 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -17,6 +17,9 @@ type toolResponseType string const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" + + SessionIDContextKey = "session_id" + MessageIDContextKey = "message_id" ) type ToolResponse struct { @@ -62,3 +65,15 @@ type BaseTool interface { Info() ToolInfo Run(ctx context.Context, params ToolCall) (ToolResponse, error) } + +func getContextValues(ctx context.Context) (string, string) { + sessionID := ctx.Value(SessionIDContextKey) + messageID := ctx.Value(MessageIDContextKey) + if sessionID == nil { + return "", "" + } + if messageID == nil { + return sessionID.(string), "" + } + return sessionID.(string), messageID.(string) +} diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 7b698d2d8..27c98bb9d 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -20,7 +21,7 @@ type WriteParams struct { type WritePermissionsParams struct { FilePath string `json:"file_path"` - Content string `json:"content"` + Diff string `json:"diff"` } type writeTool struct { @@ -28,6 +29,11 @@ type writeTool struct { permissions permission.Service } +type WriteResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` +} + const ( WriteToolName = "write" writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. @@ -138,6 +144,18 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return NewTextErrorResponse("session ID or message ID is missing"), nil + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + params.Content, + ) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil + } p := w.permissions.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -146,7 +164,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error Description: fmt.Sprintf("Create file %s", filePath), Params: WritePermissionsParams{ FilePath: filePath, - Content: GenerateDiff(oldContent, params.Content), + Diff: diff, }, }, ) @@ -166,5 +184,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result := fmt.Sprintf("File successfully written: %s", filePath) result = fmt.Sprintf("<result>\n%s\n</result>", result) result += appendDiagnostics(filePath, w.lspClients) - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result), + WriteResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }, + ), nil } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 088697d55..344310eb6 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,6 +9,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/tui/components/core" @@ -234,7 +235,6 @@ func (p *permissionDialogCmp) render() string { headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) // Format the diff with colors - formattedDiff := formatDiff(pr.Diff) // Set up viewport for the diff content p.contentViewPort.Width = p.width - 2 - 2 @@ -242,7 +242,11 @@ func (p *permissionDialogCmp) render() string { // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - p.contentViewPort.SetContent(formattedDiff) + diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + if err != nil { + diff = fmt.Sprintf("Error formatting diff: %v", err) + } + p.contentViewPort.SetContent(diff) // Style the viewport var contentBorder lipgloss.Border @@ -281,16 +285,17 @@ func (p *permissionDialogCmp) render() string { // Recreate header content with the updated headerParts headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - // Format the diff with colors - formattedDiff := formatDiff(pr.Content) - // Set up viewport for the content p.contentViewPort.Width = p.width - 2 - 2 // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - p.contentViewPort.SetContent(formattedDiff) + diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + if err != nil { + diff = fmt.Sprintf("Error formatting diff: %v", err) + } + p.contentViewPort.SetContent(diff) // Style the viewport var contentBorder lipgloss.Border |
