diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-16 20:06:23 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-21 13:42:00 +0200 |
| commit | bbfa60c787f2ec459f1689b9a650ddbec9693ed9 (patch) | |
| tree | f7f2aa31c460c8cc22ec40cc299c386277152241 /internal/llm/tools | |
| parent | 76b4065f17b87a63092acfd98c997bab53700b35 (diff) | |
| download | opencode-bbfa60c787f2ec459f1689b9a650ddbec9693ed9.tar.gz opencode-bbfa60c787f2ec459f1689b9a650ddbec9693ed9.zip | |
reimplement agent,provider and add file history
Diffstat (limited to 'internal/llm/tools')
| -rw-r--r-- | internal/llm/tools/bash.go | 7 | ||||
| -rw-r--r-- | internal/llm/tools/bash_test.go | 31 | ||||
| -rw-r--r-- | internal/llm/tools/edit.go | 75 | ||||
| -rw-r--r-- | internal/llm/tools/edit_test.go | 30 | ||||
| -rw-r--r-- | internal/llm/tools/file.go | 10 | ||||
| -rw-r--r-- | internal/llm/tools/glob.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/grep.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/ls.go | 4 | ||||
| -rw-r--r-- | internal/llm/tools/mocks_test.go | 246 | ||||
| -rw-r--r-- | internal/llm/tools/shell/shell.go | 12 | ||||
| -rw-r--r-- | internal/llm/tools/sourcegraph.go | 2 | ||||
| -rw-r--r-- | internal/llm/tools/tools.go | 9 | ||||
| -rw-r--r-- | internal/llm/tools/write.go | 27 | ||||
| -rw-r--r-- | internal/llm/tools/write_test.go | 22 |
14 files changed, 399 insertions, 84 deletions
diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0cea20878..c7c970e5a 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -23,7 +23,8 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - Took int64 `json:"took"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` } type bashTool struct { permissions permission.Service @@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } - took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - Took: took, + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), } if stdout == "" { return WithResponseMetadata(NewTextResponse("no output"), metadata), nil diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 97be3683a..dafb0ccc5 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) { }) } } - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 08d6d446c..148e7aba7 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -35,6 +36,7 @@ type EditResponseMetadata struct { type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } const ( @@ -88,10 +90,11 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return response, nil } + if response.IsError { + // Return early if there was an error during content replacement + // This prevents unnecessary LSP diagnostics processing + return response, nil + } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content) @@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // File can't be in the history so we create a new file history + _, err = e.files.Create(ctx, sessionID, filePath, "") + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + // Add the new content to the file history + _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) + if err != nil { + // Log error but don't fail the operation + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string if err != nil { return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + if oldContent == newContent { + return NewTextErrorResponse("new content is the same as old content. No changes made."), nil + } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, - - Diff: diff, + Diff: diff, }, }, ) @@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 48a34ed75..0971775dd 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -14,7 +14,7 @@ import ( ) func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, EditToolName, info.Name) @@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "replace_content.txt") @@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "delete_content.txt") @@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: EditToolName, @@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := EditParams{ FilePath: "", @@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "non_existent_file.txt") params := EditParams{ @@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "content_not_found.txt") @@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file with duplicate content filePath := filepath.Join(tempDir, "duplicate_content.txt") @@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "not_read_file.txt") @@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "permission_denied.txt") diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 9c9707c9c..7f34fdc1f 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,8 +3,6 @@ package tools import ( "sync" "time" - - "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -19,14 +17,6 @@ var ( fileRecordMutex sync.RWMutex ) -func removeWorkingDirectoryPrefix(path string) string { - wd := config.WorkingDirectory() - if len(path) > len(wd) && path[:len(wd)] == wd { - return path[len(wd)+1:] - } - return path -} - func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index bdfc23b4a..7b4fb1187 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,7 +63,7 @@ type GlobParams struct { Path string `json:"path"` } -type GlobMetadata struct { +type GlobResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GlobMetadata{ + GlobResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7e52821d0..19333f50b 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,7 +27,7 @@ type grepMatch struct { modTime time.Time } -type GrepMetadata struct { +type GrepResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } @@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GrepMetadata{ + GrepResponseMetadata{ NumberOfMatches: len(matches), Truncated: truncated, }, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a679f261b..a63bf0eeb 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,7 +23,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -type LSMetadata struct { +type LSResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { return WithResponseMetadata( NewTextResponse(output), - LSMetadata{ + LSResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go new file mode 100644 index 000000000..321f09ac1 --- /dev/null +++ b/internal/llm/tools/mocks_test.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +// Mock permission service for testing +type mockPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + allow bool +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return m.allow +} + +func newMockPermissionService(allow bool) permission.Service { + return &mockPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } +} + +type mockFileHistoryService struct { + *pubsub.Broker[history.File] + files map[string]history.File // ID -> File + timeNow func() int64 +} + +// Create implements history.Service. +func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) +} + +// CreateVersion implements history.Service. +func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + var files []history.File + for _, file := range m.files { + if file.Path == path { + files = append(files, file) + } + } + + if len(files) == 0 { + // No previous versions, create initial + return m.Create(ctx, sessionID, path, content) + } + + // Sort files by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + // Get the latest version + latestFile := files[0] + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == history.InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return m.createWithVersion(ctx, sessionID, path, content, nextVersion) +} + +func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { + now := m.timeNow() + file := history.File{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + CreatedAt: now, + UpdatedAt: now, + } + + m.files[file.ID] = file + m.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +// Delete implements history.Service. +func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { + file, ok := m.files[id] + if !ok { + return fmt.Errorf("file not found: %s", id) + } + + delete(m.files, id) + m.Publish(pubsub.DeletedEvent, file) + return nil +} + +// DeleteSessionFiles implements history.Service. +func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := m.ListBySession(ctx, sessionID) + if err != nil { + return err + } + + for _, file := range files { + err = m.Delete(ctx, file.ID) + if err != nil { + return err + } + } + + return nil +} + +// Get implements history.Service. +func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { + file, ok := m.files[id] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", id) + } + return file, nil +} + +// GetByPathAndSession implements history.Service. +func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { + var latestFile history.File + var found bool + var latestTime int64 + + for _, file := range m.files { + if file.Path == path && file.SessionID == sessionID { + if !found || file.CreatedAt > latestTime { + latestFile = file + latestTime = file.CreatedAt + found = true + } + } + } + + if !found { + return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) + } + return latestFile, nil +} + +// ListBySession implements history.Service. +func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + var files []history.File + for _, file := range m.files { + if file.SessionID == sessionID { + files = append(files, file) + } + } + + // Sort by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + return files, nil +} + +// ListLatestSessionFiles implements history.Service. +func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + // Map to track the latest file for each path + latestFiles := make(map[string]history.File) + + for _, file := range m.files { + if file.SessionID == sessionID { + existing, ok := latestFiles[file.Path] + if !ok || file.CreatedAt > existing.CreatedAt { + latestFiles[file.Path] = file + } + } + } + + // Convert map to slice + var result []history.File + for _, file := range latestFiles { + result = append(result, file) + } + + // Sort by CreatedAt in descending order + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result, nil +} + +// Subscribe implements history.Service. +func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + return m.Broker.Subscribe(ctx) +} + +// Update implements history.Service. +func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + _, ok := m.files[file.ID] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", file.ID) + } + + file.UpdatedAt = m.timeNow() + m.files[file.ID] = file + m.Publish(pubsub.UpdatedEvent, file) + return file, nil +} + +func newMockFileHistoryService() history.Service { + return &mockFileHistoryService{ + Broker: pubsub.NewBroker[history.File](), + files: make(map[string]history.File), + timeNow: func() int64 { return time.Now().Unix() }, + } +} diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 64592f67d..4a776478a 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell { commandQueue: make(chan *commandExecution, 10), } - go shell.processCommands() + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() go func() { err := cmd.Wait() if err != nil { + // Log the error if needed } shell.isAlive = false close(shell.commandQueue) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 17bc610ea..a6f2c8afb 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,7 +18,7 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } -type SourcegraphMetadata struct { +type SourcegraphResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 07afe1363..bf0f8df0b 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -14,12 +14,17 @@ type ToolInfo struct { type toolResponseType string +type ( + sessionIDContextKey string + messageIDContextKey string +) + const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" - SessionIDContextKey = "session_id" - MessageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" ) type ToolResponse struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 889561d2a..bb49381fd 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -27,6 +28,7 @@ type WritePermissionsParams struct { type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } type WriteResponseMetadata struct { @@ -67,10 +69,11 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } + // Check if file exists in history + file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = w.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) waitForLspDiagnostics(ctx, filePath, w.lspClients) diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 50dafc14f..2264f36fb 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -14,7 +14,7 @@ import ( ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: WriteToolName, @@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: "", @@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") |
