summaryrefslogtreecommitdiffhomepage
path: root/internal/llm/tools
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm/tools')
-rw-r--r--internal/llm/tools/bash.go4
-rw-r--r--internal/llm/tools/edit.go234
-rw-r--r--internal/llm/tools/edit_test.go48
-rw-r--r--internal/llm/tools/file.go10
-rw-r--r--internal/llm/tools/tools.go15
-rw-r--r--internal/llm/tools/write.go29
6 files changed, 175 insertions, 165 deletions
diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go
index d20afb7f2..d55cb241b 100644
--- a/internal/llm/tools/bash.go
+++ b/internal/llm/tools/bash.go
@@ -22,7 +22,7 @@ type BashPermissionsParams struct {
Timeout int `json:"timeout"`
}
-type BashToolResponseMetadata struct {
+type BashResponseMetadata struct {
Took int64 `json:"took"`
}
type bashTool struct {
@@ -310,7 +310,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
stdout += "\n" + errorMessage
}
- metadata := BashToolResponseMetadata{
+ metadata := BashResponseMetadata{
Took: took,
}
if stdout == "" {
diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go
index 32e2034e4..c9a0be079 100644
--- a/internal/llm/tools/edit.go
+++ b/internal/llm/tools/edit.go
@@ -10,9 +10,9 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/git"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
- "github.com/sergi/go-diff/diffmatchpatch"
)
type EditParams struct {
@@ -22,10 +22,13 @@ type EditParams struct {
}
type EditPermissionsParams struct {
- FilePath string `json:"file_path"`
- OldString string `json:"old_string"`
- NewString string `json:"new_string"`
- Diff string `json:"diff"`
+ FilePath string `json:"file_path"`
+ Diff string `json:"diff"`
+}
+
+type EditResponseMetadata struct {
+ Additions int `json:"additions"`
+ Removals int `json:"removals"`
}
type editTool struct {
@@ -129,48 +132,77 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
if params.OldString == "" {
- result, err := e.createNewFile(params.FilePath, params.NewString)
+ result, err := e.createNewFile(ctx, params.FilePath, params.NewString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil
}
- return NewTextResponse(result), nil
+ return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{
+ Additions: result.additions,
+ Removals: result.removals,
+ }), nil
}
if params.NewString == "" {
- result, err := e.deleteContent(params.FilePath, params.OldString)
+ result, err := e.deleteContent(ctx, params.FilePath, params.OldString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil
}
- return NewTextResponse(result), nil
+ return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{
+ Additions: result.additions,
+ Removals: result.removals,
+ }), nil
}
- result, err := e.replaceContent(params.FilePath, params.OldString, params.NewString)
+ result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
}
waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
- result = fmt.Sprintf("<result>\n%s\n</result>\n", result)
- result += appendDiagnostics(params.FilePath, e.lspClients)
- return NewTextResponse(result), nil
+ text := fmt.Sprintf("<result>\n%s\n</result>\n", result.text)
+ text += appendDiagnostics(params.FilePath, e.lspClients)
+ return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{
+ Additions: result.additions,
+ Removals: result.removals,
+ }), nil
+}
+
+type editResponse struct {
+ text string
+ additions int
+ removals int
}
-func (e *editTool) createNewFile(filePath, content string) (string, error) {
+func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) {
+ er := editResponse{}
fileInfo, err := os.Stat(filePath)
if err == nil {
if fileInfo.IsDir() {
- return "", fmt.Errorf("path is a directory, not a file: %s", filePath)
+ return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
}
- return "", fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath)
+ return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath)
} else if !os.IsNotExist(err) {
- return "", fmt.Errorf("failed to access file: %w", err)
+ return er, fmt.Errorf("failed to access file: %w", err)
}
dir := filepath.Dir(filePath)
if err = os.MkdirAll(dir, 0o755); err != nil {
- return "", fmt.Errorf("failed to create parent directories: %w", err)
+ return er, fmt.Errorf("failed to create parent directories: %w", err)
}
+ sessionID, messageID := getContextValues(ctx)
+ if sessionID == "" || messageID == "" {
+ return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+ }
+
+ diff, stats, err := git.GenerateGitDiffWithStats(
+ removeWorkingDirectoryPrefix(filePath),
+ "",
+ content,
+ )
+ if err != nil {
+ return er, fmt.Errorf("failed to get file diff: %w", err)
+ }
p := e.permissions.Request(
permission.CreatePermissionRequest{
Path: filepath.Dir(filePath),
@@ -178,71 +210,88 @@ func (e *editTool) createNewFile(filePath, content string) (string, error) {
Action: "create",
Description: fmt.Sprintf("Create file %s", filePath),
Params: EditPermissionsParams{
- FilePath: filePath,
- OldString: "",
- NewString: content,
- Diff: GenerateDiff("", content),
+ FilePath: filePath,
+ Diff: diff,
},
},
)
if !p {
- return "", fmt.Errorf("permission denied")
+ return er, fmt.Errorf("permission denied")
}
err = os.WriteFile(filePath, []byte(content), 0o644)
if err != nil {
- return "", fmt.Errorf("failed to write file: %w", err)
+ return er, fmt.Errorf("failed to write file: %w", err)
}
recordFileWrite(filePath)
recordFileRead(filePath)
- return "File created: " + filePath, nil
+ er.text = "File created: " + filePath
+ er.additions = stats.Additions
+ er.removals = stats.Removals
+ return er, nil
}
-func (e *editTool) deleteContent(filePath, oldString string) (string, error) {
+func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) {
+ er := editResponse{}
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
- return "", fmt.Errorf("file not found: %s", filePath)
+ return er, fmt.Errorf("file not found: %s", filePath)
}
- return "", fmt.Errorf("failed to access file: %w", err)
+ return er, fmt.Errorf("failed to access file: %w", err)
}
if fileInfo.IsDir() {
- return "", fmt.Errorf("path is a directory, not a file: %s", filePath)
+ return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
}
if getLastReadTime(filePath).IsZero() {
- return "", fmt.Errorf("you must read the file before editing it. Use the View tool first")
+ return er, fmt.Errorf("you must read the file before editing it. Use the View tool first")
}
modTime := fileInfo.ModTime()
lastRead := getLastReadTime(filePath)
if modTime.After(lastRead) {
- return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
+ return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
}
content, err := os.ReadFile(filePath)
if err != nil {
- return "", fmt.Errorf("failed to read file: %w", err)
+ return er, fmt.Errorf("failed to read file: %w", err)
}
oldContent := string(content)
index := strings.Index(oldContent, oldString)
if index == -1 {
- return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
+ return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
}
lastIndex := strings.LastIndex(oldContent, oldString)
if index != lastIndex {
- return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
+ return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
}
newContent := oldContent[:index] + oldContent[index+len(oldString):]
+ sessionID, messageID := getContextValues(ctx)
+
+ if sessionID == "" || messageID == "" {
+ return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+ }
+
+ diff, stats, err := git.GenerateGitDiffWithStats(
+ removeWorkingDirectoryPrefix(filePath),
+ oldContent,
+ newContent,
+ )
+ if err != nil {
+ return er, fmt.Errorf("failed to get file diff: %w", err)
+ }
+
p := e.permissions.Request(
permission.CreatePermissionRequest{
Path: filepath.Dir(filePath),
@@ -250,76 +299,85 @@ func (e *editTool) deleteContent(filePath, oldString string) (string, error) {
Action: "delete",
Description: fmt.Sprintf("Delete content from file %s", filePath),
Params: EditPermissionsParams{
- FilePath: filePath,
- OldString: oldString,
- NewString: "",
- Diff: GenerateDiff(oldContent, newContent),
+ FilePath: filePath,
+ Diff: diff,
},
},
)
if !p {
- return "", fmt.Errorf("permission denied")
+ return er, fmt.Errorf("permission denied")
}
err = os.WriteFile(filePath, []byte(newContent), 0o644)
if err != nil {
- return "", fmt.Errorf("failed to write file: %w", err)
+ return er, fmt.Errorf("failed to write file: %w", err)
}
-
recordFileWrite(filePath)
recordFileRead(filePath)
- return "Content deleted from file: " + filePath, nil
+ er.text = "Content deleted from file: " + filePath
+ er.additions = stats.Additions
+ er.removals = stats.Removals
+ return er, nil
}
-func (e *editTool) replaceContent(filePath, oldString, newString string) (string, error) {
+func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) {
+ er := editResponse{}
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
- return "", fmt.Errorf("file not found: %s", filePath)
+ return er, fmt.Errorf("file not found: %s", filePath)
}
- return "", fmt.Errorf("failed to access file: %w", err)
+ return er, fmt.Errorf("failed to access file: %w", err)
}
if fileInfo.IsDir() {
- return "", fmt.Errorf("path is a directory, not a file: %s", filePath)
+ return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
}
if getLastReadTime(filePath).IsZero() {
- return "", fmt.Errorf("you must read the file before editing it. Use the View tool first")
+ return er, fmt.Errorf("you must read the file before editing it. Use the View tool first")
}
modTime := fileInfo.ModTime()
lastRead := getLastReadTime(filePath)
if modTime.After(lastRead) {
- return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
+ return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
}
content, err := os.ReadFile(filePath)
if err != nil {
- return "", fmt.Errorf("failed to read file: %w", err)
+ return er, fmt.Errorf("failed to read file: %w", err)
}
oldContent := string(content)
index := strings.Index(oldContent, oldString)
if index == -1 {
- return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
+ return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
}
lastIndex := strings.LastIndex(oldContent, oldString)
if index != lastIndex {
- return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
+ return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
}
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
- startIndex := max(0, index-3)
- oldEndIndex := min(len(oldContent), index+len(oldString)+3)
- newEndIndex := min(len(newContent), index+len(newString)+3)
+ sessionID, messageID := getContextValues(ctx)
- diff := GenerateDiff(oldContent[startIndex:oldEndIndex], newContent[startIndex:newEndIndex])
+ if sessionID == "" || messageID == "" {
+ return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+ }
+ diff, stats, err := git.GenerateGitDiffWithStats(
+ removeWorkingDirectoryPrefix(filePath),
+ oldContent,
+ newContent,
+ )
+ if err != nil {
+ return er, fmt.Errorf("failed to get file diff: %w", err)
+ }
p := e.permissions.Request(
permission.CreatePermissionRequest{
@@ -328,75 +386,27 @@ func (e *editTool) replaceContent(filePath, oldString, newString string) (string
Action: "replace",
Description: fmt.Sprintf("Replace content in file %s", filePath),
Params: EditPermissionsParams{
- FilePath: filePath,
- OldString: oldString,
- NewString: newString,
- Diff: diff,
+ FilePath: filePath,
+
+ Diff: diff,
},
},
)
if !p {
- return "", fmt.Errorf("permission denied")
+ return er, fmt.Errorf("permission denied")
}
err = os.WriteFile(filePath, []byte(newContent), 0o644)
if err != nil {
- return "", fmt.Errorf("failed to write file: %w", err)
+ return er, fmt.Errorf("failed to write file: %w", err)
}
recordFileWrite(filePath)
recordFileRead(filePath)
+ er.text = "Content replaced in file: " + filePath
+ er.additions = stats.Additions
+ er.removals = stats.Removals
- return "Content replaced in file: " + filePath, nil
+ return er, nil
}
-func GenerateDiff(oldContent, newContent string) string {
- dmp := diffmatchpatch.New()
- fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent)
- diffs := dmp.DiffMain(fileAdmp, fileBdmp, false)
- diffs = dmp.DiffCharsToLines(diffs, dmpStrings)
- diffs = dmp.DiffCleanupSemantic(diffs)
- buff := strings.Builder{}
-
- buff.WriteString("Changes:\n")
-
- for _, diff := range diffs {
- text := diff.Text
-
- switch diff.Type {
- case diffmatchpatch.DiffInsert:
- for line := range strings.SplitSeq(text, "\n") {
- if line == "" {
- continue
- }
- _, _ = buff.WriteString("+ " + line + "\n")
- }
- case diffmatchpatch.DiffDelete:
- for line := range strings.SplitSeq(text, "\n") {
- if line == "" {
- continue
- }
- _, _ = buff.WriteString("- " + line + "\n")
- }
- case diffmatchpatch.DiffEqual:
- lines := strings.Split(text, "\n")
- if len(lines) > 3 {
- if lines[0] != "" {
- _, _ = buff.WriteString(" " + lines[0] + "\n")
- }
- _, _ = buff.WriteString(" ...\n")
- if lines[len(lines)-1] != "" {
- _, _ = buff.WriteString(" " + lines[len(lines)-1] + "\n")
- }
- } else {
- for _, line := range lines {
- if line == "" {
- continue
- }
- _, _ = buff.WriteString(" " + line + "\n")
- }
- }
- }
- }
- return buff.String()
-}
diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go
index dbc6e488f..48a34ed75 100644
--- a/internal/llm/tools/edit_test.go
+++ b/internal/llm/tools/edit_test.go
@@ -459,51 +459,3 @@ func TestEditTool_Run(t *testing.T) {
assert.Equal(t, initialContent, string(fileContent))
})
}
-
-func TestGenerateDiff(t *testing.T) {
- testCases := []struct {
- name string
- oldContent string
- newContent string
- expectedDiff string
- }{
- {
- name: "add content",
- oldContent: "Line 1\nLine 2\n",
- newContent: "Line 1\nLine 2\nLine 3\n",
- expectedDiff: "Changes:\n Line 1\n Line 2\n+ Line 3\n",
- },
- {
- name: "remove content",
- oldContent: "Line 1\nLine 2\nLine 3\n",
- newContent: "Line 1\nLine 3\n",
- expectedDiff: "Changes:\n Line 1\n- Line 2\n Line 3\n",
- },
- {
- name: "replace content",
- oldContent: "Line 1\nLine 2\nLine 3\n",
- newContent: "Line 1\nModified Line\nLine 3\n",
- expectedDiff: "Changes:\n Line 1\n- Line 2\n+ Modified Line\n Line 3\n",
- },
- {
- name: "empty to content",
- oldContent: "",
- newContent: "Line 1\nLine 2\n",
- expectedDiff: "Changes:\n+ Line 1\n+ Line 2\n",
- },
- {
- name: "content to empty",
- oldContent: "Line 1\nLine 2\n",
- newContent: "",
- expectedDiff: "Changes:\n- Line 1\n- Line 2\n",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- diff := GenerateDiff(tc.oldContent, tc.newContent)
- assert.Contains(t, diff, tc.expectedDiff)
- })
- }
-}
-
diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go
index 7f34fdc1f..9c9707c9c 100644
--- a/internal/llm/tools/file.go
+++ b/internal/llm/tools/file.go
@@ -3,6 +3,8 @@ package tools
import (
"sync"
"time"
+
+ "github.com/kujtimiihoxha/termai/internal/config"
)
// File record to track when files were read/written
@@ -17,6 +19,14 @@ var (
fileRecordMutex sync.RWMutex
)
+func removeWorkingDirectoryPrefix(path string) string {
+ wd := config.WorkingDirectory()
+ if len(path) > len(wd) && path[:len(wd)] == wd {
+ return path[len(wd)+1:]
+ }
+ return path
+}
+
func recordFileRead(path string) {
fileRecordMutex.Lock()
defer fileRecordMutex.Unlock()
diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go
index 6bb528686..473b787bb 100644
--- a/internal/llm/tools/tools.go
+++ b/internal/llm/tools/tools.go
@@ -17,6 +17,9 @@ type toolResponseType string
const (
ToolResponseTypeText toolResponseType = "text"
ToolResponseTypeImage toolResponseType = "image"
+
+ SessionIDContextKey = "session_id"
+ MessageIDContextKey = "message_id"
)
type ToolResponse struct {
@@ -62,3 +65,15 @@ type BaseTool interface {
Info() ToolInfo
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}
+
+func getContextValues(ctx context.Context) (string, string) {
+ sessionID := ctx.Value(SessionIDContextKey)
+ messageID := ctx.Value(MessageIDContextKey)
+ if sessionID == nil {
+ return "", ""
+ }
+ if messageID == nil {
+ return sessionID.(string), ""
+ }
+ return sessionID.(string), messageID.(string)
+}
diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go
index 7b698d2d8..27c98bb9d 100644
--- a/internal/llm/tools/write.go
+++ b/internal/llm/tools/write.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/git"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
@@ -20,7 +21,7 @@ type WriteParams struct {
type WritePermissionsParams struct {
FilePath string `json:"file_path"`
- Content string `json:"content"`
+ Diff string `json:"diff"`
}
type writeTool struct {
@@ -28,6 +29,11 @@ type writeTool struct {
permissions permission.Service
}
+type WriteResponseMetadata struct {
+ Additions int `json:"additions"`
+ Removals int `json:"removals"`
+}
+
const (
WriteToolName = "write"
writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content.
@@ -138,6 +144,18 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
}
}
+ sessionID, messageID := getContextValues(ctx)
+ if sessionID == "" || messageID == "" {
+ return NewTextErrorResponse("session ID or message ID is missing"), nil
+ }
+ diff, stats, err := git.GenerateGitDiffWithStats(
+ removeWorkingDirectoryPrefix(filePath),
+ oldContent,
+ params.Content,
+ )
+ if err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil
+ }
p := w.permissions.Request(
permission.CreatePermissionRequest{
Path: filePath,
@@ -146,7 +164,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
Description: fmt.Sprintf("Create file %s", filePath),
Params: WritePermissionsParams{
FilePath: filePath,
- Content: GenerateDiff(oldContent, params.Content),
+ Diff: diff,
},
},
)
@@ -166,5 +184,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
result := fmt.Sprintf("File successfully written: %s", filePath)
result = fmt.Sprintf("<result>\n%s\n</result>", result)
result += appendDiagnostics(filePath, w.lspClients)
- return NewTextResponse(result), nil
+ return WithResponseMetadata(NewTextResponse(result),
+ WriteResponseMetadata{
+ Additions: stats.Additions,
+ Removals: stats.Removals,
+ },
+ ), nil
}