From 333ea6ec4b2abfc2c1a9c3f6b0918ca5d296347f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 20:17:38 +0200 Subject: implement patch, update ui, improve rendering --- internal/diff/patch.go | 739 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 739 insertions(+) create mode 100644 internal/diff/patch.go (limited to 'internal/diff/patch.go') diff --git a/internal/diff/patch.go b/internal/diff/patch.go new file mode 100644 index 000000000..aab0f956d --- /dev/null +++ b/internal/diff/patch.go @@ -0,0 +1,739 @@ +package diff + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +type ActionType string + +const ( + ActionAdd ActionType = "add" + ActionDelete ActionType = "delete" + ActionUpdate ActionType = "update" +) + +type FileChange struct { + Type ActionType + OldContent *string + NewContent *string + MovePath *string +} + +type Commit struct { + Changes map[string]FileChange +} + +type Chunk struct { + OrigIndex int // line index of the first line in the original file + DelLines []string // lines to delete + InsLines []string // lines to insert +} + +type PatchAction struct { + Type ActionType + NewFile *string + Chunks []Chunk + MovePath *string +} + +type Patch struct { + Actions map[string]PatchAction +} + +type DiffError struct { + message string +} + +func (e DiffError) Error() string { + return e.message +} + +// Helper functions for error handling +func NewDiffError(message string) DiffError { + return DiffError{message: message} +} + +func fileError(action, reason, path string) DiffError { + return NewDiffError(fmt.Sprintf("%s File Error: %s: %s", action, reason, path)) +} + +func contextError(index int, context string, isEOF bool) DiffError { + prefix := "Invalid Context" + if isEOF { + prefix = "Invalid EOF Context" + } + return NewDiffError(fmt.Sprintf("%s %d:\n%s", prefix, index, context)) +} + +type Parser struct { + currentFiles map[string]string + lines []string + index int + patch Patch + fuzz int +} + +func NewParser(currentFiles map[string]string, lines []string) *Parser { + return &Parser{ + currentFiles: currentFiles, + lines: lines, + index: 0, + patch: Patch{Actions: make(map[string]PatchAction, len(currentFiles))}, + fuzz: 0, + } +} + +func (p *Parser) isDone(prefixes []string) bool { + if p.index >= len(p.lines) { + return true + } + if prefixes != nil { + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true + } + } + } + return false +} + +func (p *Parser) startsWith(prefix any) bool { + var prefixes []string + switch v := prefix.(type) { + case string: + prefixes = []string{v} + case []string: + prefixes = v + } + + for _, pfx := range prefixes { + if strings.HasPrefix(p.lines[p.index], pfx) { + return true + } + } + return false +} + +func (p *Parser) readStr(prefix string, returnEverything bool) string { + if p.index >= len(p.lines) { + return "" // Changed from panic to return empty string for safer operation + } + if strings.HasPrefix(p.lines[p.index], prefix) { + var text string + if returnEverything { + text = p.lines[p.index] + } else { + text = p.lines[p.index][len(prefix):] + } + p.index++ + return text + } + return "" +} + +func (p *Parser) Parse() error { + endPatchPrefixes := []string{"*** End Patch"} + + for !p.isDone(endPatchPrefixes) { + path := p.readStr("*** Update File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Update", "Duplicate Path", path) + } + moveTo := p.readStr("*** Move to: ", false) + if _, exists := p.currentFiles[path]; !exists { + return fileError("Update", "Missing File", path) + } + text := p.currentFiles[path] + action, err := p.parseUpdateFile(text) + if err != nil { + return err + } + if moveTo != "" { + action.MovePath = &moveTo + } + p.patch.Actions[path] = action + continue + } + + path = p.readStr("*** Delete File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Delete", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; !exists { + return fileError("Delete", "Missing File", path) + } + p.patch.Actions[path] = PatchAction{Type: ActionDelete, Chunks: []Chunk{}} + continue + } + + path = p.readStr("*** Add File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Add", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; exists { + return fileError("Add", "File already exists", path) + } + action, err := p.parseAddFile() + if err != nil { + return err + } + p.patch.Actions[path] = action + continue + } + + return NewDiffError(fmt.Sprintf("Unknown Line: %s", p.lines[p.index])) + } + + if !p.startsWith("*** End Patch") { + return NewDiffError("Missing End Patch") + } + p.index++ + + return nil +} + +func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { + action := PatchAction{Type: ActionUpdate, Chunks: []Chunk{}} + fileLines := strings.Split(text, "\n") + index := 0 + + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + "*** End of File", + } + + for !p.isDone(endPrefixes) { + defStr := p.readStr("@@ ", false) + sectionStr := "" + if defStr == "" && p.index < len(p.lines) && p.lines[p.index] == "@@" { + sectionStr = p.lines[p.index] + p.index++ + } + if !(defStr != "" || sectionStr != "" || index == 0) { + return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) + } + if strings.TrimSpace(defStr) != "" { + found := false + for i := range fileLines[:index] { + if fileLines[i] == defStr { + found = true + break + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if fileLines[i] == defStr { + index = i + 1 + found = true + break + } + } + } + + if !found { + for i := range fileLines[:index] { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + found = true + break + } + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + index = i + 1 + p.fuzz++ + found = true + break + } + } + } + } + + nextChunkContext, chunks, endPatchIndex, eof := peekNextSection(p.lines, p.index) + newIndex, fuzz := findContext(fileLines, nextChunkContext, index, eof) + if newIndex == -1 { + ctxText := strings.Join(nextChunkContext, "\n") + return action, contextError(index, ctxText, eof) + } + p.fuzz += fuzz + + for _, ch := range chunks { + ch.OrigIndex += newIndex + action.Chunks = append(action.Chunks, ch) + } + index = newIndex + len(nextChunkContext) + p.index = endPatchIndex + } + return action, nil +} + +func (p *Parser) parseAddFile() (PatchAction, error) { + lines := make([]string, 0, 16) // Preallocate space for better performance + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + } + + for !p.isDone(endPrefixes) { + s := p.readStr("", true) + if !strings.HasPrefix(s, "+") { + return PatchAction{}, NewDiffError(fmt.Sprintf("Invalid Add File Line: %s", s)) + } + lines = append(lines, s[1:]) + } + + newFile := strings.Join(lines, "\n") + return PatchAction{ + Type: ActionAdd, + NewFile: &newFile, + Chunks: []Chunk{}, + }, nil +} + +// Refactored to use a matcher function for each comparison type +func findContextCore(lines []string, context []string, start int) (int, int) { + if len(context) == 0 { + return start, 0 + } + + // Try exact match + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return a == b + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming right whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimRight(a, " \t") == strings.TrimRight(b, " \t") + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming all whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimSpace(a) == strings.TrimSpace(b) + }); idx >= 0 { + return idx, fuzz + } + + return -1, 0 +} + +// Helper function to DRY up the match logic +func tryFindMatch(lines []string, context []string, start int, + compareFunc func(string, string) bool, +) (int, int) { + for i := start; i < len(lines); i++ { + if i+len(context) <= len(lines) { + match := true + for j := range context { + if !compareFunc(lines[i+j], context[j]) { + match = false + break + } + } + if match { + // Return fuzz level: 0 for exact, 1 for trimRight, 100 for trimSpace + var fuzz int + if compareFunc("a ", "a") && !compareFunc("a", "b") { + fuzz = 1 + } else if compareFunc("a ", "a") { + fuzz = 100 + } + return i, fuzz + } + } + } + return -1, 0 +} + +func findContext(lines []string, context []string, start int, eof bool) (int, int) { + if eof { + newIndex, fuzz := findContextCore(lines, context, len(lines)-len(context)) + if newIndex != -1 { + return newIndex, fuzz + } + newIndex, fuzz = findContextCore(lines, context, start) + return newIndex, fuzz + 10000 + } + return findContextCore(lines, context, start) +} + +func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, bool) { + index := initialIndex + old := make([]string, 0, 32) // Preallocate for better performance + delLines := make([]string, 0, 8) + insLines := make([]string, 0, 8) + chunks := make([]Chunk, 0, 4) + mode := "keep" + + // End conditions for the section + endSectionConditions := func(s string) bool { + return strings.HasPrefix(s, "@@") || + strings.HasPrefix(s, "*** End Patch") || + strings.HasPrefix(s, "*** Update File:") || + strings.HasPrefix(s, "*** Delete File:") || + strings.HasPrefix(s, "*** Add File:") || + strings.HasPrefix(s, "*** End of File") || + s == "***" || + strings.HasPrefix(s, "***") + } + + for index < len(lines) { + s := lines[index] + if endSectionConditions(s) { + break + } + index++ + lastMode := mode + line := s + + if len(line) > 0 { + switch line[0] { + case '+': + mode = "add" + case '-': + mode = "delete" + case ' ': + mode = "keep" + default: + mode = "keep" + line = " " + line + } + } else { + mode = "keep" + line = " " + } + + line = line[1:] + if mode == "keep" && lastMode != mode { + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + delLines = make([]string, 0, 8) + insLines = make([]string, 0, 8) + } + if mode == "delete" { + delLines = append(delLines, line) + old = append(old, line) + } else if mode == "add" { + insLines = append(insLines, line) + } else { + old = append(old, line) + } + } + + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + + if index < len(lines) && lines[index] == "*** End of File" { + index++ + return old, chunks, index, true + } + return old, chunks, index, false +} + +func TextToPatch(text string, orig map[string]string) (Patch, int, error) { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + if len(lines) < 2 || !strings.HasPrefix(lines[0], "*** Begin Patch") || lines[len(lines)-1] != "*** End Patch" { + return Patch{}, 0, NewDiffError("Invalid patch text") + } + parser := NewParser(orig, lines) + parser.index = 1 + if err := parser.Parse(); err != nil { + return Patch{}, 0, err + } + return parser.patch, parser.fuzz, nil +} + +func IdentifyFilesNeeded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Update File: ") { + result[line[len("*** Update File: "):]] = true + } + if strings.HasPrefix(line, "*** Delete File: ") { + result[line[len("*** Delete File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func IdentifyFilesAdded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Add File: ") { + result[line[len("*** Add File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func getUpdatedFile(text string, action PatchAction, path string) (string, error) { + if action.Type != ActionUpdate { + return "", errors.New("Expected UPDATE action") + } + origLines := strings.Split(text, "\n") + destLines := make([]string, 0, len(origLines)) // Preallocate with capacity + origIndex := 0 + + for _, chunk := range action.Chunks { + if chunk.OrigIndex > len(origLines) { + return "", NewDiffError(fmt.Sprintf("%s: chunk.orig_index %d > len(lines) %d", path, chunk.OrigIndex, len(origLines))) + } + if origIndex > chunk.OrigIndex { + return "", NewDiffError(fmt.Sprintf("%s: orig_index %d > chunk.orig_index %d", path, origIndex, chunk.OrigIndex)) + } + destLines = append(destLines, origLines[origIndex:chunk.OrigIndex]...) + delta := chunk.OrigIndex - origIndex + origIndex += delta + + if len(chunk.InsLines) > 0 { + destLines = append(destLines, chunk.InsLines...) + } + origIndex += len(chunk.DelLines) + } + + destLines = append(destLines, origLines[origIndex:]...) + return strings.Join(destLines, "\n"), nil +} + +func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { + commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} + for pathKey, action := range patch.Actions { + if action.Type == ActionDelete { + oldContent := orig[pathKey] + commit.Changes[pathKey] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else if action.Type == ActionAdd { + commit.Changes[pathKey] = FileChange{ + Type: ActionAdd, + NewContent: action.NewFile, + } + } else if action.Type == ActionUpdate { + newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) + if err != nil { + return Commit{}, err + } + oldContent := orig[pathKey] + fileChange := FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + if action.MovePath != nil { + fileChange.MovePath = action.MovePath + } + commit.Changes[pathKey] = fileChange + } + } + return commit, nil +} + +func AssembleChanges(orig map[string]string, updatedFiles map[string]string) Commit { + commit := Commit{Changes: make(map[string]FileChange, len(updatedFiles))} + for p, newContent := range updatedFiles { + oldContent, exists := orig[p] + if exists && oldContent == newContent { + continue + } + + if exists && newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + } else if newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionAdd, + NewContent: &newContent, + } + } else if exists { + commit.Changes[p] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else { + return commit // Changed from panic to simply return current commit + } + } + return commit +} + +func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string]string, error) { + orig := make(map[string]string, len(paths)) + for _, p := range paths { + content, err := openFn(p) + if err != nil { + return nil, fileError("Open", "File not found", p) + } + orig[p] = content + } + return orig, nil +} + +func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { + for p, change := range commit.Changes { + if change.Type == ActionDelete { + if err := removeFn(p); err != nil { + return err + } + } else if change.Type == ActionAdd { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) + } + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } else if change.Type == ActionUpdate { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) + } + if change.MovePath != nil { + if err := writeFn(*change.MovePath, *change.NewContent); err != nil { + return err + } + if err := removeFn(p); err != nil { + return err + } + } else { + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } + } + } + return nil +} + +func ProcessPatch(text string, openFn func(string) (string, error), writeFn func(string, string) error, removeFn func(string) error) (string, error) { + if !strings.HasPrefix(text, "*** Begin Patch") { + return "", NewDiffError("Patch must start with *** Begin Patch") + } + paths := IdentifyFilesNeeded(text) + orig, err := LoadFiles(paths, openFn) + if err != nil { + return "", err + } + + patch, fuzz, err := TextToPatch(text, orig) + if err != nil { + return "", err + } + + if fuzz > 0 { + return "", NewDiffError(fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz)) + } + + commit, err := PatchToCommit(patch, orig) + if err != nil { + return "", err + } + + if err := ApplyCommit(commit, writeFn, removeFn); err != nil { + return "", err + } + + return "Patch applied successfully", nil +} + +func OpenFile(p string) (string, error) { + data, err := os.ReadFile(p) + if err != nil { + return "", err + } + return string(data), nil +} + +func WriteFile(p string, content string) error { + if filepath.IsAbs(p) { + return NewDiffError("We do not support absolute paths.") + } + + dir := filepath.Dir(p) + if dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + } + + return os.WriteFile(p, []byte(content), 0o644) +} + +func RemoveFile(p string) error { + return os.Remove(p) +} + +func ValidatePatch(patchText string, files map[string]string) (bool, string, error) { + if !strings.HasPrefix(patchText, "*** Begin Patch") { + return false, "Patch must start with *** Begin Patch", nil + } + + neededFiles := IdentifyFilesNeeded(patchText) + for _, filePath := range neededFiles { + if _, exists := files[filePath]; !exists { + return false, fmt.Sprintf("File not found: %s", filePath), nil + } + } + + patch, fuzz, err := TextToPatch(patchText, files) + if err != nil { + return false, err.Error(), nil + } + + if fuzz > 0 { + return false, fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz), nil + } + + _, err = PatchToCommit(patch, files) + if err != nil { + return false, err.Error(), nil + } + + return true, "Patch is valid", nil +} -- cgit v1.2.3 From 72afeb9f54cee8e248093a52ac0779441c79aea3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 21:24:35 +0200 Subject: small fixes --- internal/diff/patch.go | 33 ++++++++++--------- internal/llm/agent/agent.go | 2 ++ internal/llm/tools/edit.go | 6 ++-- internal/tui/components/chat/editor.go | 9 +++-- internal/tui/components/chat/list.go | 49 ++++++++++++++++++++-------- internal/tui/components/chat/message.go | 1 + internal/tui/components/dialog/permission.go | 22 +++++++------ internal/tui/components/dialog/session.go | 48 ++++++++++++++------------- internal/tui/page/chat.go | 7 ---- internal/tui/tui.go | 4 ++- 10 files changed, 106 insertions(+), 75 deletions(-) (limited to 'internal/diff/patch.go') diff --git a/internal/diff/patch.go b/internal/diff/patch.go index aab0f956d..49242f7ef 100644 --- a/internal/diff/patch.go +++ b/internal/diff/patch.go @@ -91,11 +91,9 @@ func (p *Parser) isDone(prefixes []string) bool { if p.index >= len(p.lines) { return true } - if prefixes != nil { - for _, prefix := range prefixes { - if strings.HasPrefix(p.lines[p.index], prefix) { - return true - } + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true } } return false @@ -219,7 +217,7 @@ func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { sectionStr = p.lines[p.index] p.index++ } - if !(defStr != "" || sectionStr != "" || index == 0) { + if defStr == "" && sectionStr == "" && index != 0 { return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) } if strings.TrimSpace(defStr) != "" { @@ -433,12 +431,13 @@ func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, delLines = make([]string, 0, 8) insLines = make([]string, 0, 8) } - if mode == "delete" { + switch mode { + case "delete": delLines = append(delLines, line) old = append(old, line) - } else if mode == "add" { + case "add": insLines = append(insLines, line) - } else { + default: old = append(old, line) } } @@ -513,7 +512,7 @@ func IdentifyFilesAdded(text string) []string { func getUpdatedFile(text string, action PatchAction, path string) (string, error) { if action.Type != ActionUpdate { - return "", errors.New("Expected UPDATE action") + return "", errors.New("expected UPDATE action") } origLines := strings.Split(text, "\n") destLines := make([]string, 0, len(origLines)) // Preallocate with capacity @@ -543,18 +542,19 @@ func getUpdatedFile(text string, action PatchAction, path string) (string, error func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} for pathKey, action := range patch.Actions { - if action.Type == ActionDelete { + switch action.Type { + case ActionDelete: oldContent := orig[pathKey] commit.Changes[pathKey] = FileChange{ Type: ActionDelete, OldContent: &oldContent, } - } else if action.Type == ActionAdd { + case ActionAdd: commit.Changes[pathKey] = FileChange{ Type: ActionAdd, NewContent: action.NewFile, } - } else if action.Type == ActionUpdate { + case ActionUpdate: newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) if err != nil { return Commit{}, err @@ -619,18 +619,19 @@ func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string] func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { for p, change := range commit.Changes { - if change.Type == ActionDelete { + switch change.Type { + case ActionDelete: if err := removeFn(p); err != nil { return err } - } else if change.Type == ActionAdd { + case ActionAdd: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) } if err := writeFn(p, *change.NewContent); err != nil { return err } - } else if change.Type == ActionUpdate { + case ActionUpdate: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5e9785991..7542d9adf 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -221,6 +221,8 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) if err != nil { if errors.Is(err, context.Canceled) { + agentMessage.AddFinish(message.FinishReasonCanceled) + a.messages.Update(context.Background(), agentMessage) return a.err(ErrRequestCancelled) } return a.err(fmt.Errorf("failed to process events: %w", err)) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 6a1616010..83cec5dba 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -141,20 +141,20 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if params.OldString == "" { response, err = e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { - return response, nil + return response, err } } if params.NewString == "" { response, err = e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { - return response, nil + return response, err } } response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { - return response, nil + return response, err } if response.IsError { // Return early if there was an error during content replacement diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 537ef392c..963fbbdbf 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -21,6 +21,8 @@ type editorCmp struct { textarea textarea.Model } +type FocusEditorMsg bool + type focusedEditorKeyMaps struct { Send key.Binding OpenEditor key.Binding @@ -112,7 +114,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(EditorFocusMsg(false)), ) } @@ -124,9 +125,13 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg } return m, nil + case FocusEditorMsg: + if msg { + m.textarea.Focus() + return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) + } case tea.KeyMsg: if key.Matches(msg, focusedKeyMaps.OpenEditor) { - m.textarea.Blur() return m, openEditor() } // if the key does not match any binding, return diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index f95b53731..b7703e2cc 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -22,6 +22,10 @@ import ( "github.com/kujtimiihoxha/opencode/internal/tui/util" ) +type cacheItem struct { + width int + content []uiMessage +} type messagesCmp struct { app *app.App width, height int @@ -32,8 +36,9 @@ type messagesCmp struct { uiMessages []uiMessage currentMsgID string mutex sync.Mutex - cachedContent map[string][]uiMessage + cachedContent map[string]cacheItem spinner spinner.Model + lastUpdate time.Time rendering bool } type renderFinishedMsg struct{} @@ -44,6 +49,8 @@ func (m *messagesCmp) Init() tea.Cmd { func (m *messagesCmp) preloadSessions() tea.Cmd { return func() tea.Msg { + m.mutex.Lock() + defer m.mutex.Unlock() sessions, err := m.app.Sessions.List(context.Background()) if err != nil { return util.ReportError(err)() @@ -67,13 +74,13 @@ func (m *messagesCmp) preloadSessions() tea.Cmd { } logging.Debug("preloaded sessions") - return nil + return func() tea.Msg { + return renderFinishedMsg{} + } } } func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { - m.mutex.Lock() - defer m.mutex.Unlock() pos := 0 if m.width == 0 { return @@ -87,7 +94,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int width, pos, ) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: assistantMessages := renderAssistantMessage( @@ -102,7 +112,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int for _, msg := range assistantMessages { pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: assistantMessages, + } } } } @@ -223,8 +236,8 @@ func (m *messagesCmp) renderView() { for inx, msg := range m.messages { switch msg.Role { case message.User: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } userMsg := renderUserMessage( @@ -234,11 +247,14 @@ func (m *messagesCmp) renderView() { pos, ) m.uiMessages = append(m.uiMessages, userMsg) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } assistantMessages := renderAssistantMessage( @@ -254,7 +270,10 @@ func (m *messagesCmp) renderView() { m.uiMessages = append(m.uiMessages, msg) pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: assistantMessages, + } } } @@ -418,6 +437,10 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { m.height = height m.viewport.Width = width m.viewport.Height = height - 2 + for _, msg := range m.messages { + delete(m.cachedContent, msg.ID) + } + m.uiMessages = make([]uiMessage, 0) m.renderView() return m.preloadSessions() } @@ -456,7 +479,7 @@ func NewMessagesCmp(app *app.App) tea.Model { return &messagesCmp{ app: app, writingMode: true, - cachedContent: make(map[string][]uiMessage), + cachedContent: make(map[string]cacheItem), viewport: viewport.New(0, 0), spinner: s, } diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index be6c7ce50..7a840b4ec 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -389,6 +389,7 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " ")) errContent = ansi.Truncate(errContent, width-1, "...") return styles.BaseStyle. + Width(width). Foreground(styles.Error). Render(errContent) } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index f83472e68..1f8df21a0 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -40,7 +40,8 @@ type PermissionDialogCmp interface { } type permissionsMapping struct { - LeftRight key.Binding + Left key.Binding + Right key.Binding EnterSpace key.Binding Allow key.Binding AllowSession key.Binding @@ -49,9 +50,13 @@ type permissionsMapping struct { } var permissionsKeys = permissionsMapping{ - LeftRight: key.NewBinding( - key.WithKeys("left", "right"), - key.WithHelp("←/→", "switch options"), + Left: key.NewBinding( + key.WithKeys("left"), + key.WithHelp("←", "switch options"), + ), + Right: key.NewBinding( + key.WithKeys("right"), + key.WithHelp("→", "switch options"), ), EnterSpace: key.NewBinding( key.WithKeys("enter", " "), @@ -104,21 +109,18 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.diffCache = make(map[string]string) case tea.KeyMsg: switch { - case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): - // Change selected option + case key.Matches(msg, permissionsKeys.Right) || key.Matches(msg, permissionsKeys.Tab): p.selectedOption = (p.selectedOption + 1) % 3 return p, nil + case key.Matches(msg, permissionsKeys.Left): + p.selectedOption = (p.selectedOption + 2) % 3 case key.Matches(msg, permissionsKeys.EnterSpace): - // Select current option return p, p.selectCurrentOption() case key.Matches(msg, permissionsKeys.Allow): - // Select Allow return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) case key.Matches(msg, permissionsKeys.AllowSession): - // Select Allow for session return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) case key.Matches(msg, permissionsKeys.Deny): - // Select Deny return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) default: // Pass other keys to viewport diff --git a/internal/tui/components/dialog/session.go b/internal/tui/components/dialog/session.go index d8c859c49..060875f91 100644 --- a/internal/tui/components/dialog/session.go +++ b/internal/tui/components/dialog/session.go @@ -27,20 +27,20 @@ type SessionDialog interface { } type sessionDialogCmp struct { - sessions []session.Session - selectedIdx int - width int - height int + sessions []session.Session + selectedIdx int + width int + height int selectedSessionID string } type sessionKeyMap struct { - Up key.Binding - Down key.Binding - Enter key.Binding - Escape key.Binding - J key.Binding - K key.Binding + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding } var sessionKeys = sessionKeyMap{ @@ -128,7 +128,7 @@ func (s *sessionDialogCmp) View() string { // Build the session list sessionItems := make([]string, 0, maxVisibleSessions) startIdx := 0 - + // If we have more sessions than can be displayed, adjust the start index if len(s.sessions) > maxVisibleSessions { // Center the selected item when possible @@ -145,30 +145,31 @@ func (s *sessionDialogCmp) View() string { for i := startIdx; i < endIdx; i++ { sess := s.sessions[i] itemStyle := styles.BaseStyle.Width(maxWidth) - + if i == s.selectedIdx { itemStyle = itemStyle. Background(styles.PrimaryColor). Foreground(styles.Background). Bold(true) } - + sessionItems = append(sessionItems, itemStyle.Padding(0, 1).Render(sess.Title)) } title := styles.BaseStyle. Foreground(styles.PrimaryColor). Bold(true). + Width(maxWidth). Padding(0, 1). Render("Switch Session") content := lipgloss.JoinVertical( lipgloss.Left, title, - styles.BaseStyle.Render(""), - lipgloss.JoinVertical(lipgloss.Left, sessionItems...), - styles.BaseStyle.Render(""), - styles.BaseStyle.Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Render(lipgloss.JoinVertical(lipgloss.Left, sessionItems...)), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Padding(0, 1).Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), ) return styles.BaseStyle.Padding(1, 2). @@ -185,7 +186,7 @@ func (s *sessionDialogCmp) BindingKeys() []key.Binding { func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { s.sessions = sessions - + // If we have a selected session ID, find its index if s.selectedSessionID != "" { for i, sess := range sessions { @@ -195,14 +196,14 @@ func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { } } } - + // Default to first session if selected not found s.selectedIdx = 0 } func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { s.selectedSessionID = sessionID - + // Update the selected index if sessions are already loaded if len(s.sessions) > 0 { for i, sess := range s.sessions { @@ -217,8 +218,9 @@ func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { // NewSessionDialogCmp creates a new session switching dialog func NewSessionDialogCmp() SessionDialog { return &sessionDialogCmp{ - sessions: []session.Session{}, - selectedIdx: 0, + sessions: []session.Session{}, + selectedIdx: 0, selectedSessionID: "", } -} \ No newline at end of file +} + diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index b99dc3dfe..ef826e9a3 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -43,13 +43,6 @@ func (p *chatPage) Init() tea.Cmd { cmds := []tea.Cmd{ p.layout.Init(), } - - sessions, _ := p.app.Sessions.List(context.Background()) - if len(sessions) > 0 { - p.session = sessions[0] - cmd := p.setSidebar() - cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) - } return tea.Batch(cmds...) } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index f3a7298cf..2a9ed0d70 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -163,6 +163,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showPermissions = true return a, a.permissions.SetPermissions(msg.Payload) case dialog.PermissionResponseMsg: + var cmd tea.Cmd switch msg.Action { case dialog.PermissionAllow: a.app.Permissions.Grant(msg.Permission) @@ -170,9 +171,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.app.Permissions.GrantPersistant(msg.Permission) case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) + cmd = util.CmdHandler(chat.FocusEditorMsg(true)) } a.showPermissions = false - return a, nil + return a, cmd case page.PageChangeMsg: return a, a.moveToPage(msg.ID) -- cgit v1.2.3