From 0130bde1edabb81d82dbce9d2d562966d2dee133 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 14:09:17 +0200 Subject: remove node dependency and implement diff format --- internal/diff/diff.go | 995 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 995 insertions(+) create mode 100644 internal/diff/diff.go (limited to 'internal/diff') diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 000000000..4e6aa9f5b --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,995 @@ +package diff + +import ( + "bytes" + "fmt" + "io" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alecthomas/chroma/v2" + "github.com/alecthomas/chroma/v2/formatters" + "github.com/alecthomas/chroma/v2/lexers" + "github.com/alecthomas/chroma/v2/styles" + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/sergi/go-diff/diffmatchpatch" +) + +// LineType represents the kind of line in a diff. +type LineType int + +const ( + // LineContext represents a line that exists in both the old and new file. + LineContext LineType = iota + // LineAdded represents a line added in the new file. + LineAdded + // LineRemoved represents a line removed from the old file. + LineRemoved +) + +// DiffLine represents a single line in a diff, either from the old file, +// the new file, or a context line. +type DiffLine struct { + OldLineNo int // Line number in the old file (0 for added lines) + NewLineNo int // Line number in the new file (0 for removed lines) + Kind LineType // Type of line (added, removed, context) + Content string // Content of the line +} + +// Hunk represents a section of changes in a diff. +type Hunk struct { + Header string + Lines []DiffLine +} + +// DiffResult contains the parsed result of a diff. +type DiffResult struct { + OldFile string + NewFile string + Hunks []Hunk +} + +// HunkDelta represents the change statistics for a hunk. +type HunkDelta struct { + StartLine1 int + LineCount1 int + StartLine2 int + LineCount2 int +} + +// linePair represents a pair of lines to be displayed side by side. +type linePair struct { + left *DiffLine + right *DiffLine +} + +// ------------------------------------------------------------------------- +// Style Configuration with Option Pattern +// ------------------------------------------------------------------------- + +// StyleConfig defines styling for diff rendering. +type StyleConfig struct { + RemovedLineBg lipgloss.Color + AddedLineBg lipgloss.Color + ContextLineBg lipgloss.Color + HunkLineBg lipgloss.Color + HunkLineFg lipgloss.Color + RemovedFg lipgloss.Color + AddedFg lipgloss.Color + LineNumberFg lipgloss.Color + HighlightStyle string + RemovedHighlightBg lipgloss.Color + AddedHighlightBg lipgloss.Color + RemovedLineNumberBg lipgloss.Color + AddedLineNamerBg lipgloss.Color + RemovedHighlightFg lipgloss.Color + AddedHighlightFg lipgloss.Color +} + +// StyleOption defines a function that modifies a StyleConfig. +type StyleOption func(*StyleConfig) + +// NewStyleConfig creates a StyleConfig with default values and applies any provided options. +func NewStyleConfig(opts ...StyleOption) StyleConfig { + // Set default values + config := StyleConfig{ + RemovedLineBg: lipgloss.Color("#3A3030"), + AddedLineBg: lipgloss.Color("#303A30"), + ContextLineBg: lipgloss.Color("#212121"), + HunkLineBg: lipgloss.Color("#2A2822"), + HunkLineFg: lipgloss.Color("#D4AF37"), + RemovedFg: lipgloss.Color("#7C4444"), + AddedFg: lipgloss.Color("#478247"), + LineNumberFg: lipgloss.Color("#888888"), + HighlightStyle: "dracula", + RemovedHighlightBg: lipgloss.Color("#612726"), + AddedHighlightBg: lipgloss.Color("#256125"), + RemovedLineNumberBg: lipgloss.Color("#332929"), + AddedLineNamerBg: lipgloss.Color("#293229"), + RemovedHighlightFg: lipgloss.Color("#FADADD"), + AddedHighlightFg: lipgloss.Color("#DAFADA"), + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithRemovedLineBg sets the background color for removed lines. +func WithRemovedLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedLineBg = color + } +} + +// WithAddedLineBg sets the background color for added lines. +func WithAddedLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedLineBg = color + } +} + +// WithContextLineBg sets the background color for context lines. +func WithContextLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.ContextLineBg = color + } +} + +// WithRemovedFg sets the foreground color for removed line markers. +func WithRemovedFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedFg = color + } +} + +// WithAddedFg sets the foreground color for added line markers. +func WithAddedFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedFg = color + } +} + +// WithLineNumberFg sets the foreground color for line numbers. +func WithLineNumberFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.LineNumberFg = color + } +} + +// WithHighlightStyle sets the syntax highlighting style. +func WithHighlightStyle(style string) StyleOption { + return func(s *StyleConfig) { + s.HighlightStyle = style + } +} + +// WithRemovedHighlightColors sets the colors for highlighted parts in removed text. +func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedHighlightBg = bg + s.RemovedHighlightFg = fg + } +} + +// WithAddedHighlightColors sets the colors for highlighted parts in added text. +func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedHighlightBg = bg + s.AddedHighlightFg = fg + } +} + +// WithRemovedLineNumberBg sets the background color for removed line numbers. +func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedLineNumberBg = color + } +} + +// WithAddedLineNumberBg sets the background color for added line numbers. +func WithAddedLineNumberBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedLineNamerBg = color + } +} + +func WithHunkLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.HunkLineBg = color + } +} + +func WithHunkLineFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.HunkLineFg = color + } +} + +// ------------------------------------------------------------------------- +// Parse Options with Option Pattern +// ------------------------------------------------------------------------- + +// ParseConfig configures the behavior of diff parsing. +type ParseConfig struct { + ContextSize int // Number of context lines to include +} + +// ParseOption defines a function that modifies a ParseConfig. +type ParseOption func(*ParseConfig) + +// NewParseConfig creates a ParseConfig with default values and applies any provided options. +func NewParseConfig(opts ...ParseOption) ParseConfig { + // Set default values + config := ParseConfig{ + ContextSize: 3, + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithContextSize sets the number of context lines to include. +func WithContextSize(size int) ParseOption { + return func(p *ParseConfig) { + if size >= 0 { + p.ContextSize = size + } + } +} + +// ------------------------------------------------------------------------- +// Side-by-Side Options with Option Pattern +// ------------------------------------------------------------------------- + +// SideBySideConfig configures the rendering of side-by-side diffs. +type SideBySideConfig struct { + TotalWidth int + Style StyleConfig +} + +// SideBySideOption defines a function that modifies a SideBySideConfig. +type SideBySideOption func(*SideBySideConfig) + +// NewSideBySideConfig creates a SideBySideConfig with default values and applies any provided options. +func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { + // Set default values + config := SideBySideConfig{ + TotalWidth: 160, // Default width for side-by-side view + Style: NewStyleConfig(), + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithTotalWidth sets the total width for side-by-side view. +func WithTotalWidth(width int) SideBySideOption { + return func(s *SideBySideConfig) { + if width > 0 { + s.TotalWidth = width + } + } +} + +// WithStyle sets the styling configuration. +func WithStyle(style StyleConfig) SideBySideOption { + return func(s *SideBySideConfig) { + s.Style = style + } +} + +// WithStyleOptions applies the specified style options. +func WithStyleOptions(opts ...StyleOption) SideBySideOption { + return func(s *SideBySideConfig) { + s.Style = NewStyleConfig(opts...) + } +} + +// ------------------------------------------------------------------------- +// Diff Parsing and Generation +// ------------------------------------------------------------------------- + +// ParseUnifiedDiff parses a unified diff format string into structured data. +func ParseUnifiedDiff(diff string) (DiffResult, error) { + var result DiffResult + var currentHunk *Hunk + + hunkHeaderRe := regexp.MustCompile(`^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@`) + lines := strings.Split(diff, "\n") + + var oldLine, newLine int + inFileHeader := true + + for _, line := range lines { + // Parse the file headers + if inFileHeader { + if strings.HasPrefix(line, "--- a/") { + result.OldFile = strings.TrimPrefix(line, "--- a/") + continue + } + if strings.HasPrefix(line, "+++ b/") { + result.NewFile = strings.TrimPrefix(line, "+++ b/") + inFileHeader = false + continue + } + } + + // Parse hunk headers + if matches := hunkHeaderRe.FindStringSubmatch(line); matches != nil { + if currentHunk != nil { + result.Hunks = append(result.Hunks, *currentHunk) + } + currentHunk = &Hunk{ + Header: line, + Lines: []DiffLine{}, + } + + oldStart, _ := strconv.Atoi(matches[1]) + newStart, _ := strconv.Atoi(matches[3]) + oldLine = oldStart + newLine = newStart + + continue + } + + if currentHunk == nil { + continue + } + + if len(line) > 0 { + // Process the line based on its prefix + switch line[0] { + case '+': + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: 0, + NewLineNo: newLine, + Kind: LineAdded, + Content: line[1:], // skip '+' + }) + newLine++ + case '-': + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: 0, + Kind: LineRemoved, + Content: line[1:], // skip '-' + }) + oldLine++ + default: + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: newLine, + Kind: LineContext, + Content: line, + }) + oldLine++ + newLine++ + } + } else { + // Handle empty lines + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: newLine, + Kind: LineContext, + Content: "", + }) + oldLine++ + newLine++ + } + } + + // Add the last hunk if there is one + if currentHunk != nil { + result.Hunks = append(result.Hunks, *currentHunk) + } + + return result, nil +} + +// HighlightIntralineChanges updates the content of lines in a hunk to show +// character-level differences within lines. +func HighlightIntralineChanges(h *Hunk, style StyleConfig) { + var updated []DiffLine + dmp := diffmatchpatch.New() + + for i := 0; i < len(h.Lines); i++ { + // Look for removed line followed by added line, which might have similar content + if i+1 < len(h.Lines) && + h.Lines[i].Kind == LineRemoved && + h.Lines[i+1].Kind == LineAdded { + + oldLine := h.Lines[i] + newLine := h.Lines[i+1] + + // Find character-level differences + patches := dmp.DiffMain(oldLine.Content, newLine.Content, false) + patches = dmp.DiffCleanupEfficiency(patches) + patches = dmp.DiffCleanupSemantic(patches) + + // Apply highlighting to the differences + oldLine.Content = colorizeSegments(patches, true, style) + newLine.Content = colorizeSegments(patches, false, style) + + updated = append(updated, oldLine, newLine) + i++ // Skip the next line as we've already processed it + } else { + updated = append(updated, h.Lines[i]) + } + } + + h.Lines = updated +} + +// colorizeSegments applies styles to the character-level diff segments. +func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig) string { + var buf strings.Builder + + removeBg := lipgloss.NewStyle(). + Background(style.RemovedHighlightBg). + Foreground(style.RemovedHighlightFg) + + addBg := lipgloss.NewStyle(). + Background(style.AddedHighlightBg). + Foreground(style.AddedHighlightFg) + + removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) + addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) + + afterBg := false + + for _, d := range diffs { + switch d.Type { + case diffmatchpatch.DiffEqual: + // Handle text that's the same in both versions + if afterBg { + if isOld { + buf.WriteString(removedLineStyle.Render(d.Text)) + } else { + buf.WriteString(addedLineStyle.Render(d.Text)) + } + } else { + buf.WriteString(d.Text) + } + case diffmatchpatch.DiffDelete: + // Handle deleted text (only show in old version) + if isOld { + buf.WriteString(removeBg.Render(d.Text)) + afterBg = true + } + case diffmatchpatch.DiffInsert: + // Handle inserted text (only show in new version) + if !isOld { + buf.WriteString(addBg.Render(d.Text)) + afterBg = true + } + } + } + + return buf.String() +} + +// pairLines converts a flat list of diff lines to pairs for side-by-side display. +func pairLines(lines []DiffLine) []linePair { + var pairs []linePair + i := 0 + + for i < len(lines) { + switch lines[i].Kind { + case LineRemoved: + // Check if the next line is an addition, if so pair them + if i+1 < len(lines) && lines[i+1].Kind == LineAdded { + pairs = append(pairs, linePair{left: &lines[i], right: &lines[i+1]}) + i += 2 + } else { + pairs = append(pairs, linePair{left: &lines[i], right: nil}) + i++ + } + case LineAdded: + pairs = append(pairs, linePair{left: nil, right: &lines[i]}) + i++ + case LineContext: + pairs = append(pairs, linePair{left: &lines[i], right: &lines[i]}) + i++ + } + } + + return pairs +} + +// ------------------------------------------------------------------------- +// Syntax Highlighting +// ------------------------------------------------------------------------- + +// SyntaxHighlight applies syntax highlighting to a string based on the file extension. +func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error { + // Determine the language lexer to use + l := lexers.Match(fileName) + if l == nil { + l = lexers.Analyse(source) + } + if l == nil { + l = lexers.Fallback + } + l = chroma.Coalesce(l) + + // Get the formatter + f := formatters.Get(formatter) + if f == nil { + f = formatters.Fallback + } + + // Get the style + s := styles.Get("dracula") + if s == nil { + s = styles.Fallback + } + + // Modify the style to use the provided background + s, err := s.Builder().Transform( + func(t chroma.StyleEntry) chroma.StyleEntry { + r, g, b, _ := bg.RGBA() + ru8 := uint8(r >> 8) + gu8 := uint8(g >> 8) + bu8 := uint8(b >> 8) + t.Background = chroma.NewColour(ru8, gu8, bu8) + return t + }, + ).Build() + if err != nil { + s = styles.Fallback + } + + // Tokenize and format + it, err := l.Tokenise(nil, source) + if err != nil { + return err + } + + return f.Format(w, s, it) +} + +// highlightLine applies syntax highlighting to a single line. +func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string { + var buf bytes.Buffer + err := SyntaxHighlight(&buf, line, fileName, "terminal16m", bg) + if err != nil { + return line + } + return buf.String() +} + +// createStyles generates the lipgloss styles needed for rendering diffs. +func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) { + removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg) + addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg) + contextLineStyle = lipgloss.NewStyle().Background(config.ContextLineBg) + lineNumberStyle = lipgloss.NewStyle().Foreground(config.LineNumberFg) + + return +} + +// renderLeftColumn formats the left side of a side-by-side diff. +func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { + if dl == nil { + contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) + return contextLineStyle.Width(colWidth).Render("") + } + + removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles) + + var marker string + var bgStyle lipgloss.Style + + switch dl.Kind { + case LineRemoved: + marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-") + bgStyle = removedLineStyle + lineNumberStyle = lineNumberStyle.Foreground(styles.RemovedFg).Background(styles.RemovedLineNumberBg) + case LineAdded: + marker = "?" + bgStyle = contextLineStyle + case LineContext: + marker = contextLineStyle.Render(" ") + bgStyle = contextLineStyle + } + + lineNum := "" + if dl.OldLineNo > 0 { + lineNum = fmt.Sprintf("%6d", dl.OldLineNo) + } + + prefix := lineNumberStyle.Render(lineNum + " " + marker) + content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + + if dl.Kind == LineRemoved { + content = bgStyle.Render(" ") + content + } + + lineText := prefix + content + return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) +} + +// renderRightColumn formats the right side of a side-by-side diff. +func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { + if dl == nil { + contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) + return contextLineStyle.Width(colWidth).Render("") + } + + _, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles) + + var marker string + var bgStyle lipgloss.Style + + switch dl.Kind { + case LineAdded: + marker = addedLineStyle.Foreground(styles.AddedFg).Render("+") + bgStyle = addedLineStyle + lineNumberStyle = lineNumberStyle.Foreground(styles.AddedFg).Background(styles.AddedLineNamerBg) + case LineRemoved: + marker = "?" + bgStyle = contextLineStyle + case LineContext: + marker = contextLineStyle.Render(" ") + bgStyle = contextLineStyle + } + + lineNum := "" + if dl.NewLineNo > 0 { + lineNum = fmt.Sprintf("%6d", dl.NewLineNo) + } + + prefix := lineNumberStyle.Render(lineNum + " " + marker) + content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + + if dl.Kind == LineAdded { + content = bgStyle.Render(" ") + content + } + + lineText := prefix + content + return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) +} + +// ------------------------------------------------------------------------- +// Public API Methods +// ------------------------------------------------------------------------- + +// RenderSideBySideHunk formats a hunk for side-by-side display. +func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) string { + // Apply options to create the configuration + config := NewSideBySideConfig(opts...) + + // Make a copy of the hunk so we don't modify the original + hunkCopy := Hunk{Lines: make([]DiffLine, len(h.Lines))} + copy(hunkCopy.Lines, h.Lines) + + // Highlight changes within lines + HighlightIntralineChanges(&hunkCopy, config.Style) + + // Pair lines for side-by-side display + pairs := pairLines(hunkCopy.Lines) + + // Calculate column width + colWidth := config.TotalWidth / 2 + + var sb strings.Builder + for _, p := range pairs { + leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style) + rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style) + sb.WriteString(leftStr + rightStr + "\n") + } + + return sb.String() +} + +// FormatDiff creates a side-by-side formatted view of a diff. +func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { + diffResult, err := ParseUnifiedDiff(diffText) + if err != nil { + return "", err + } + + var sb strings.Builder + + config := NewSideBySideConfig(opts...) + for i, h := range diffResult.Hunks { + if i > 0 { + sb.WriteString(lipgloss.NewStyle().Background(config.Style.HunkLineBg).Foreground(config.Style.HunkLineFg).Width(config.TotalWidth).Render(h.Header) + "\n") + } + sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) + } + + return sb.String(), nil +} + +// GenerateDiff creates a unified diff from two file contents. +func GenerateDiff(beforeContent, afterContent, beforeFilename, afterFilename string, opts ...ParseOption) (string, int, int) { + config := NewParseConfig(opts...) + + var output strings.Builder + + // Ensure we handle newlines correctly + beforeHasNewline := len(beforeContent) > 0 && beforeContent[len(beforeContent)-1] == '\n' + afterHasNewline := len(afterContent) > 0 && afterContent[len(afterContent)-1] == '\n' + + // Split into lines + beforeLines := strings.Split(beforeContent, "\n") + afterLines := strings.Split(afterContent, "\n") + + // Remove empty trailing element from the split if the content ended with a newline + if beforeHasNewline && len(beforeLines) > 0 { + beforeLines = beforeLines[:len(beforeLines)-1] + } + if afterHasNewline && len(afterLines) > 0 { + afterLines = afterLines[:len(afterLines)-1] + } + + dmp := diffmatchpatch.New() + dmp.DiffTimeout = 5 * time.Second + + // Convert lines to characters for efficient diffing + lineArray1, lineArray2, lineArrays := dmp.DiffLinesToChars(beforeContent, afterContent) + diffs := dmp.DiffMain(lineArray1, lineArray2, false) + diffs = dmp.DiffCharsToLines(diffs, lineArrays) + + // Default filenames if not provided + if beforeFilename == "" { + beforeFilename = "a" + } + if afterFilename == "" { + afterFilename = "b" + } + + // Write diff header + output.WriteString(fmt.Sprintf("diff --git a/%s b/%s\n", beforeFilename, afterFilename)) + output.WriteString(fmt.Sprintf("--- a/%s\n", beforeFilename)) + output.WriteString(fmt.Sprintf("+++ b/%s\n", afterFilename)) + + line1 := 0 // Line numbers start from 0 internally + line2 := 0 + additions := 0 + deletions := 0 + + var hunks []string + var currentHunk strings.Builder + var hunkStartLine1, hunkStartLine2 int + var hunkLines1, hunkLines2 int + inHunk := false + + contextSize := config.ContextSize + + // startHunk begins recording a new hunk + startHunk := func(startLine1, startLine2 int) { + inHunk = true + hunkStartLine1 = startLine1 + hunkStartLine2 = startLine2 + hunkLines1 = 0 + hunkLines2 = 0 + currentHunk.Reset() + } + + // writeHunk adds the current hunk to the hunks slice + writeHunk := func() { + if inHunk { + hunkHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@\n", + hunkStartLine1+1, hunkLines1, + hunkStartLine2+1, hunkLines2) + hunks = append(hunks, hunkHeader+currentHunk.String()) + inHunk = false + } + } + + // Process diffs to create hunks + pendingContext := make([]string, 0, contextSize*2) + var contextLines1, contextLines2 int + + // Helper function to add context lines to the hunk + addContextToHunk := func(lines []string, count int) { + for i := 0; i < count; i++ { + if i < len(lines) { + currentHunk.WriteString(" " + lines[i] + "\n") + hunkLines1++ + hunkLines2++ + } + } + } + + // Process diffs + for _, diff := range diffs { + lines := strings.Split(diff.Text, "\n") + + // Remove empty trailing line that comes from splitting a string that ends with \n + if len(lines) > 0 && lines[len(lines)-1] == "" && diff.Text[len(diff.Text)-1] == '\n' { + lines = lines[:len(lines)-1] + } + + switch diff.Type { + case diffmatchpatch.DiffEqual: + // If we have enough equal lines to serve as context, add them to pending + pendingContext = append(pendingContext, lines...) + + // If pending context grows too large, trim it + if len(pendingContext) > contextSize*2 { + pendingContext = pendingContext[len(pendingContext)-contextSize*2:] + } + + // If we're in a hunk, add the necessary context + if inHunk { + // Only add the first contextSize lines as trailing context + numContextLines := min(contextSize, len(lines)) + addContextToHunk(lines[:numContextLines], numContextLines) + + // If we've added enough trailing context, close the hunk + if numContextLines >= contextSize { + writeHunk() + } + } + + line1 += len(lines) + line2 += len(lines) + contextLines1 += len(lines) + contextLines2 += len(lines) + + case diffmatchpatch.DiffDelete, diffmatchpatch.DiffInsert: + // Start a new hunk if needed + if !inHunk { + // Determine how many context lines we can add before + contextBefore := min(contextSize, len(pendingContext)) + ctxStartIdx := len(pendingContext) - contextBefore + + // Calculate the correct start lines + startLine1 := line1 - contextLines1 + ctxStartIdx + startLine2 := line2 - contextLines2 + ctxStartIdx + + startHunk(startLine1, startLine2) + + // Add the context lines before + addContextToHunk(pendingContext[ctxStartIdx:], contextBefore) + } + + // Reset context tracking when we see a diff + pendingContext = pendingContext[:0] + contextLines1 = 0 + contextLines2 = 0 + + // Add the changes + if diff.Type == diffmatchpatch.DiffDelete { + for _, line := range lines { + currentHunk.WriteString("-" + line + "\n") + hunkLines1++ + deletions++ + } + line1 += len(lines) + } else { // DiffInsert + for _, line := range lines { + currentHunk.WriteString("+" + line + "\n") + hunkLines2++ + additions++ + } + line2 += len(lines) + } + } + } + + // Write the final hunk if there's one pending + if inHunk { + writeHunk() + } + + // Merge hunks that are close to each other (within 2*contextSize lines) + var mergedHunks []string + if len(hunks) > 0 { + mergedHunks = append(mergedHunks, hunks[0]) + + for i := 1; i < len(hunks); i++ { + prevHunk := mergedHunks[len(mergedHunks)-1] + currHunk := hunks[i] + + // Extract line numbers to check proximity + var prevStart, prevLen, currStart, currLen int + fmt.Sscanf(prevHunk, "@@ -%d,%d", &prevStart, &prevLen) + fmt.Sscanf(currHunk, "@@ -%d,%d", &currStart, &currLen) + + prevEnd := prevStart + prevLen - 1 + + // If hunks are close, merge them + if currStart-prevEnd <= contextSize*2 { + // Create a merged hunk - this is a simplification, real git has more complex merging logic + merged := mergeHunks(prevHunk, currHunk) + mergedHunks[len(mergedHunks)-1] = merged + } else { + mergedHunks = append(mergedHunks, currHunk) + } + } + } + + // Write all hunks to output + for _, hunk := range mergedHunks { + output.WriteString(hunk) + } + + // Handle "No newline at end of file" notifications + if !beforeHasNewline && len(beforeLines) > 0 { + // Find the last deletion in the diff and add the notification after it + lastPos := strings.LastIndex(output.String(), "\n-") + if lastPos != -1 { + // Insert the notification after the line + str := output.String() + output.Reset() + output.WriteString(str[:lastPos+1]) + output.WriteString("\\ No newline at end of file\n") + output.WriteString(str[lastPos+1:]) + } + } + + if !afterHasNewline && len(afterLines) > 0 { + // Find the last insertion in the diff and add the notification after it + lastPos := strings.LastIndex(output.String(), "\n+") + if lastPos != -1 { + // Insert the notification after the line + str := output.String() + output.Reset() + output.WriteString(str[:lastPos+1]) + output.WriteString("\\ No newline at end of file\n") + output.WriteString(str[lastPos+1:]) + } + } + + // Return the diff without the summary line + return output.String(), additions, deletions +} + +// Helper function to merge two hunks +func mergeHunks(hunk1, hunk2 string) string { + // This is a simplified implementation + // A full implementation would need to properly recalculate the hunk header + // and remove redundant context lines + + // Extract header info from both hunks + var start1, len1, start2, len2 int + var startB1, lenB1, startB2, lenB2 int + + fmt.Sscanf(hunk1, "@@ -%d,%d +%d,%d @@", &start1, &len1, &startB1, &lenB1) + fmt.Sscanf(hunk2, "@@ -%d,%d +%d,%d @@", &start2, &len2, &startB2, &lenB2) + + // Split the hunks to get content + parts1 := strings.SplitN(hunk1, "\n", 2) + parts2 := strings.SplitN(hunk2, "\n", 2) + + content1 := "" + content2 := "" + + if len(parts1) > 1 { + content1 = parts1[1] + } + if len(parts2) > 1 { + content2 = parts2[1] + } + + // Calculate the new header + newEnd := max(start1+len1-1, start2+len2-1) + newEndB := max(startB1+lenB1-1, startB2+lenB2-1) + + newLen := newEnd - start1 + 1 + newLenB := newEndB - startB1 + 1 + + newHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@", start1, newLen, startB1, newLenB) + + // Combine the content, potentially with some overlap handling + return newHeader + "\n" + content1 + content2 +} -- cgit v1.2.3 From 013694832f4c5a7819bfd9a801346e4c3fb22e77 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 15:48:01 +0200 Subject: fix diff --- go.mod | 20 ++- go.sum | 57 ++++++- internal/diff/diff.go | 362 ++++++++++---------------------------------- internal/llm/tools/edit.go | 3 - internal/llm/tools/write.go | 1 - 5 files changed, 155 insertions(+), 288 deletions(-) (limited to 'internal/diff') diff --git a/go.mod b/go.mod index b201be800..925a71097 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/ansi v0.8.0 github.com/fsnotify/fsnotify v1.8.0 + github.com/go-git/go-git/v5 v5.15.0 github.com/go-logfmt/logfmt v0.6.0 github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 @@ -46,6 +47,9 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect + dario.cat/mergo v1.0.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect @@ -68,15 +72,20 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/emirpasic/gods v1.18.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect + github.com/go-git/go-billy/v5 v5.6.2 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect @@ -84,6 +93,8 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect + github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -91,11 +102,12 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -105,6 +117,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.8 // indirect @@ -118,7 +131,6 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect @@ -132,6 +144,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect google.golang.org/grpc v1.67.3 // indirect google.golang.org/protobuf v1.36.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 08e7e7c42..9c2c2df8f 100644 --- a/go.sum +++ b/go.sum @@ -10,10 +10,17 @@ cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4 cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k= github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw= +github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE= github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= @@ -24,8 +31,12 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60MteeW23iKeEtBoY7bYZk= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= @@ -88,7 +99,11 @@ github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= +github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,6 +111,10 @@ github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yA github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= +github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -104,6 +123,16 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= +github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= +github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= +github.com/go-git/go-git/v5 v5.15.0 h1:f5Qn0W0F7ry1iN0ZwIU5m/n7/BKB4hiZfc+zlZx7ly0= +github.com/go-git/go-git/v5 v5.15.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -115,6 +144,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -138,8 +169,11 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= +github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= +github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -179,11 +213,17 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894= github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= +github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -203,6 +243,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= +github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= @@ -216,6 +259,7 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY= github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -232,6 +276,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= +github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -260,6 +306,7 @@ golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= @@ -277,6 +324,7 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= @@ -294,10 +342,14 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -320,6 +372,7 @@ golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -348,6 +401,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 4e6aa9f5b..c4088d329 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "io" + "os" + "path/filepath" "regexp" "strconv" "strings" @@ -15,6 +17,8 @@ import ( "github.com/alecthomas/chroma/v2/styles" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/object" "github.com/sergi/go-diff/diffmatchpatch" ) @@ -224,21 +228,6 @@ type ParseConfig struct { // ParseOption defines a function that modifies a ParseConfig. type ParseOption func(*ParseConfig) -// NewParseConfig creates a ParseConfig with default values and applies any provided options. -func NewParseConfig(opts ...ParseOption) ParseConfig { - // Set default values - config := ParseConfig{ - ContextSize: 3, - } - - // Apply all provided options - for _, opt := range opts { - opt(&config) - } - - return config -} - // WithContextSize sets the number of context lines to include. func WithContextSize(size int) ParseOption { return func(p *ParseConfig) { @@ -347,6 +336,10 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { continue } + // ignore the \\ No newline at end of file + if strings.HasPrefix(line, "\\ No newline at end of file") { + continue + } if currentHunk == nil { continue } @@ -450,32 +443,22 @@ func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) - afterBg := false - for _, d := range diffs { switch d.Type { case diffmatchpatch.DiffEqual: // Handle text that's the same in both versions - if afterBg { - if isOld { - buf.WriteString(removedLineStyle.Render(d.Text)) - } else { - buf.WriteString(addedLineStyle.Render(d.Text)) - } - } else { - buf.WriteString(d.Text) - } + buf.WriteString(d.Text) case diffmatchpatch.DiffDelete: // Handle deleted text (only show in old version) if isOld { buf.WriteString(removeBg.Render(d.Text)) - afterBg = true + buf.WriteString(removedLineStyle.Render("")) } case diffmatchpatch.DiffInsert: // Handle inserted text (only show in new version) if !isOld { buf.WriteString(addBg.Render(d.Text)) - afterBg = true + buf.WriteString(addedLineStyle.Render("")) } } } @@ -621,7 +604,13 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC } lineText := prefix + content - return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) + return bgStyle.MaxHeight(1).Width(colWidth).Render( + ansi.Truncate( + lineText, + colWidth, + lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."), + ), + ) } // renderRightColumn formats the right side of a side-by-side diff. @@ -662,7 +651,13 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style } lineText := prefix + content - return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) + return bgStyle.MaxHeight(1).Width(colWidth).Render( + ansi.Truncate( + lineText, + colWidth, + lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."), + ), + ) } // ------------------------------------------------------------------------- @@ -718,278 +713,87 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { } // GenerateDiff creates a unified diff from two file contents. -func GenerateDiff(beforeContent, afterContent, beforeFilename, afterFilename string, opts ...ParseOption) (string, int, int) { - config := NewParseConfig(opts...) - - var output strings.Builder - - // Ensure we handle newlines correctly - beforeHasNewline := len(beforeContent) > 0 && beforeContent[len(beforeContent)-1] == '\n' - afterHasNewline := len(afterContent) > 0 && afterContent[len(afterContent)-1] == '\n' - - // Split into lines - beforeLines := strings.Split(beforeContent, "\n") - afterLines := strings.Split(afterContent, "\n") - - // Remove empty trailing element from the split if the content ended with a newline - if beforeHasNewline && len(beforeLines) > 0 { - beforeLines = beforeLines[:len(beforeLines)-1] - } - if afterHasNewline && len(afterLines) > 0 { - afterLines = afterLines[:len(afterLines)-1] +func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", 0, 0 } + defer os.RemoveAll(tempDir) - dmp := diffmatchpatch.New() - dmp.DiffTimeout = 5 * time.Second + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", 0, 0 + } - // Convert lines to characters for efficient diffing - lineArray1, lineArray2, lineArrays := dmp.DiffLinesToChars(beforeContent, afterContent) - diffs := dmp.DiffMain(lineArray1, lineArray2, false) - diffs = dmp.DiffCharsToLines(diffs, lineArrays) + wt, err := repo.Worktree() + if err != nil { + return "", 0, 0 + } - // Default filenames if not provided - if beforeFilename == "" { - beforeFilename = "a" + fullPath := filepath.Join(tempDir, fileName) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", 0, 0 } - if afterFilename == "" { - afterFilename = "b" + if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil { + return "", 0, 0 } - // Write diff header - output.WriteString(fmt.Sprintf("diff --git a/%s b/%s\n", beforeFilename, afterFilename)) - output.WriteString(fmt.Sprintf("--- a/%s\n", beforeFilename)) - output.WriteString(fmt.Sprintf("+++ b/%s\n", afterFilename)) - - line1 := 0 // Line numbers start from 0 internally - line2 := 0 - additions := 0 - deletions := 0 - - var hunks []string - var currentHunk strings.Builder - var hunkStartLine1, hunkStartLine2 int - var hunkLines1, hunkLines2 int - inHunk := false - - contextSize := config.ContextSize - - // startHunk begins recording a new hunk - startHunk := func(startLine1, startLine2 int) { - inHunk = true - hunkStartLine1 = startLine1 - hunkStartLine2 = startLine2 - hunkLines1 = 0 - hunkLines2 = 0 - currentHunk.Reset() - } - - // writeHunk adds the current hunk to the hunks slice - writeHunk := func() { - if inHunk { - hunkHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@\n", - hunkStartLine1+1, hunkLines1, - hunkStartLine2+1, hunkLines2) - hunks = append(hunks, hunkHeader+currentHunk.String()) - inHunk = false - } + _, err = wt.Add(fileName) + if err != nil { + return "", 0, 0 } - // Process diffs to create hunks - pendingContext := make([]string, 0, contextSize*2) - var contextLines1, contextLines2 int - - // Helper function to add context lines to the hunk - addContextToHunk := func(lines []string, count int) { - for i := 0; i < count; i++ { - if i < len(lines) { - currentHunk.WriteString(" " + lines[i] + "\n") - hunkLines1++ - hunkLines2++ - } - } + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", 0, 0 } - // Process diffs - for _, diff := range diffs { - lines := strings.Split(diff.Text, "\n") - - // Remove empty trailing line that comes from splitting a string that ends with \n - if len(lines) > 0 && lines[len(lines)-1] == "" && diff.Text[len(diff.Text)-1] == '\n' { - lines = lines[:len(lines)-1] - } - - switch diff.Type { - case diffmatchpatch.DiffEqual: - // If we have enough equal lines to serve as context, add them to pending - pendingContext = append(pendingContext, lines...) - - // If pending context grows too large, trim it - if len(pendingContext) > contextSize*2 { - pendingContext = pendingContext[len(pendingContext)-contextSize*2:] - } - - // If we're in a hunk, add the necessary context - if inHunk { - // Only add the first contextSize lines as trailing context - numContextLines := min(contextSize, len(lines)) - addContextToHunk(lines[:numContextLines], numContextLines) - - // If we've added enough trailing context, close the hunk - if numContextLines >= contextSize { - writeHunk() - } - } - - line1 += len(lines) - line2 += len(lines) - contextLines1 += len(lines) - contextLines2 += len(lines) - - case diffmatchpatch.DiffDelete, diffmatchpatch.DiffInsert: - // Start a new hunk if needed - if !inHunk { - // Determine how many context lines we can add before - contextBefore := min(contextSize, len(pendingContext)) - ctxStartIdx := len(pendingContext) - contextBefore - - // Calculate the correct start lines - startLine1 := line1 - contextLines1 + ctxStartIdx - startLine2 := line2 - contextLines2 + ctxStartIdx - - startHunk(startLine1, startLine2) - - // Add the context lines before - addContextToHunk(pendingContext[ctxStartIdx:], contextBefore) - } - - // Reset context tracking when we see a diff - pendingContext = pendingContext[:0] - contextLines1 = 0 - contextLines2 = 0 - - // Add the changes - if diff.Type == diffmatchpatch.DiffDelete { - for _, line := range lines { - currentHunk.WriteString("-" + line + "\n") - hunkLines1++ - deletions++ - } - line1 += len(lines) - } else { // DiffInsert - for _, line := range lines { - currentHunk.WriteString("+" + line + "\n") - hunkLines2++ - additions++ - } - line2 += len(lines) - } - } + if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { } - // Write the final hunk if there's one pending - if inHunk { - writeHunk() + _, err = wt.Add(fileName) + if err != nil { + return "", 0, 0 } - // Merge hunks that are close to each other (within 2*contextSize lines) - var mergedHunks []string - if len(hunks) > 0 { - mergedHunks = append(mergedHunks, hunks[0]) - - for i := 1; i < len(hunks); i++ { - prevHunk := mergedHunks[len(mergedHunks)-1] - currHunk := hunks[i] - - // Extract line numbers to check proximity - var prevStart, prevLen, currStart, currLen int - fmt.Sscanf(prevHunk, "@@ -%d,%d", &prevStart, &prevLen) - fmt.Sscanf(currHunk, "@@ -%d,%d", &currStart, &currLen) - - prevEnd := prevStart + prevLen - 1 - - // If hunks are close, merge them - if currStart-prevEnd <= contextSize*2 { - // Create a merged hunk - this is a simplification, real git has more complex merging logic - merged := mergeHunks(prevHunk, currHunk) - mergedHunks[len(mergedHunks)-1] = merged - } else { - mergedHunks = append(mergedHunks, currHunk) - } - } + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", 0, 0 } - // Write all hunks to output - for _, hunk := range mergedHunks { - output.WriteString(hunk) + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", 0, 0 } - // Handle "No newline at end of file" notifications - if !beforeHasNewline && len(beforeLines) > 0 { - // Find the last deletion in the diff and add the notification after it - lastPos := strings.LastIndex(output.String(), "\n-") - if lastPos != -1 { - // Insert the notification after the line - str := output.String() - output.Reset() - output.WriteString(str[:lastPos+1]) - output.WriteString("\\ No newline at end of file\n") - output.WriteString(str[lastPos+1:]) - } + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", 0, 0 } - if !afterHasNewline && len(afterLines) > 0 { - // Find the last insertion in the diff and add the notification after it - lastPos := strings.LastIndex(output.String(), "\n+") - if lastPos != -1 { - // Insert the notification after the line - str := output.String() - output.Reset() - output.WriteString(str[:lastPos+1]) - output.WriteString("\\ No newline at end of file\n") - output.WriteString(str[lastPos+1:]) - } + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", 0, 0 } - // Return the diff without the summary line - return output.String(), additions, deletions -} - -// Helper function to merge two hunks -func mergeHunks(hunk1, hunk2 string) string { - // This is a simplified implementation - // A full implementation would need to properly recalculate the hunk header - // and remove redundant context lines - - // Extract header info from both hunks - var start1, len1, start2, len2 int - var startB1, lenB1, startB2, lenB2 int - - fmt.Sscanf(hunk1, "@@ -%d,%d +%d,%d @@", &start1, &len1, &startB1, &lenB1) - fmt.Sscanf(hunk2, "@@ -%d,%d +%d,%d @@", &start2, &len2, &startB2, &lenB2) - - // Split the hunks to get content - parts1 := strings.SplitN(hunk1, "\n", 2) - parts2 := strings.SplitN(hunk2, "\n", 2) - - content1 := "" - content2 := "" - - if len(parts1) > 1 { - content1 = parts1[1] - } - if len(parts2) > 1 { - content2 = parts2[1] + additions := 0 + removals := 0 + for _, fileStat := range patch.Stats() { + additions += fileStat.Addition + removals += fileStat.Deletion } - // Calculate the new header - newEnd := max(start1+len1-1, start2+len2-1) - newEndB := max(startB1+lenB1-1, startB2+lenB2-1) - - newLen := newEnd - start1 + 1 - newLenB := newEndB - startB1 + 1 - - newHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@", start1, newLen, startB1, newLenB) - - // Combine the content, potentially with some overlap handling - return newHeader + "\n" + content1 + content2 + return patch.String(), additions, removals } diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 1305879b9..08d6d446c 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -186,7 +186,6 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) "", content, filePath, - filePath, ) p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -277,7 +276,6 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string oldContent, newContent, filePath, - filePath, ) p := e.permissions.Request( @@ -367,7 +365,6 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS oldContent, newContent, filePath, - filePath, ) p := e.permissions.Request( permission.CreatePermissionRequest{ diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index ef2ca01f4..889561d2a 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -154,7 +154,6 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error oldContent, params.Content, filePath, - filePath, ) p := w.permissions.Request( permission.CreatePermissionRequest{ -- cgit v1.2.3 From f6be348bf704ab3d012eec549357f5acd9c74796 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 20:12:28 +0200 Subject: fix segment diff and add new theme --- internal/diff/diff.go | 523 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 348 insertions(+), 175 deletions(-) (limited to 'internal/diff') diff --git a/internal/diff/diff.go b/internal/diff/diff.go index c4088d329..02d4d7140 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -22,89 +22,95 @@ import ( "github.com/sergi/go-diff/diffmatchpatch" ) +// ------------------------------------------------------------------------- +// Core Types +// ------------------------------------------------------------------------- + // LineType represents the kind of line in a diff. type LineType int const ( - // LineContext represents a line that exists in both the old and new file. - LineContext LineType = iota - // LineAdded represents a line added in the new file. - LineAdded - // LineRemoved represents a line removed from the old file. - LineRemoved + LineContext LineType = iota // Line exists in both files + LineAdded // Line added in the new file + LineRemoved // Line removed from the old file ) -// DiffLine represents a single line in a diff, either from the old file, -// the new file, or a context line. +// Segment represents a portion of a line for intra-line highlighting +type Segment struct { + Start int + End int + Type LineType + Text string +} + +// DiffLine represents a single line in a diff type DiffLine struct { - OldLineNo int // Line number in the old file (0 for added lines) - NewLineNo int // Line number in the new file (0 for removed lines) - Kind LineType // Type of line (added, removed, context) - Content string // Content of the line + OldLineNo int // Line number in old file (0 for added lines) + NewLineNo int // Line number in new file (0 for removed lines) + Kind LineType // Type of line (added, removed, context) + Content string // Content of the line + Segments []Segment // Segments for intraline highlighting } -// Hunk represents a section of changes in a diff. +// Hunk represents a section of changes in a diff type Hunk struct { Header string Lines []DiffLine } -// DiffResult contains the parsed result of a diff. +// DiffResult contains the parsed result of a diff type DiffResult struct { OldFile string NewFile string Hunks []Hunk } -// HunkDelta represents the change statistics for a hunk. -type HunkDelta struct { - StartLine1 int - LineCount1 int - StartLine2 int - LineCount2 int -} - -// linePair represents a pair of lines to be displayed side by side. +// linePair represents a pair of lines for side-by-side display type linePair struct { left *DiffLine right *DiffLine } // ------------------------------------------------------------------------- -// Style Configuration with Option Pattern +// Style Configuration // ------------------------------------------------------------------------- -// StyleConfig defines styling for diff rendering. +// StyleConfig defines styling for diff rendering type StyleConfig struct { + // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color ContextLineBg lipgloss.Color HunkLineBg lipgloss.Color - HunkLineFg lipgloss.Color - RemovedFg lipgloss.Color - AddedFg lipgloss.Color - LineNumberFg lipgloss.Color - HighlightStyle string - RemovedHighlightBg lipgloss.Color - AddedHighlightBg lipgloss.Color RemovedLineNumberBg lipgloss.Color AddedLineNamerBg lipgloss.Color - RemovedHighlightFg lipgloss.Color - AddedHighlightFg lipgloss.Color + + // Foreground colors + HunkLineFg lipgloss.Color + RemovedFg lipgloss.Color + AddedFg lipgloss.Color + LineNumberFg lipgloss.Color + RemovedHighlightFg lipgloss.Color + AddedHighlightFg lipgloss.Color + + // Highlight settings + HighlightStyle string + RemovedHighlightBg lipgloss.Color + AddedHighlightBg lipgloss.Color } -// StyleOption defines a function that modifies a StyleConfig. +// StyleOption is a function that modifies a StyleConfig type StyleOption func(*StyleConfig) -// NewStyleConfig creates a StyleConfig with default values and applies any provided options. +// NewStyleConfig creates a StyleConfig with default values func NewStyleConfig(opts ...StyleOption) StyleConfig { - // Set default values + // Default color scheme config := StyleConfig{ RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), - HunkLineBg: lipgloss.Color("#2A2822"), - HunkLineFg: lipgloss.Color("#D4AF37"), + HunkLineBg: lipgloss.Color("#23252D"), + HunkLineFg: lipgloss.Color("#8CA3B4"), RemovedFg: lipgloss.Color("#7C4444"), AddedFg: lipgloss.Color("#478247"), LineNumberFg: lipgloss.Color("#888888"), @@ -125,56 +131,35 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { return config } -// WithRemovedLineBg sets the background color for removed lines. +// Style option functions func WithRemovedLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedLineBg = color - } + return func(s *StyleConfig) { s.RemovedLineBg = color } } -// WithAddedLineBg sets the background color for added lines. func WithAddedLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedLineBg = color - } + return func(s *StyleConfig) { s.AddedLineBg = color } } -// WithContextLineBg sets the background color for context lines. func WithContextLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.ContextLineBg = color - } + return func(s *StyleConfig) { s.ContextLineBg = color } } -// WithRemovedFg sets the foreground color for removed line markers. func WithRemovedFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedFg = color - } + return func(s *StyleConfig) { s.RemovedFg = color } } -// WithAddedFg sets the foreground color for added line markers. func WithAddedFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedFg = color - } + return func(s *StyleConfig) { s.AddedFg = color } } -// WithLineNumberFg sets the foreground color for line numbers. func WithLineNumberFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.LineNumberFg = color - } + return func(s *StyleConfig) { s.LineNumberFg = color } } -// WithHighlightStyle sets the syntax highlighting style. func WithHighlightStyle(style string) StyleOption { - return func(s *StyleConfig) { - s.HighlightStyle = style - } + return func(s *StyleConfig) { s.HighlightStyle = style } } -// WithRemovedHighlightColors sets the colors for highlighted parts in removed text. func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.RemovedHighlightBg = bg @@ -182,7 +167,6 @@ func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { } } -// WithAddedHighlightColors sets the colors for highlighted parts in added text. func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.AddedHighlightBg = bg @@ -190,45 +174,35 @@ func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { } } -// WithRemovedLineNumberBg sets the background color for removed line numbers. func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedLineNumberBg = color - } + return func(s *StyleConfig) { s.RemovedLineNumberBg = color } } -// WithAddedLineNumberBg sets the background color for added line numbers. func WithAddedLineNumberBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedLineNamerBg = color - } + return func(s *StyleConfig) { s.AddedLineNamerBg = color } } func WithHunkLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.HunkLineBg = color - } + return func(s *StyleConfig) { s.HunkLineBg = color } } func WithHunkLineFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.HunkLineFg = color - } + return func(s *StyleConfig) { s.HunkLineFg = color } } // ------------------------------------------------------------------------- -// Parse Options with Option Pattern +// Parse Configuration // ------------------------------------------------------------------------- -// ParseConfig configures the behavior of diff parsing. +// ParseConfig configures the behavior of diff parsing type ParseConfig struct { ContextSize int // Number of context lines to include } -// ParseOption defines a function that modifies a ParseConfig. +// ParseOption modifies a ParseConfig type ParseOption func(*ParseConfig) -// WithContextSize sets the number of context lines to include. +// WithContextSize sets the number of context lines to include func WithContextSize(size int) ParseOption { return func(p *ParseConfig) { if size >= 0 { @@ -238,27 +212,25 @@ func WithContextSize(size int) ParseOption { } // ------------------------------------------------------------------------- -// Side-by-Side Options with Option Pattern +// Side-by-Side Configuration // ------------------------------------------------------------------------- -// SideBySideConfig configures the rendering of side-by-side diffs. +// SideBySideConfig configures the rendering of side-by-side diffs type SideBySideConfig struct { TotalWidth int Style StyleConfig } -// SideBySideOption defines a function that modifies a SideBySideConfig. +// SideBySideOption modifies a SideBySideConfig type SideBySideOption func(*SideBySideConfig) -// NewSideBySideConfig creates a SideBySideConfig with default values and applies any provided options. +// NewSideBySideConfig creates a SideBySideConfig with default values func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { - // Set default values config := SideBySideConfig{ TotalWidth: 160, // Default width for side-by-side view Style: NewStyleConfig(), } - // Apply all provided options for _, opt := range opts { opt(&config) } @@ -266,7 +238,7 @@ func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { return config } -// WithTotalWidth sets the total width for side-by-side view. +// WithTotalWidth sets the total width for side-by-side view func WithTotalWidth(width int) SideBySideOption { return func(s *SideBySideConfig) { if width > 0 { @@ -275,14 +247,14 @@ func WithTotalWidth(width int) SideBySideOption { } } -// WithStyle sets the styling configuration. +// WithStyle sets the styling configuration func WithStyle(style StyleConfig) SideBySideOption { return func(s *SideBySideConfig) { s.Style = style } } -// WithStyleOptions applies the specified style options. +// WithStyleOptions applies the specified style options func WithStyleOptions(opts ...StyleOption) SideBySideOption { return func(s *SideBySideConfig) { s.Style = NewStyleConfig(opts...) @@ -290,10 +262,10 @@ func WithStyleOptions(opts ...StyleOption) SideBySideOption { } // ------------------------------------------------------------------------- -// Diff Parsing and Generation +// Diff Parsing // ------------------------------------------------------------------------- -// ParseUnifiedDiff parses a unified diff format string into structured data. +// ParseUnifiedDiff parses a unified diff format string into structured data func ParseUnifiedDiff(diff string) (DiffResult, error) { var result DiffResult var currentHunk *Hunk @@ -305,7 +277,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { inFileHeader := true for _, line := range lines { - // Parse the file headers + // Parse file headers if inFileHeader { if strings.HasPrefix(line, "--- a/") { result.OldFile = strings.TrimPrefix(line, "--- a/") @@ -332,27 +304,27 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { newStart, _ := strconv.Atoi(matches[3]) oldLine = oldStart newLine = newStart - continue } - // ignore the \\ No newline at end of file + // Ignore "No newline at end of file" markers if strings.HasPrefix(line, "\\ No newline at end of file") { continue } + if currentHunk == nil { continue } + // Process the line based on its prefix if len(line) > 0 { - // Process the line based on its prefix switch line[0] { case '+': currentHunk.Lines = append(currentHunk.Lines, DiffLine{ OldLineNo: 0, NewLineNo: newLine, Kind: LineAdded, - Content: line[1:], // skip '+' + Content: line[1:], }) newLine++ case '-': @@ -360,7 +332,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { OldLineNo: oldLine, NewLineNo: 0, Kind: LineRemoved, - Content: line[1:], // skip '-' + Content: line[1:], }) oldLine++ default: @@ -394,14 +366,13 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { return result, nil } -// HighlightIntralineChanges updates the content of lines in a hunk to show -// character-level differences within lines. +// HighlightIntralineChanges updates lines in a hunk to show character-level differences func HighlightIntralineChanges(h *Hunk, style StyleConfig) { var updated []DiffLine dmp := diffmatchpatch.New() for i := 0; i < len(h.Lines); i++ { - // Look for removed line followed by added line, which might have similar content + // Look for removed line followed by added line if i+1 < len(h.Lines) && h.Lines[i].Kind == LineRemoved && h.Lines[i+1].Kind == LineAdded { @@ -411,12 +382,40 @@ func HighlightIntralineChanges(h *Hunk, style StyleConfig) { // Find character-level differences patches := dmp.DiffMain(oldLine.Content, newLine.Content, false) - patches = dmp.DiffCleanupEfficiency(patches) patches = dmp.DiffCleanupSemantic(patches) + patches = dmp.DiffCleanupMerge(patches) + patches = dmp.DiffCleanupEfficiency(patches) - // Apply highlighting to the differences - oldLine.Content = colorizeSegments(patches, true, style) - newLine.Content = colorizeSegments(patches, false, style) + segments := make([]Segment, 0) + + removeStart := 0 + addStart := 0 + for _, patch := range patches { + switch patch.Type { + case diffmatchpatch.DiffDelete: + segments = append(segments, Segment{ + Start: removeStart, + End: removeStart + len(patch.Text), + Type: LineRemoved, + Text: patch.Text, + }) + removeStart += len(patch.Text) + case diffmatchpatch.DiffInsert: + segments = append(segments, Segment{ + Start: addStart, + End: addStart + len(patch.Text), + Type: LineAdded, + Text: patch.Text, + }) + addStart += len(patch.Text) + default: + // Context text, no highlighting needed + removeStart += len(patch.Text) + addStart += len(patch.Text) + } + } + oldLine.Segments = segments + newLine.Segments = segments updated = append(updated, oldLine, newLine) i++ // Skip the next line as we've already processed it @@ -428,45 +427,7 @@ func HighlightIntralineChanges(h *Hunk, style StyleConfig) { h.Lines = updated } -// colorizeSegments applies styles to the character-level diff segments. -func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig) string { - var buf strings.Builder - - removeBg := lipgloss.NewStyle(). - Background(style.RemovedHighlightBg). - Foreground(style.RemovedHighlightFg) - - addBg := lipgloss.NewStyle(). - Background(style.AddedHighlightBg). - Foreground(style.AddedHighlightFg) - - removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) - addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) - - for _, d := range diffs { - switch d.Type { - case diffmatchpatch.DiffEqual: - // Handle text that's the same in both versions - buf.WriteString(d.Text) - case diffmatchpatch.DiffDelete: - // Handle deleted text (only show in old version) - if isOld { - buf.WriteString(removeBg.Render(d.Text)) - buf.WriteString(removedLineStyle.Render("")) - } - case diffmatchpatch.DiffInsert: - // Handle inserted text (only show in new version) - if !isOld { - buf.WriteString(addBg.Render(d.Text)) - buf.WriteString(addedLineStyle.Render("")) - } - } - } - - return buf.String() -} - -// pairLines converts a flat list of diff lines to pairs for side-by-side display. +// pairLines converts a flat list of diff lines to pairs for side-by-side display func pairLines(lines []DiffLine) []linePair { var pairs []linePair i := 0 @@ -498,7 +459,7 @@ func pairLines(lines []DiffLine) []linePair { // Syntax Highlighting // ------------------------------------------------------------------------- -// SyntaxHighlight applies syntax highlighting to a string based on the file extension. +// SyntaxHighlight applies syntax highlighting to text based on file extension func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error { // Determine the language lexer to use l := lexers.Match(fileName) @@ -515,21 +476,98 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos if f == nil { f = formatters.Fallback } - - // Get the style - s := styles.Get("dracula") - if s == nil { - s = styles.Fallback - } - + theme := ` + +` + + r := strings.NewReader(theme) + style := chroma.MustNewXMLStyle(r) // Modify the style to use the provided background - s, err := s.Builder().Transform( + s, err := style.Builder().Transform( func(t chroma.StyleEntry) chroma.StyleEntry { r, g, b, _ := bg.RGBA() - ru8 := uint8(r >> 8) - gu8 := uint8(g >> 8) - bu8 := uint8(b >> 8) - t.Background = chroma.NewColour(ru8, gu8, bu8) + t.Background = chroma.NewColour(uint8(r>>8), uint8(g>>8), uint8(b>>8)) return t }, ).Build() @@ -546,7 +584,7 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos return f.Format(w, s, it) } -// highlightLine applies syntax highlighting to a single line. +// highlightLine applies syntax highlighting to a single line func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string { var buf bytes.Buffer err := SyntaxHighlight(&buf, line, fileName, "terminal16m", bg) @@ -556,7 +594,7 @@ func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) stri return buf.String() } -// createStyles generates the lipgloss styles needed for rendering diffs. +// createStyles generates the lipgloss styles needed for rendering diffs func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) { removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg) addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg) @@ -566,7 +604,106 @@ func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, context return } -// renderLeftColumn formats the left side of a side-by-side diff. +// ------------------------------------------------------------------------- +// Rendering Functions +// ------------------------------------------------------------------------- + +// applyHighlighting applies intra-line highlighting to a piece of text +func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.Color, +) string { + // Find all ANSI sequences in the content + ansiRegex := regexp.MustCompile(`\x1b(?:[@-Z\\-_]|\[[0-9?]*(?:;[0-9?]*)*[@-~])`) + ansiMatches := ansiRegex.FindAllStringIndex(content, -1) + + // Build a mapping of visible character positions to their actual indices + visibleIdx := 0 + ansiSequences := make(map[int]string) + lastAnsiSeq := "\x1b[0m" // Default reset sequence + + for i := 0; i < len(content); { + isAnsi := false + for _, match := range ansiMatches { + if match[0] == i { + ansiSequences[visibleIdx] = content[match[0]:match[1]] + lastAnsiSeq = content[match[0]:match[1]] + i = match[1] + isAnsi = true + break + } + } + if isAnsi { + continue + } + + // For non-ANSI positions, store the last ANSI sequence + if _, exists := ansiSequences[visibleIdx]; !exists { + ansiSequences[visibleIdx] = lastAnsiSeq + } + visibleIdx++ + i++ + } + + // Apply highlighting + var sb strings.Builder + inSelection := false + currentPos := 0 + + for i := 0; i < len(content); { + // Check if we're at an ANSI sequence + isAnsi := false + for _, match := range ansiMatches { + if match[0] == i { + sb.WriteString(content[match[0]:match[1]]) // Preserve ANSI sequence + i = match[1] + isAnsi = true + break + } + } + if isAnsi { + continue + } + + // Check for segment boundaries + for _, seg := range segments { + if seg.Type == segmentType { + if currentPos == seg.Start { + inSelection = true + } + if currentPos == seg.End { + inSelection = false + } + } + } + + // Get current character + char := string(content[i]) + + if inSelection { + // Get the current styling + currentStyle := ansiSequences[currentPos] + + // Apply background highlight + sb.WriteString("\x1b[48;2;") + r, g, b, _ := highlightBg.RGBA() + sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8)) + sb.WriteString(char) + sb.WriteString("\x1b[49m") // Reset only background + + // Reapply the original ANSI sequence + sb.WriteString(currentStyle) + } else { + // Not in selection, just copy the character + sb.WriteString(char) + } + + currentPos++ + i++ + } + + return sb.String() +} + +// renderLeftColumn formats the left side of a side-by-side diff func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { if dl == nil { contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) @@ -575,9 +712,9 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles) + // Determine line style based on line type var marker string var bgStyle lipgloss.Style - switch dl.Kind { case LineRemoved: marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-") @@ -591,18 +728,29 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC bgStyle = contextLineStyle } + // Format line number lineNum := "" if dl.OldLineNo > 0 { lineNum = fmt.Sprintf("%6d", dl.OldLineNo) } + // Create the line prefix prefix := lineNumberStyle.Render(lineNum + " " + marker) + + // Apply syntax highlighting content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + // Apply intra-line highlighting for removed lines + if dl.Kind == LineRemoved && len(dl.Segments) > 0 { + content = applyHighlighting(content, dl.Segments, LineRemoved, styles.RemovedHighlightBg) + } + + // Add a padding space for removed lines if dl.Kind == LineRemoved { content = bgStyle.Render(" ") + content } + // Create the final line and truncate if needed lineText := prefix + content return bgStyle.MaxHeight(1).Width(colWidth).Render( ansi.Truncate( @@ -613,7 +761,7 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC ) } -// renderRightColumn formats the right side of a side-by-side diff. +// renderRightColumn formats the right side of a side-by-side diff func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { if dl == nil { contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) @@ -622,9 +770,9 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style _, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles) + // Determine line style based on line type var marker string var bgStyle lipgloss.Style - switch dl.Kind { case LineAdded: marker = addedLineStyle.Foreground(styles.AddedFg).Render("+") @@ -638,18 +786,29 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style bgStyle = contextLineStyle } + // Format line number lineNum := "" if dl.NewLineNo > 0 { lineNum = fmt.Sprintf("%6d", dl.NewLineNo) } + // Create the line prefix prefix := lineNumberStyle.Render(lineNum + " " + marker) + + // Apply syntax highlighting content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + // Apply intra-line highlighting for added lines + if dl.Kind == LineAdded && len(dl.Segments) > 0 { + content = applyHighlighting(content, dl.Segments, LineAdded, styles.AddedHighlightBg) + } + + // Add a padding space for added lines if dl.Kind == LineAdded { content = bgStyle.Render(" ") + content } + // Create the final line and truncate if needed lineText := prefix + content return bgStyle.MaxHeight(1).Width(colWidth).Render( ansi.Truncate( @@ -661,10 +820,10 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style } // ------------------------------------------------------------------------- -// Public API Methods +// Public API // ------------------------------------------------------------------------- -// RenderSideBySideHunk formats a hunk for side-by-side display. +// RenderSideBySideHunk formats a hunk for side-by-side display func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) string { // Apply options to create the configuration config := NewSideBySideConfig(opts...) @@ -692,7 +851,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str return sb.String() } -// FormatDiff creates a side-by-side formatted view of a diff. +// FormatDiff creates a side-by-side formatted view of a diff func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { diffResult, err := ParseUnifiedDiff(diffText) if err != nil { @@ -700,11 +859,18 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { } var sb strings.Builder - config := NewSideBySideConfig(opts...) + for i, h := range diffResult.Hunks { if i > 0 { - sb.WriteString(lipgloss.NewStyle().Background(config.Style.HunkLineBg).Foreground(config.Style.HunkLineFg).Width(config.TotalWidth).Render(h.Header) + "\n") + // Render hunk header + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) } sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } @@ -712,14 +878,16 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { return sb.String(), nil } -// GenerateDiff creates a unified diff from two file contents. +// GenerateDiff creates a unified diff from two file contents func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + // Create temporary directory for git operations tempDir, err := os.MkdirTemp("", "git-diff-temp") if err != nil { return "", 0, 0 } defer os.RemoveAll(tempDir) + // Initialize git repo repo, err := git.PlainInit(tempDir, false) if err != nil { return "", 0, 0 @@ -730,6 +898,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Write the "before" content and commit it fullPath := filepath.Join(tempDir, fileName) if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { return "", 0, 0 @@ -754,7 +923,9 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Write the "after" content and commit it if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { + return "", 0, 0 } _, err = wt.Add(fileName) @@ -773,6 +944,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Get the diff between the two commits beforeCommitObj, err := repo.CommitObject(beforeCommit) if err != nil { return "", 0, 0 @@ -788,6 +960,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Count additions and removals additions := 0 removals := 0 for _, fileStat := range patch.Stats() { -- cgit v1.2.3 From bbfa60c787f2ec459f1689b9a650ddbec9693ed9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 20:06:23 +0200 Subject: reimplement agent,provider and add file history --- .opencode.json | 4 - README.md | 34 +- cmd/root.go | 24 +- go.mod | 8 +- go.sum | 14 - internal/app/app.go | 17 +- internal/app/lsp.go | 19 +- internal/config/config.go | 108 +++- internal/db/files.sql.go | 4 +- internal/db/sql/files.sql | 4 +- internal/diff/diff.go | 99 ++- internal/llm/agent/agent-tool.go | 18 +- internal/llm/agent/agent.go | 861 ++++++++++----------------- internal/llm/agent/coder.go | 63 -- internal/llm/agent/mcp-tools.go | 4 +- internal/llm/agent/task.go | 47 -- internal/llm/agent/tools.go | 50 ++ internal/llm/models/anthropic.go | 71 +++ internal/llm/models/models.go | 190 +++--- internal/llm/prompt/coder.go | 28 +- internal/llm/prompt/prompt.go | 19 + internal/llm/prompt/task.go | 5 +- internal/llm/prompt/title.go | 4 +- internal/llm/provider/anthropic.go | 531 +++++++++-------- internal/llm/provider/bedrock.go | 101 ++-- internal/llm/provider/gemini.go | 533 +++++++++++------ internal/llm/provider/openai.go | 401 ++++++++----- internal/llm/provider/provider.go | 169 ++++-- internal/llm/tools/bash.go | 7 +- internal/llm/tools/bash_test.go | 31 - internal/llm/tools/edit.go | 75 ++- internal/llm/tools/edit_test.go | 30 +- internal/llm/tools/file.go | 10 - internal/llm/tools/glob.go | 4 +- internal/llm/tools/grep.go | 4 +- internal/llm/tools/ls.go | 4 +- internal/llm/tools/mocks_test.go | 246 ++++++++ internal/llm/tools/shell/shell.go | 12 +- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/tools.go | 9 +- internal/llm/tools/write.go | 27 +- internal/llm/tools/write_test.go | 22 +- internal/logging/logger.go | 41 +- internal/lsp/client.go | 13 +- internal/lsp/handlers.go | 2 +- internal/lsp/transport.go | 28 +- internal/lsp/watcher/watcher.go | 18 +- internal/message/content.go | 30 +- internal/pubsub/broker.go | 2 +- internal/session/session.go | 15 + internal/tui/components/chat/chat.go | 2 - internal/tui/components/chat/editor.go | 22 +- internal/tui/components/chat/messages.go | 205 ++++++- internal/tui/components/chat/sidebar.go | 176 +++++- internal/tui/components/core/dialog.go | 117 ---- internal/tui/components/core/help.go | 119 ---- internal/tui/components/core/status.go | 90 ++- internal/tui/components/dialog/help.go | 182 ++++++ internal/tui/components/dialog/permission.go | 682 +++++++++++---------- internal/tui/components/dialog/quit.go | 156 +++-- internal/tui/components/logs/details.go | 2 - internal/tui/components/logs/table.go | 22 - internal/tui/components/repl/editor.go | 201 ------- internal/tui/components/repl/messages.go | 513 ---------------- internal/tui/components/repl/sessions.go | 249 -------- internal/tui/layout/overlay.go | 11 +- internal/tui/layout/split.go | 1 + internal/tui/page/chat.go | 32 +- internal/tui/page/init.go | 308 ---------- internal/tui/page/logs.go | 17 + internal/tui/page/repl.go | 21 - internal/tui/tui.go | 277 ++++----- main.go | 7 + 73 files changed, 3595 insertions(+), 3879 deletions(-) delete mode 100644 internal/llm/agent/coder.go delete mode 100644 internal/llm/agent/task.go create mode 100644 internal/llm/agent/tools.go create mode 100644 internal/llm/models/anthropic.go create mode 100644 internal/llm/prompt/prompt.go create mode 100644 internal/llm/tools/mocks_test.go delete mode 100644 internal/tui/components/core/dialog.go delete mode 100644 internal/tui/components/core/help.go create mode 100644 internal/tui/components/dialog/help.go delete mode 100644 internal/tui/components/repl/editor.go delete mode 100644 internal/tui/components/repl/messages.go delete mode 100644 internal/tui/components/repl/sessions.go delete mode 100644 internal/tui/page/init.go delete mode 100644 internal/tui/page/repl.go (limited to 'internal/diff') diff --git a/.opencode.json b/.opencode.json index f63a63dba..b7fc19b52 100644 --- a/.opencode.json +++ b/.opencode.json @@ -1,8 +1,4 @@ { - "model": { - "coder": "claude-3.7-sonnet", - "coderMaxTokens": 20000 - }, "lsp": { "gopls": { "command": "gopls" diff --git a/README.md b/README.md index 23a1906a1..564284c7f 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# TermAI +# OpenCode > **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk. A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal. -[![TermAI Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) +[![OpenCode Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) ## Overview -TermAI is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. +OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. ## Features @@ -23,16 +23,16 @@ TermAI is a Go-based CLI application that brings AI assistance to your terminal. ```bash # Coming soon -go install github.com/kujtimiihoxha/termai@latest +go install github.com/kujtimiihoxha/opencode@latest ``` ## Configuration -TermAI looks for configuration in the following locations: +OpenCode looks for configuration in the following locations: -- `$HOME/.termai.json` -- `$XDG_CONFIG_HOME/termai/.termai.json` -- `./.termai.json` (local directory) +- `$HOME/.opencode.json` +- `$XDG_CONFIG_HOME/opencode/.opencode.json` +- `./.opencode.json` (local directory) You can also use environment variables: @@ -43,11 +43,11 @@ You can also use environment variables: ## Usage ```bash -# Start TermAI -termai +# Start OpenCode +opencode # Start with debug logging -termai -d +opencode -d ``` ### Keyboard Shortcuts @@ -81,7 +81,7 @@ termai -d ## Architecture -TermAI is built with a modular architecture: +OpenCode is built with a modular architecture: - **cmd**: Command-line interface using Cobra - **internal/app**: Core application services @@ -103,22 +103,22 @@ TermAI is built with a modular architecture: ```bash # Clone the repository -git clone https://github.com/kujtimiihoxha/termai.git -cd termai +git clone https://github.com/kujtimiihoxha/opencode.git +cd opencode # Build the diff script first go run cmd/diff/main.go # Build -go build -o termai +go build -o opencode # Run -./termai +./opencode ``` ## Acknowledgments -TermAI builds upon the work of several open source projects and developers: +OpenCode builds upon the work of several open source projects and developers: - [@isaacphi](https://github.com/isaacphi) - LSP client implementation diff --git a/cmd/root.go b/cmd/root.go index a2e63006f..ff71747d5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,7 +20,7 @@ import ( ) var rootCmd = &cobra.Command{ - Use: "termai", + Use: "OpenCode", Short: "A terminal ai assistant", Long: `A terminal ai assistant`, RunE: func(cmd *cobra.Command, args []string) error { @@ -89,12 +89,9 @@ var rootCmd = &cobra.Command{ // Set up message handling for the TUI go func() { defer tuiWg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in TUI message handling: %v", r) - attemptTUIRecovery(program) - } - }() + defer logging.RecoverPanic("TUI-message-handler", func() { + attemptTUIRecovery(program) + }) for { select { @@ -153,11 +150,7 @@ func attemptTUIRecovery(program *tea.Program) { func initMCPTools(ctx context.Context, app *app.App) { go func() { - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in MCP goroutine: %v", r) - } - }() + defer logging.RecoverPanic("MCP-goroutine", nil) // Create a context with timeout for the initial MCP tools fetch ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -179,11 +172,7 @@ func setupSubscriber[T any]( wg.Add(1) go func() { defer wg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in %s subscription goroutine: %v", name, r) - } - }() + defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil) for { select { @@ -232,6 +221,7 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { // Wait with a timeout for all goroutines to complete waitCh := make(chan struct{}) go func() { + defer logging.RecoverPanic("subscription-cleanup", nil) wg.Wait() close(waitCh) }() diff --git a/go.mod b/go.mod index 925a71097..16c88d3a6 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 github.com/google/uuid v1.6.0 - github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 github.com/mark3labs/mcp-go v0.17.0 github.com/mattn/go-runewidth v0.0.16 @@ -36,7 +35,6 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 - golang.org/x/net v0.39.0 google.golang.org/api v0.215.0 ) @@ -106,7 +104,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect - github.com/sahilm/fuzzy v0.1.1 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect @@ -129,11 +126,8 @@ require ( go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect - golang.org/x/image v0.14.0 // indirect - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect diff --git a/go.sum b/go.sum index 9c2c2df8f..4832271f2 100644 --- a/go.sum +++ b/go.sum @@ -180,10 +180,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 h1:xYfCLI8KUwmXDFp1pOpNX+XsQczQw9VbEuju1pQF5/A= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9/go.mod h1:Ye+kIkTmPO5xuqCQ+PPHDTGIViRRoSpSIlcYgma8YlA= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 h1:9rjt7AfnrXKNSZhp36A3/4QAZAwGGCGD/p8Bse26zms= @@ -235,8 +231,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= -github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= -github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -302,8 +296,6 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= -golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -314,12 +306,6 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= -golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= -golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/internal/app/app.go b/internal/app/app.go index ca23b3c40..1c16ccc11 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -20,7 +21,7 @@ import ( type App struct { Sessions session.Service Messages message.Service - Files history.Service + History history.Service Permissions permission.Service CoderAgent agent.Service @@ -43,7 +44,7 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app := &App{ Sessions: sessions, Messages: messages, - Files: files, + History: files, Permissions: permission.NewPermissionService(), LSPClients: make(map[string]*lsp.Client), } @@ -51,11 +52,17 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app.initLSPClients(ctx) var err error - app.CoderAgent, err = agent.NewCoderAgent( - app.Permissions, + app.CoderAgent, err = agent.NewAgent( + config.AgentCoder, app.Sessions, app.Messages, - app.LSPClients, + agent.CoderAgentTools( + app.Permissions, + app.Sessions, + app.Messages, + app.History, + app.LSPClients, + ), ) if err != nil { logging.Error("Failed to create coder agent", err) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4e0568f07..4a762f1a1 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -22,16 +22,17 @@ func (app *App) initLSPClients(ctx context.Context) { // createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { // Create a specific context for initialization with a timeout - initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer initCancel() // Create the LSP client - lspClient, err := lsp.NewClient(initCtx, command, args...) + lspClient, err := lsp.NewClient(ctx, command, args...) if err != nil { logging.Error("Failed to create LSP client for", name, err) return + } + initCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() // Initialize with the initialization context _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) if err != nil { @@ -64,14 +65,10 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman // runWorkspaceWatcher executes the workspace watcher for an LSP client func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) { defer app.watcherWG.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("LSP client crashed", "client", name, "panic", r) - - // Try to restart the client - app.restartLSPClient(ctx, name) - } - }() + defer logging.RecoverPanic("LSP-"+name, func() { + // Try to restart the client + app.restartLSPClient(ctx, name) + }) workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) logging.Info("Workspace watcher stopped", "client", name) diff --git a/internal/config/config.go b/internal/config/config.go index f0afbdd3c..147d6c83a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -31,12 +31,18 @@ type MCPServer struct { Headers map[string]string `json:"headers"` } -// Model defines configuration for different LLM models and their token limits. -type Model struct { - Coder models.ModelID `json:"coder"` - CoderMaxTokens int64 `json:"coderMaxTokens"` - Task models.ModelID `json:"task"` - TaskMaxTokens int64 `json:"taskMaxTokens"` +type AgentName string + +const ( + AgentCoder AgentName = "coder" + AgentTask AgentName = "task" + AgentTitle AgentName = "title" +) + +// Agent defines configuration for different LLM models and their token limits. +type Agent struct { + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` } // Provider defines configuration for an LLM provider. @@ -65,8 +71,9 @@ type Config struct { MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` LSP map[string]LSPConfig `json:"lsp,omitempty"` - Model Model `json:"model"` + Agents map[AgentName]Agent `json:"agents"` Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debugLSP,omitempty"` } // Application constants @@ -118,11 +125,42 @@ func Load(workingDir string, debug bool) (*Config, error) { if cfg.Debug { defaultLevel = slog.LevelDebug } - // Configure logger - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) + // if we are in debug mode make the writer a file + if cfg.Debug { + loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + + // if file does not exist create it + if _, err := os.Stat(loggingFile); os.IsNotExist(err) { + if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil { + return cfg, fmt.Errorf("failed to create directory: %w", err) + } + if _, err := os.Create(loggingFile); err != nil { + return cfg, fmt.Errorf("failed to create log file: %w", err) + } + } + + sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) + if err != nil { + return cfg, fmt.Errorf("failed to open log file: %w", err) + } + // Configure logger + logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } else { + // Configure logger + logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } + + // Override the max tokens for title agent + cfg.Agents[AgentTitle] = Agent{ + Model: cfg.Agents[AgentTitle].Model, + MaxTokens: 80, + } return cfg, nil } @@ -159,44 +197,50 @@ func setProviderDefaults() { // Groq configuration if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { viper.SetDefault("providers.groq.apiKey", apiKey) - viper.SetDefault("model.coder", models.QWENQwq) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.QWENQwq) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.QWENQwq) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.QWENQwq) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.QWENQwq) } // Google Gemini configuration if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) - viper.SetDefault("model.coder", models.GRMINI20Flash) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GRMINI20Flash) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GRMINI20Flash) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GRMINI20Flash) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GRMINI20Flash) } // OpenAI configuration if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) - viper.SetDefault("model.coder", models.GPT4o) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GPT4o) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GPT4o) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GPT4o) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GPT4o) + } // Anthropic configuration if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) - viper.SetDefault("model.coder", models.Claude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.Claude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.Claude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.Claude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.Claude37Sonnet) } if hasAWSCredentials() { - viper.SetDefault("model.coder", models.BedrockClaude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.BedrockClaude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) } } diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index b45731098..39def271f 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -97,7 +97,9 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) { const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one SELECT id, session_id, path, content, version, created_at, updated_at FROM files -WHERE path = ? AND session_id = ? LIMIT 1 +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1 ` type GetFileByPathAndSessionParams struct { diff --git a/internal/db/sql/files.sql b/internal/db/sql/files.sql index c2e799076..aba2a6111 100644 --- a/internal/db/sql/files.sql +++ b/internal/db/sql/files.sql @@ -6,7 +6,9 @@ WHERE id = ? LIMIT 1; -- name: GetFileByPathAndSession :one SELECT * FROM files -WHERE path = ? AND session_id = ? LIMIT 1; +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1; -- name: ListFilesBySession :many SELECT * diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 02d4d7140..829554c7e 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,6 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) @@ -77,6 +79,8 @@ type linePair struct { // StyleConfig defines styling for diff rendering type StyleConfig struct { + ShowHeader bool + FileNameFg lipgloss.Color // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color @@ -106,11 +110,13 @@ type StyleOption func(*StyleConfig) func NewStyleConfig(opts ...StyleOption) StyleConfig { // Default color scheme config := StyleConfig{ + ShowHeader: true, + FileNameFg: lipgloss.Color("#fab283"), RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), - HunkLineBg: lipgloss.Color("#23252D"), - HunkLineFg: lipgloss.Color("#8CA3B4"), + HunkLineBg: lipgloss.Color("#212121"), + HunkLineFg: lipgloss.Color("#a0a0a0"), RemovedFg: lipgloss.Color("#7C4444"), AddedFg: lipgloss.Color("#478247"), LineNumberFg: lipgloss.Color("#888888"), @@ -132,6 +138,10 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { } // Style option functions +func WithFileNameFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { s.FileNameFg = color } +} + func WithRemovedLineBg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.RemovedLineBg = color } } @@ -190,6 +200,10 @@ func WithHunkLineFg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.HunkLineFg = color } } +func WithShowHeader(show bool) StyleOption { + return func(s *StyleConfig) { s.ShowHeader = show } +} + // ------------------------------------------------------------------------- // Parse Configuration // ------------------------------------------------------------------------- @@ -841,10 +855,12 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str // Calculate column width colWidth := config.TotalWidth / 2 + leftWidth := colWidth + rightWidth := config.TotalWidth - colWidth var sb strings.Builder for _, p := range pairs { - leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style) - rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style) + leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style) + rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style) sb.WriteString(leftStr + rightStr + "\n") } @@ -861,17 +877,50 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { var sb strings.Builder config := NewSideBySideConfig(opts...) - for i, h := range diffResult.Hunks { - if i > 0 { - // Render hunk header - sb.WriteString( - lipgloss.NewStyle(). - Background(config.Style.HunkLineBg). - Foreground(config.Style.HunkLineFg). - Width(config.TotalWidth). - Render(h.Header) + "\n", - ) - } + if config.Style.ShowHeader { + removeIcon := lipgloss.NewStyle(). + Background(config.Style.RemovedLineBg). + Foreground(config.Style.RemovedFg). + Render("⏹") + addIcon := lipgloss.NewStyle(). + Background(config.Style.AddedLineBg). + Foreground(config.Style.AddedFg). + Render("⏹") + + fileName := lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Foreground(config.Style.FileNameFg). + Render(" " + diffResult.OldFile) + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Padding(0, 1, 0, 1). + Foreground(config.Style.FileNameFg). + BorderStyle(lipgloss.NormalBorder()). + BorderTop(true). + BorderBottom(true). + BorderForeground(config.Style.FileNameFg). + BorderBackground(config.Style.ContextLineBg). + Width(config.TotalWidth). + Render( + lipgloss.JoinHorizontal(lipgloss.Top, + removeIcon, + addIcon, + fileName, + ), + ) + "\n", + ) + } + + for _, h := range diffResult.Hunks { + // Render hunk header + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } @@ -880,9 +929,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { // GenerateDiff creates a unified diff from two file contents func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + // remove the cwd prefix and ensure consistent path format + // this prevents issues with absolute paths in different environments + cwd := config.WorkingDirectory() + fileName = strings.TrimPrefix(fileName, cwd) + fileName = strings.TrimPrefix(fileName, "/") // Create temporary directory for git operations - tempDir, err := os.MkdirTemp("", "git-diff-temp") + tempDir, err := os.MkdirTemp("", fmt.Sprintf("git-diff-%d", time.Now().UnixNano())) if err != nil { + logging.Error("Failed to create temp directory for git diff", "error", err) return "", 0, 0 } defer os.RemoveAll(tempDir) @@ -890,25 +945,30 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in // Initialize git repo repo, err := git.PlainInit(tempDir, false) if err != nil { + logging.Error("Failed to initialize git repository", "error", err) return "", 0, 0 } wt, err := repo.Worktree() if err != nil { + logging.Error("Failed to get git worktree", "error", err) return "", 0, 0 } // Write the "before" content and commit it fullPath := filepath.Join(tempDir, fileName) if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + logging.Error("Failed to create directory for file", "error", err) return "", 0, 0 } if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil { + logging.Error("Failed to write before content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -920,16 +980,19 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit before content", "error", err) return "", 0, 0 } // Write the "after" content and commit it if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { + logging.Error("Failed to write after content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -941,22 +1004,26 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit after content", "error", err) return "", 0, 0 } // Get the diff between the two commits beforeCommitObj, err := repo.CommitObject(beforeCommit) if err != nil { + logging.Error("Failed to get before commit object", "error", err) return "", 0, 0 } afterCommitObj, err := repo.CommitObject(afterCommit) if err != nil { + logging.Error("Failed to get after commit object", "error", err) return "", 0, 0 } patch, err := beforeCommitObj.Patch(afterCommitObj) if err != nil { + logging.Error("Failed to create git diff patch", "error", err) return "", 0, 0 } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 83160bb64..308412bde 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) + agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - err = agent.Generate(ctx, session.ID, params.Prompt) + done, err := agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } - - messages, err := b.messages.List(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) - } - - if len(messages) == 0 { - return tools.NewTextErrorResponse("no response"), nil + result := <-done + if result.Err() != nil { + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err()) } - response := messages[len(messages)-1] + response := result.Response() if response.Role != message.Assistant { return tools.NewTextErrorResponse("no response"), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 1958111a1..ab2742ec1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "runtime/debug" "strings" "sync" @@ -16,133 +14,101 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/session" ) // Common errors var ( - ErrProviderNotEnabled = errors.New("provider is not enabled") - ErrRequestCancelled = errors.New("request cancelled by user") - ErrSessionBusy = errors.New("session is currently processing another request") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") ) -// Service defines the interface for generating responses +type AgentEvent struct { + message message.Message + err error +} + +func (e *AgentEvent) Err() error { + return e.err +} + +func (e *AgentEvent) Response() message.Message { + return e.message +} + type Service interface { - Generate(ctx context.Context, sessionID string, content string) error - Cancel(sessionID string) error + Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) + Cancel(sessionID string) + IsSessionBusy(sessionID string) bool } type agent struct { - sessions session.Service - messages message.Service - model models.Model - tools []tools.BaseTool - agent provider.Provider - titleGenerator provider.Provider - activeRequests sync.Map // map[sessionID]context.CancelFunc + sessions session.Service + messages message.Service + + tools []tools.BaseTool + provider provider.Provider + + titleProvider provider.Provider + + activeRequests sync.Map } -// NewAgent creates a new agent instance with the given model and tools -func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) +func NewAgent( + agentName config.AgentName, + sessions session.Service, + messages message.Service, + agentTools []tools.BaseTool, +) (Service, error) { + agentProvider, err := createAgentProvider(agentName) if err != nil { - return nil, fmt.Errorf("failed to initialize providers: %w", err) + return nil, err + } + var titleProvider provider.Provider + // Only generate titles for the coder agent + if agentName == config.AgentCoder { + titleProvider, err = createAgentProvider(config.AgentTitle) + if err != nil { + return nil, err + } } - return &agent{ - model: model, - tools: tools, - sessions: sessions, + agent := &agent{ + provider: agentProvider, messages: messages, - agent: agentProvider, - titleGenerator: titleGenerator, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, activeRequests: sync.Map{}, - }, nil + } + + return agent, nil } -// Cancel cancels an active request by session ID -func (a *agent) Cancel(sessionID string) error { +func (a *agent) Cancel(sessionID string) { if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) cancel() - return nil } } - return errors.New("no active request found for this session") } -// Generate starts the generation process -func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { - // Check if this session already has an active request - if _, busy := a.activeRequests.Load(sessionID); busy { - return ErrSessionBusy - } - - // Create a cancellable context - genCtx, cancel := context.WithCancel(ctx) - - // Store cancel function to allow user cancellation - a.activeRequests.Store(sessionID, cancel) - - // Launch the generation in a goroutine - go func() { - defer func() { - if r := recover(); r != nil { - logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) - - // dump stack trace into a file - file, err := os.Create("panic.log") - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) - return - } - - defer file.Close() - - stackTrace := debug.Stack() - if _, err := file.Write(stackTrace); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) - } - - } - }() - defer a.activeRequests.Delete(sessionID) - defer cancel() - - if err := a.generate(genCtx, sessionID, content); err != nil { - if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { - // Log the error (avoid logging cancellations as they're expected) - logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) - - // You may want to create an error message in the chat - bgCtx := context.Background() - errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) - _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ - Role: message.System, - Parts: []message.ContentPart{ - message.TextContent{ - Text: errorMsg, - }, - }, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) - } - } - } - }() - - return nil -} - -// IsSessionBusy checks if a session currently has an active request func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy -} // handleTitleGeneration asynchronously generates a title for new sessions -func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := a.titleGenerator.SendMessages( +} + +func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { + if a.titleProvider == nil { + return nil + } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return err + } + response, err := a.titleProvider.SendMessages( ctx, []message.Message{ { @@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st }, }, }, - nil, + make([]tools.BaseTool, 0), ) if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) - return + return err } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) - return + title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " ")) + if title == "" { + return nil } - if response.Content != "" { - session.Title = strings.TrimSpace(response.Content) - session.Title = strings.ReplaceAll(session.Title, "\n", " ") - if _, err := a.sessions.Save(ctx, session); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) - } + session.Title = title + _, err = a.sessions.Save(ctx, session) + return err +} + +func (a *agent) err(err error) AgentEvent { + return AgentEvent{ + err: err, } } -// TrackUsage updates token usage statistics for the session -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) +func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) { + events := make(chan AgentEvent) + if a.IsSessionBusy(sessionID) { + return nil, ErrSessionBusy } - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) + genCtx, cancel := context.WithCancel(ctx) + + a.activeRequests.Store(sessionID, cancel) + go func() { + logging.Debug("Request started", "sessionID", sessionID) + defer logging.RecoverPanic("agent.Run", func() { + events <- a.err(fmt.Errorf("panic while running the agent")) + }) - session.Cost += cost - session.CompletionTokens += usage.OutputTokens - session.PromptTokens += usage.InputTokens + result := a.processGeneration(genCtx, sessionID, content) + if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) { + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result)) + } + logging.Debug("Request completed", "sessionID", sessionID) + a.activeRequests.Delete(sessionID) + cancel() + events <- result + close(events) + }() + return events, nil +} - _, err = a.sessions.Save(ctx, session) +func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent { + // List existing messages; if none, start title generation asynchronously. + msgs, err := a.messages.List(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to save session: %w", err) + return a.err(fmt.Errorf("failed to list messages: %w", err)) + } + if len(msgs) == 0 { + go func() { + defer logging.RecoverPanic("agent.Run", func() { + logging.ErrorPersist("panic while generating title") + }) + titleErr := a.generateTitle(context.Background(), sessionID, content) + if titleErr != nil { + logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr)) + } + }() } - return nil -} -// processEvent handles different types of events during generation -func (a *agent) processEvent( - ctx context.Context, - sessionID string, - assistantMsg *message.Message, - event provider.ProviderEvent, -) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing + userMsg, err := a.createUserMessage(ctx, sessionID, content) + if err != nil { + return a.err(fmt.Errorf("failed to create user message: %w", err)) } - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - if errors.Is(event.Error, context.Canceled) { - logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) - return context.Canceled + // Append the new user message to the conversation history. + msgHistory := append(msgs, userMsg) + for { + // Check for cancellation before each iteration + select { + case <-ctx.Done(): + return a.err(ctx.Err()) + default: + // Continue processing } - logging.ErrorPersist(event.Error.Error()) - return event.Error - case provider.EventWarning: - logging.WarnPersist(event.Info) - case provider.EventInfo: - logging.InfoPersist(event.Info) - case provider.EventComplete: - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason) - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) + if err != nil { + if errors.Is(err, context.Canceled) { + return a.err(ErrRequestCancelled) + } + return a.err(fmt.Errorf("failed to process events: %w", err)) + } + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { + // We are not done, we need to respond with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } + return AgentEvent{ + message: agentMessage, } - return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } +} - return nil +func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) { + return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: content}, + }, + }) } -// ExecuteTools runs all tool calls sequentially and returns the results -func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - toolResults := make([]message.ToolResult, len(toolCalls)) +func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + }) + if err != nil { + return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) + } - // Create a child context that can be canceled - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Add the session and message ID into the context if needed by tools. + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - // Check if already canceled before starting any execution - if ctx.Err() != nil { - // Mark all tools as canceled - for i, toolCall := range toolCalls { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled by user", - IsError: true, - } + // Process each event in the stream. + for event := range eventChan { + if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { + a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, processErr + } + if ctx.Err() != nil { + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, ctx.Err() } - return toolResults, ctx.Err() } + toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) + toolCalls := assistantMsg.ToolCalls() for i, toolCall := range toolCalls { - // Check for cancellation before executing each tool select { case <-ctx.Done(): - // Mark this and all remaining tools as canceled + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + // Make all future tool calls cancelled for j := i; j < len(toolCalls); j++ { toolResults[j] = message.ToolResult{ ToolCallID: toolCalls[j].ID, @@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, IsError: true, } } - return toolResults, ctx.Err() + goto out default: // Continue processing - } - - response := "" - isError := false - found := false - - // Find and execute the appropriate tool - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled by user" - } else { - response = fmt.Sprintf("Error running tool: %s", toolErr) - } - isError = true - } else { - response = toolResult.Content - isError = toolResult.IsError + var tool tools.BaseTool + for _, availableTools := range a.tools { + if availableTools.Info().Name == toolCall.Name { + tool = availableTools } - break } - } - - if !found { - response = fmt.Sprintf("Tool not found: %s", toolCall.Name) - isError = true - } - - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - } - return toolResults, nil -} - -// handleToolExecution processes tool calls and creates tool result messages -func (a *agent) handleToolExecution( - ctx context.Context, - assistantMsg message.Message, -) (*message.Message, error) { - select { - case <-ctx.Done(): - // If cancelled, create tool results that indicate cancellation - if len(assistantMsg.ToolCalls()) > 0 { - toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) - for _, tc := range assistantMsg.ToolCalls() { - toolResults = append(toolResults, message.ToolResult{ - ToolCallID: tc.ID, - Content: "Tool execution canceled by user", + // Tool not found + if tool == nil { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), IsError: true, - }) + } + continue } - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) - } - msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, }) - if err != nil { - return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - return &msg, ctx.Err() - } - return nil, ctx.Err() - default: - // Continue processing - } - - if len(assistantMsg.ToolCalls()) == 0 { - return nil, nil - } - - toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) - if err != nil { - // If error is from cancellation, still return the partial results we have - if errors.Is(err, context.Canceled) { - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) + if toolErr != nil { + if errors.Is(toolErr, permission.ErrorPermissionDenied) { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Permission denied", + IsError: true, + } + for j := i + 1; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) + } else { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolErr.Error(), + IsError: true, + } + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Previous tool failed", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + } + // If permission is denied or an error happens we cancel all the following tools + break } - - msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) - return nil, err + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolResult.Content, + Metadata: toolResult.Metadata, + IsError: toolResult.IsError, } - return &msg, err } - return nil, err } - - parts := make([]message.ContentPart, 0, len(toolResults)) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) +out: + if len(toolResults) == 0 { + return assistantMsg, nil, nil } - - msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + parts := make([]message.ContentPart, 0) + for _, tr := range toolResults { + parts = append(parts, tr) + } + msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) if err != nil { - return nil, fmt.Errorf("failed to create tool message: %w", err) + return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) } - return &msg, nil + return assistantMsg, &msg, err } -// generate handles the main generation workflow -func (a *agent) generate(ctx context.Context, sessionID string, content string) error { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) +func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) { + msg.AddFinish(finishReson) + _ = a.messages.Update(ctx, *msg) +} - // Handle context cancellation at any point - if err := ctx.Err(); err != nil { - return ErrRequestCancelled +func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing. } - messages, err := a.messages.List(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to list messages: %w", err) + switch event.Type { + case provider.EventThinkingDelta: + assistantMsg.AppendReasoningContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventContentDelta: + assistantMsg.AppendContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventError: + if errors.Is(event.Error, context.Canceled) { + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled + } + logging.ErrorPersist(event.Error.Error()) + return event.Error + case provider.EventComplete: + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) } - if len(messages) == 0 { - titleCtx := context.Background() - go a.handleTitleGeneration(titleCtx, sessionID, content) - } + return nil +} - userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: content, - }, - }, - }) +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + sess, err := a.sessions.Get(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to create user message: %w", err) + return fmt.Errorf("failed to get session: %w", err) } - messages = append(messages, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - // Continue processing - } - - eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) - if err != nil { - if errors.Is(err, context.Canceled) { - return ErrRequestCancelled - } - return fmt.Errorf("failed to stream response: %w", err) - } - - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.model.ID, - }) - if err != nil { - return fmt.Errorf("failed to create assistant message: %w", err) - } - - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - - // Process events from the LLM provider - for event := range eventChan { - if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { - if errors.Is(err, context.Canceled) { - // Mark as canceled but don't create separate message - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - assistantMsg.AddFinish("error:" + err.Error()) - _ = a.messages.Update(ctx, assistantMsg) - return fmt.Errorf("event processing error: %w", err) - } - - // Check for cancellation during event processing - select { - case <-ctx.Done(): - // Mark as canceled - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - } - - // Check for cancellation before tool execution - select { - case <-ctx.Done(): - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - - // Execute any tool calls - toolMsg, err := a.handleToolExecution(ctx, assistantMsg) - if err != nil { - if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - return fmt.Errorf("tool execution error: %w", err) - } - - if err := a.messages.Update(ctx, assistantMsg); err != nil { - return fmt.Errorf("failed to update assistant message: %w", err) - } - - // If no tool calls, we're done - if len(assistantMsg.ToolCalls()) == 0 { - break - } + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) - // Add messages for next iteration - messages = append(messages, assistantMsg) - if toolMsg != nil { - messages = append(messages, *toolMsg) - } + sess.Cost += cost + sess.CompletionTokens += usage.OutputTokens + sess.PromptTokens += usage.InputTokens - // Check for cancellation after tool execution - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - } + _, err = a.sessions.Save(ctx, sess) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) } - return nil } -// getAgentProviders initializes the LLM providers based on the chosen model -func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { - maxTokens := config.Get().Model.CoderMaxTokens - - providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || providerConfig.Disabled { - return nil, nil, ErrProviderNotEnabled +func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { + cfg := config.Get() + agentConfig, ok := cfg.Agents[agentName] + if !ok { + return nil, fmt.Errorf("agent %s not found", agentName) + } + model, ok := models.SupportedModels[agentConfig.Model] + if !ok { + return nil, fmt.Errorf("model %s not supported", agentConfig.Model) } - var agentProvider provider.Provider - var titleGenerator provider.Provider - var err error - - switch model.Provider { - case models.ProviderOpenAI: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) - } - - case models.ProviderAnthropic: - agentProvider, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithAnthropicMaxTokens(maxTokens), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) - } - - titleGenerator, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithAnthropicMaxTokens(80), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) - } - - case models.ProviderGemini: - agentProvider, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithGeminiMaxTokens(int32(maxTokens)), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) - } - - titleGenerator, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithGeminiMaxTokens(80), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) - } - - case models.ProviderGROQ: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) - } - - case models.ProviderBedrock: - agentProvider, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithBedrockMaxTokens(maxTokens), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) - } - - titleGenerator, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithBedrockMaxTokens(80), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) - } - default: - return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) + providerCfg, ok := cfg.Providers[model.Provider] + if !ok { + return nil, fmt.Errorf("provider %s not supported", model.Provider) + } + if providerCfg.Disabled { + return nil, fmt.Errorf("provider %s is not enabled", model.Provider) + } + agentProvider, err := provider.NewProvider( + model.Provider, + provider.WithAPIKey(providerCfg.APIKey), + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), + provider.WithMaxTokens(agentConfig.MaxTokens), + ) + if err != nil { + return nil, fmt.Errorf("could not create provider: %v", err) } - return agentProvider, titleGenerator, nil + return agentProvider, nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go deleted file mode 100644 index a3db6b55c..000000000 --- a/internal/llm/agent/coder.go +++ /dev/null @@ -1,63 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type coderAgent struct { - Service -} - -func NewCoderAgent( - permissions permission.Service, - sessions session.Service, - messages message.Service, - lspClients map[string]*lsp.Client, -) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ), - ) - if err != nil { - return nil, err - } - - return &coderAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index b1c97b512..c7ea4916c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } @@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go deleted file mode 100644 index fca1f223f..000000000 --- a/internal/llm/agent/task.go +++ /dev/null @@ -1,47 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type taskAgent struct { - Service -} - -func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - ) - if err != nil { - return nil, err - } - - return &taskAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 000000000..a37f1d65d --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,50 @@ +package agent + +import ( + "context" + + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +func CoderAgentTools( + permissions permission.Service, + sessions session.Service, + messages message.Service, + history history.Service, + lspClients map[string]*lsp.Client, +) []tools.BaseTool { + ctx := context.Background() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + return append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + }, otherTools..., + ) +} + +func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { + return []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + } +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go new file mode 100644 index 000000000..48307e6d3 --- /dev/null +++ b/internal/llm/models/anthropic.go @@ -0,0 +1,71 @@ +package models + +const ( + ProviderAnthropic ModelProvider = "anthropic" + + // Models + Claude35Sonnet ModelID = "claude-3.5-sonnet" + Claude3Haiku ModelID = "claude-3-haiku" + Claude37Sonnet ModelID = "claude-3.7-sonnet" + Claude35Haiku ModelID = "claude-3.5-haiku" + Claude3Opus ModelID = "claude-3-opus" +) + +var AnthropicModels = map[ModelID]Model{ + // Anthropic + Claude35Sonnet: { + ID: Claude35Sonnet, + Name: "Claude 3.5 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude3Haiku: { + ID: Claude3Haiku, + Name: "Claude 3 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-haiku-latest", + CostPer1MIn: 0.25, + CostPer1MInCached: 0.30, + CostPer1MOutCached: 0.03, + CostPer1MOut: 1.25, + ContextWindow: 200000, + }, + Claude37Sonnet: { + ID: Claude37Sonnet, + Name: "Claude 3.7 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-7-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude35Haiku: { + ID: Claude35Haiku, + Name: "Claude 3.5 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-haiku-latest", + CostPer1MIn: 0.80, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + CostPer1MOut: 4.0, + ContextWindow: 200000, + }, + Claude3Opus: { + ID: Claude3Opus, + Name: "Claude 3 Opus", + Provider: ProviderAnthropic, + APIModel: "claude-3-opus-latest", + CostPer1MIn: 15.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.50, + CostPer1MOut: 75.0, + ContextWindow: 200000, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 140693237..4d4589bfd 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -1,5 +1,7 @@ package models +import "maps" + type ( ModelID string ModelProvider string @@ -14,15 +16,13 @@ type Model struct { CostPer1MOut float64 `json:"cost_per_1m_out"` CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` } // Model IDs const ( - // Anthropic - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" // OpenAI + GPT4o ModelID = "gpt-4o" GPT41 ModelID = "gpt-4.1" // GEMINI @@ -37,47 +37,59 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" - ProviderAnthropic ModelProvider = "anthropic" - ProviderBedrock ModelProvider = "bedrock" - ProviderGemini ModelProvider = "gemini" - ProviderGROQ ModelProvider = "groq" + ProviderOpenAI ModelProvider = "openai" + ProviderBedrock ModelProvider = "bedrock" + ProviderGemini ModelProvider = "gemini" + ProviderGROQ ModelProvider = "groq" + + // ForTests + ProviderMock ModelProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ - // Anthropic - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, + // // Anthropic + // Claude35Sonnet: { + // ID: Claude35Sonnet, + // Name: "Claude 3.5 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-5-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // Claude3Haiku: { + // ID: Claude3Haiku, + // Name: "Claude 3 Haiku", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-haiku-latest", + // CostPer1MIn: 0.80, + // CostPer1MInCached: 1, + // CostPer1MOutCached: 0.08, + // CostPer1MOut: 4, + // }, + // Claude37Sonnet: { + // ID: Claude37Sonnet, + // Name: "Claude 3.7 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-7-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // + // // OpenAI + GPT4o: { + ID: GPT4o, + Name: "GPT-4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0, + CostPer1MOut: 8.00, }, - - // OpenAI GPT41: { ID: GPT41, Name: "GPT-4.1", @@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 8.00, }, + // + // // GEMINI + // GEMINI25: { + // ID: GEMINI25, + // Name: "Gemini 2.5 Pro", + // Provider: ProviderGemini, + // APIModel: "gemini-2.5-pro-exp-03-25", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // GRMINI20Flash: { + // ID: GRMINI20Flash, + // Name: "Gemini 2.0 Flash", + // Provider: ProviderGemini, + // APIModel: "gemini-2.0-flash", + // CostPer1MIn: 0.1, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0.025, + // CostPer1MOut: 0.4, + // }, + // + // // GROQ + // QWENQwq: { + // ID: QWENQwq, + // Name: "Qwen Qwq", + // Provider: ProviderGROQ, + // APIModel: "qwen-qwq-32b", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // // Bedrock + // BedrockClaude37Sonnet: { + // ID: BedrockClaude37Sonnet, + // Name: "Bedrock: Claude 3.7 Sonnet", + // Provider: ProviderBedrock, + // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, +} - // GEMINI - GEMINI25: { - ID: GEMINI25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-exp-03-25", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - GRMINI20Flash: { - ID: GRMINI20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.1, - CostPer1MInCached: 0, - CostPer1MOutCached: 0.025, - CostPer1MOut: 0.4, - }, - - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, +func init() { + maps.Copy(SupportedModels, AnthropicModels) } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 47941f976..7439fd570 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,11 +9,22 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" ) -func CoderOpenAISystemPrompt() string { - basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. +func CoderPrompt(provider models.ModelProvider) string { + basePrompt := baseAnthropicCoderPrompt + switch provider { + case models.ProviderOpenAI: + basePrompt = baseOpenAICoderPrompt + } + envInfo := getEnvironmentInfo() + + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) +} + +const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. # Your mindset Act like a competent, efficient software engineer who is familiar with large codebases. You should: @@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines] Never commit changes unless the user explicitly asks you to.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - -func CoderAnthropicSystemPrompt() string { - basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. @@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - func getEnvironmentInfo() string { cwd := config.WorkingDirectory() isGit := isGitRepo(cwd) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go new file mode 100644 index 000000000..63fc2df7b --- /dev/null +++ b/internal/llm/prompt/prompt.go @@ -0,0 +1,19 @@ +package prompt + +import ( + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" +) + +func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { + switch agentName { + case config.AgentCoder: + return CoderPrompt(provider) + case config.AgentTitle: + return TitlePrompt(provider) + case config.AgentTask: + return TaskPrompt(provider) + default: + return "You are a helpful assistant" + } +} diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index ee3c707fa..8bf604ad9 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,11 +2,12 @@ package prompt import ( "fmt" + + "github.com/kujtimiihoxha/termai/internal/llm/models" ) -func TaskAgentSystemPrompt() string { +func TaskPrompt(_ models.ModelProvider) string { agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. - Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 5c47f4d64..3023a8550 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,8 @@ package prompt -func TitlePrompt() string { +import "github.com/kujtimiihoxha/termai/internal/llm/models" + +func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 93c4308ad..c3a4efc49 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,187 +12,257 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" ) -type anthropicProvider struct { - client anthropic.Client - model models.Model - maxTokens int64 - apiKey string - systemMessage string - useBedrock bool - disableCache bool +type anthropicOptions struct { + useBedrock bool + disableCache bool + shouldThink func(userMessage string) bool } -type AnthropicOption func(*anthropicProvider) +type AnthropicOption func(*anthropicOptions) -func WithAnthropicSystemMessage(message string) AnthropicOption { - return func(a *anthropicProvider) { - a.systemMessage = message - } +type anthropicClient struct { + providerOptions providerClientOptions + options anthropicOptions + client anthropic.Client } -func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption { - return func(a *anthropicProvider) { - a.maxTokens = maxTokens - } -} +type AnthropicClient ProviderClient -func WithAnthropicModel(model models.Model) AnthropicOption { - return func(a *anthropicProvider) { - a.model = model +func newAnthropicClient(opts providerClientOptions) AnthropicClient { + anthropicOpts := anthropicOptions{} + for _, o := range opts.anthropicOptions { + o(&anthropicOpts) } -} -func WithAnthropicKey(apiKey string) AnthropicOption { - return func(a *anthropicProvider) { - a.apiKey = apiKey + anthropicClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithAnthropicBedrock() AnthropicOption { - return func(a *anthropicProvider) { - a.useBedrock = true + if anthropicOpts.useBedrock { + anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } -} -func WithAnthropicDisableCache() AnthropicOption { - return func(a *anthropicProvider) { - a.disableCache = true + client := anthropic.NewClient(anthropicClientOptions...) + return &anthropicClient{ + providerOptions: opts, + options: anthropicOpts, + client: client, } } -func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { - provider := &anthropicProvider{ - maxTokens: 1024, - } +func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { + cachedBlocks := 0 + for _, msg := range messages { + switch msg.Role { + case message.User: + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - for _, opt := range opts { - opt(provider) - } + case message.Assistant: + blocks := []anthropic.ContentBlockParamUnion{} + if msg.Content().String() != "" { + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + blocks = append(blocks, content) + } - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } + for _, toolCall := range msg.ToolCalls() { + var inputMap map[string]any + err := json.Unmarshal([]byte(toolCall.Input), &inputMap) + if err != nil { + continue + } + blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) + } - anthropicOptions := []option.RequestOption{} + if len(blocks) == 0 { + logging.Warn("There is a message without content, investigate") + // This should never happend but we log this because we might have a bug in our cleanup method + continue + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - if provider.apiKey != "" { - anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) - } - if provider.useBedrock { - anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + case message.Tool: + results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) + for i, toolResult := range msg.ToolResults() { + results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) + } } - - provider.client = anthropic.NewClient(anthropicOptions...) - return provider, nil + return } -func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) - - response, err := a.client.Messages.New( - ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: anthropic.Float(0), - Messages: anthropicMessages, - Tools: anthropicTools, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, +func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { + anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) + + for i, tool := range tools { + info := tool.Info() + toolParam := anthropic.ToolParam{ + Name: info.Name, + Description: anthropic.String(info.Description), + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: info.Parameters, + // TODO: figure out how we can tell claude the required fields? }, - }, - ) - if err != nil { - return nil, err - } + } - content := "" - for _, block := range response.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text + if i == len(tools)-1 && !a.options.disableCache { + toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } } - } - toolCalls := a.extractToolCalls(response.Content) - tokenUsage := a.extractTokenUsage(response.Usage) + anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return anthropicTools } -func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) +func (a *anthropicClient) finishReason(reason string) message.FinishReason { + switch reason { + case "end_turn": + return message.FinishReasonEndTurn + case "max_tokens": + return message.FinishReasonMaxTokens + case "tool_use": + return message.FinishReasonToolUse + case "stop_sequence": + return message.FinishReasonEndTurn + default: + return message.FinishReasonUnknown + } +} +func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion lastMessage := messages[len(messages)-1] + isUser := lastMessage.Role == anthropic.MessageParamRoleUser + messageContent := "" temperature := anthropic.Float(0) - if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") { - thinkingParam = anthropic.ThinkingConfigParamUnion{ - OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: int64(float64(a.maxTokens) * 0.8), - Type: "enabled", - }, + if isUser { + for _, m := range lastMessage.Content { + if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" { + messageContent = m.OfRequestTextBlock.Text + } + } + if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { + thinkingParam = anthropic.ThinkingConfigParamUnion{ + OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8), + Type: "enabled", + }, + } + temperature = anthropic.Float(1) } - temperature = anthropic.Float(1) } - eventChan := make(chan ProviderEvent) + return anthropic.MessageNewParams{ + Model: anthropic.Model(a.providerOptions.model.APIModel), + MaxTokens: a.providerOptions.maxTokens, + Temperature: temperature, + Messages: messages, + Tools: tools, + Thinking: thinkingParam, + System: []anthropic.TextBlockParam{ + { + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }, + }, + } +} - go func() { - defer close(eventChan) +func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + for { + attempts++ + anthropicResponse, err := a.client.Messages.New( + ctx, + preparedMessages, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr + } - const maxRetries = 8 - attempts := 0 + content := "" + for _, block := range anthropicResponse.Content { + if text, ok := block.AsAny().(anthropic.TextBlock); ok { + content += text.Text + } + } - for { + return &ProviderResponse{ + Content: content, + ToolCalls: a.toolCalls(*anthropicResponse), + Usage: a.usage(*anthropicResponse), + }, nil + } +} +func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + eventChan := make(chan ProviderEvent) + go func() { + for { attempts++ - - stream := a.client.Messages.NewStreaming( + anthropicStream := a.client.Messages.NewStreaming( ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: temperature, - Messages: anthropicMessages, - Tools: anthropicTools, - Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, - }, + preparedMessages, ) - accumulatedMessage := anthropic.Message{} - for stream.Next() { - event := stream.Current() + for anthropicStream.Next() { + event := anthropicStream.Current() err := accumulatedMessage.Accumulate(event) if err != nil { eventChan <- ProviderEvent{Type: EventError, Error: err} - return // Don't retry on accumulation errors + continue } switch event := event.AsAny().(type) { @@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa Content: event.Delta.Text, } } + // TODO: check if we can somehow stream tool calls case anthropic.ContentBlockStopEvent: eventChan <- ProviderEvent{Type: EventContentStop} @@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa } } - toolCalls := a.extractToolCalls(accumulatedMessage.Content) - tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage) - eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(accumulatedMessage.StopReason), + ToolCalls: a.toolCalls(accumulatedMessage), + Usage: a.usage(accumulatedMessage), + FinishReason: a.finishReason(string(accumulatedMessage.StopReason)), }, } } } - err := stream.Err() + err := anthropicStream.Err() if err == nil || errors.Is(err, io.EOF) { + close(eventChan) return } - - var apierr *anthropic.Error - if !errors.As(err, &apierr) { - eventChan <- ProviderEvent{Type: EventError, Error: err} - return - } - - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { - eventChan <- ProviderEvent{Type: EventError, Error: err} + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } - - if attempts > maxRetries { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("maximum retry attempts reached for rate limit (429)"), - } - return - } - - retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") - if len(retryAfterValues) > 0 { - var retryAfterSec int - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil { - retryMs = retryAfterSec * 1000 - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec), + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue } - } else { - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries), - } - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs } - select { - case <-ctx.Done(): + if ctx.Err() != nil { eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - return - case <-time.After(time.Duration(retryMs) * time.Millisecond): - continue } + close(eventChan) + return } }() + return eventChan +} - return eventChan, nil +func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *anthropic.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil } -func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall { +func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { var toolCalls []message.ToolCall - for _, block := range content { + for _, block := range msg.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: toolCall := message.ToolCall{ @@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni return toolCalls } -func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage { +func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { return TokenUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - CacheCreationTokens: usage.CacheCreationInputTokens, - CacheReadTokens: usage.CacheReadInputTokens, + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + CacheCreationTokens: msg.Usage.CacheCreationInputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, } } -func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { - anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - toolParam := anthropic.ToolParam{ - Name: info.Name, - Description: anthropic.String(info.Description), - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: info.Parameters, - }, - } - - if i == len(tools)-1 && !a.disableCache { - toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - - anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} +func WithAnthropicBedrock(useBedrock bool) AnthropicOption { + return func(options *anthropicOptions) { + options.useBedrock = useBedrock } - - return anthropicTools } -func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam { - anthropicMessages := make([]anthropic.MessageParam, 0, len(messages)) - cachedBlocks := 0 - - for _, msg := range messages { - switch msg.Role { - case message.User: - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - - case message.Assistant: - blocks := []anthropic.ContentBlockParamUnion{} - if msg.Content().String() != "" { - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - blocks = append(blocks, content) - } - - for _, toolCall := range msg.ToolCalls() { - var inputMap map[string]any - err := json.Unmarshal([]byte(toolCall.Input), &inputMap) - if err != nil { - continue - } - blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) - } +func WithAnthropicDisableCache() AnthropicOption { + return func(options *anthropicOptions) { + options.disableCache = true + } +} - if len(blocks) > 0 { - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } +func DefaultShouldThinkFn(s string) bool { + return strings.Contains(strings.ToLower(s), "think") +} - case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) - for i, toolResult := range msg.ToolResults() { - results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) - } +func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { + return func(options *anthropicOptions) { + options.shouldThink = fn } - - return anthropicMessages } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 677f4676b..d76925ad1 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,33 +7,29 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -type bedrockProvider struct { - childProvider Provider - model models.Model - maxTokens int64 - systemMessage string +type bedrockOptions struct { + // Bedrock specific options can be added here } -func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - return b.childProvider.SendMessages(ctx, messages, tools) -} +type BedrockOption func(*bedrockOptions) -func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - return b.childProvider.StreamResponse(ctx, messages, tools) +type bedrockClient struct { + providerOptions providerClientOptions + options bedrockOptions + childProvider ProviderClient } -func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { - provider := &bedrockProvider{} - for _, opt := range opts { - opt(provider) - } +type BedrockClient ProviderClient + +func newBedrockClient(opts providerClientOptions) BedrockClient { + bedrockOpts := bedrockOptions{} + // Apply bedrock specific options if they are added in the future - // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + // Get AWS region from environment region := os.Getenv("AWS_REGION") if region == "" { region = os.Getenv("AWS_DEFAULT_REGION") @@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { region = "us-east-1" // default region } if len(region) < 2 { - return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, // Will cause an error when used + } } + + // Prefix the model name with region regionPrefix := region[:2] - provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + modelName := opts.model.APIModel + opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) - if strings.Contains(string(provider.model.APIModel), "anthropic") { - anthropic, err := NewAnthropicProvider( - WithAnthropicModel(provider.model), - WithAnthropicMaxTokens(provider.maxTokens), - WithAnthropicSystemMessage(provider.systemMessage), - WithAnthropicBedrock(), + // Determine which provider to use based on the model + if strings.Contains(string(opts.model.APIModel), "anthropic") { + // Create Anthropic client with Bedrock configuration + anthropicOpts := opts + anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, + WithAnthropicBedrock(true), WithAnthropicDisableCache(), ) - provider.childProvider = anthropic - if err != nil { - return nil, err + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: newAnthropicClient(anthropicOpts), } - } else { - return nil, errors.New("unsupported model for bedrock provider") } - return provider, nil -} - -type BedrockOption func(*bedrockProvider) -func WithBedrockSystemMessage(message string) BedrockOption { - return func(a *bedrockProvider) { - a.systemMessage = message + // Return client with nil childProvider if model is not supported + // This will cause an error when used + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, } } -func WithBedrockMaxTokens(maxTokens int64) BedrockOption { - return func(a *bedrockProvider) { - a.maxTokens = maxTokens +func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if b.childProvider == nil { + return nil, errors.New("unsupported model for bedrock provider") } + return b.childProvider.send(ctx, messages, tools) } -func WithBedrockModel(model models.Model) BedrockOption { - return func(a *bedrockProvider) { - a.model = model +func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + eventChan := make(chan ProviderEvent) + + if b.childProvider == nil { + go func() { + eventChan <- ProviderEvent{ + Type: EventError, + Error: errors.New("unsupported model for bedrock provider"), + } + close(eventChan) + }() + return eventChan } -} + + return b.childProvider.stream(ctx, messages, tools) +} \ No newline at end of file diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 2d1db2b64..804baea28 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -4,80 +4,68 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" + "strings" + "time" "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) -type geminiProvider struct { - client *genai.Client - model models.Model - maxTokens int32 - apiKey string - systemMessage string +type geminiOptions struct { + disableCache bool } -type GeminiOption func(*geminiProvider) +type GeminiOption func(*geminiOptions) -func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) { - provider := &geminiProvider{ - maxTokens: 5000, - } +type geminiClient struct { + providerOptions providerClientOptions + options geminiOptions + client *genai.Client +} - for _, opt := range opts { - opt(provider) - } +type GeminiClient ProviderClient - if provider.systemMessage == "" { - return nil, errors.New("system message is required") +func newGeminiClient(opts providerClientOptions) GeminiClient { + geminiOpts := geminiOptions{} + for _, o := range opts.geminiOptions { + o(&geminiOpts) } - client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey)) + client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey)) if err != nil { - return nil, err - } - provider.client = client - - return provider, nil -} - -func WithGeminiSystemMessage(message string) GeminiOption { - return func(p *geminiProvider) { - p.systemMessage = message + logging.Error("Failed to create Gemini client", "error", err) + return nil } -} -func WithGeminiMaxTokens(maxTokens int32) GeminiOption { - return func(p *geminiProvider) { - p.maxTokens = maxTokens + return &geminiClient{ + providerOptions: opts, + options: geminiOpts, + client: client, } } -func WithGeminiModel(model models.Model) GeminiOption { - return func(p *geminiProvider) { - p.model = model - } -} - -func WithGeminiKey(apiKey string) GeminiOption { - return func(p *geminiProvider) { - p.apiKey = apiKey - } -} +func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { + var history []*genai.Content -func (p *geminiProvider) Close() { - if p.client != nil { - p.client.Close() - } -} + // Add system message first + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)}, + Role: "user", + }) -func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content { - var history []*genai.Content + // Add a system response to acknowledge the system message + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text("I'll help you with that.")}, + Role: "model", + }) for _, msg := range messages { switch msg.Role { @@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g Parts: []genai.Part{genai.Text(msg.Content().String())}, Role: "user", }) + case message.Assistant: content := &genai.Content{ Role: "model", @@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g } history = append(history, content) + case message.Tool: for _, result := range msg.ToolResults() { response := map[string]interface{}{"result": result.Content} @@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g if err == nil { response = parsed } + var toolCall message.ToolCall - for _, msg := range messages { - if msg.Role == message.Assistant { - for _, call := range msg.ToolCalls() { + for _, m := range messages { + if m.Role == message.Assistant { + for _, call := range m.ToolCalls() { if call.ID == result.ToolCallID { toolCall = call break @@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g return history } -func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage { - if resp == nil || resp.UsageMetadata == nil { - return TokenUsage{} - } +func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { + geminiTools := make([]*genai.Tool, 0, len(tools)) - return TokenUsage{ - InputTokens: int64(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), - CacheCreationTokens: 0, // Not directly provided by Gemini - CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + for _, tool := range tools { + info := tool.Info() + declaration := &genai.FunctionDeclaration{ + Name: info.Name, + Description: info.Description, + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: convertSchemaProperties(info.Parameters), + Required: info.Required, + }, + } + + geminiTools = append(geminiTools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{declaration}, + }) } + + return geminiTools } -func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) +func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { + reasonStr := reason.String() + switch { + case reasonStr == "STOP": + return message.FinishReasonEndTurn + case reasonStr == "MAX_TOKENS": + return message.FinishReasonMaxTokens + case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"): + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String())) - if err != nil { - return nil, err + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - var content string - var toolCalls []message.ToolCall + attempts := 0 + for { + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break + } + } - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - content = string(p) - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - }) + resp, err := chat.SendMessage(ctx, genai.Text(lastText)) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(resp) + content := "" + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + content = string(p) + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + }) + } + } + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: g.usage(resp), + FinishReason: g.finishReason(resp.Candidates[0].FinishReason), + }, nil + } } -func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) - - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - - iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String())) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) go func() { defer close(eventChan) - var finalResp *genai.GenerateContentResponse - currentContent := "" - toolCalls := []message.ToolCall{} - for { - resp, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break } - return } - finalResp = resp + iter := chat.SendMessageStream(ctx, genai.Text(lastText)) - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - newText := string(p) - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: newText, - } - currentContent += newText - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - newCall := message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - } + currentContent := "" + toolCalls := []message.ToolCall{} + var finalResp *genai.GenerateContentResponse - isNew := true - for _, existing := range toolCalls { - if existing.Name == newCall.Name && existing.Input == newCall.Input { - isNew = false - break + eventChan <- ProviderEvent{Type: EventContentStart} + + for { + resp, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + + return + case <-time.After(time.Duration(after) * time.Millisecond): + break } + } else { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + } + + finalResp = resp + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + newText := string(p) + delta := newText[len(currentContent):] + if delta != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: delta, + } + currentContent = newText + } + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + newCall := message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + } - if isNew { - toolCalls = append(toolCalls, newCall) + isNew := true + for _, existing := range toolCalls { + if existing.Name == newCall.Name && existing.Input == newCall.Input { + isNew = false + break + } + } + + if isNew { + toolCalls = append(toolCalls, newCall) + } } } } } - } - tokenUsage := p.extractTokenUsage(finalResp) + eventChan <- ProviderEvent{Type: EventContentStop} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(finalResp.Candidates[0].FinishReason.String()), - }, + if finalResp != nil { + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: g.usage(finalResp), + FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason), + }, + } + return + } + + // If we get here, we need to retry + if attempts > maxRetries { + eventChan <- ProviderEvent{ + Type: EventError, + Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries), + } + return + } + + // Wait before retrying + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + return + case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond): + continue + } } }() - return eventChan, nil + return eventChan } -func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { - declarations := make([]*genai.FunctionDeclaration, len(tools)) +func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + // Check if error is a rate limit error + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } - for i, tool := range tools { - info := tool.Info() - declarations[i] = &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: convertSchemaProperties(info.Parameters), - Required: info.Required, - }, + // Gemini doesn't have a standard error type we can check against + // So we'll check the error message for rate limit indicators + if errors.Is(err, io.EOF) { + return false, 0, err + } + + errMsg := err.Error() + isRateLimit := false + + // Check for common rate limit error messages + if contains(errMsg, "rate limit", "quota exceeded", "too many requests") { + isRateLimit = true + } + + if !isRateLimit { + return false, 0, err + } + + // Calculate backoff with jitter + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs := backoffMs + jitterMs + + return true, int64(retryMs), nil +} + +func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + if funcCall, ok := part.(genai.FunctionCall); ok { + id := "call_" + uuid.New().String() + args, _ := json.Marshal(funcCall.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: funcCall.Name, + Input: string(args), + Type: "function", + }) + } } } - return declarations + return toolCalls +} + +func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { + if resp == nil || resp.UsageMetadata == nil { + return TokenUsage{} + } + + return TokenUsage{ + InputTokens: int64(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), + CacheCreationTokens: 0, // Not directly provided by Gemini + CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + } +} + +func WithGeminiDisableCache() GeminiOption { + return func(options *geminiOptions) { + options.disableCache = true + } +} + +// Helper functions +func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { + var result map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + return result, err } func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema { @@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type { } } -func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { - var result map[string]interface{} - err := json.Unmarshal([]byte(jsonStr), &result) - return result, err +func contains(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) { + return true + } + } + return false } + diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index dbfde3fa8..9c2ad2012 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -2,89 +2,65 @@ package provider import ( "context" + "encoding/json" "errors" + "fmt" + "io" + "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) -type openaiProvider struct { - client openai.Client - model models.Model - maxTokens int64 - baseURL string - apiKey string - systemMessage string +type openaiOptions struct { + baseURL string + disableCache bool } -type OpenAIOption func(*openaiProvider) +type OpenAIOption func(*openaiOptions) -func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) { - provider := &openaiProvider{ - maxTokens: 5000, - } - - for _, opt := range opts { - opt(provider) - } - - clientOpts := []option.RequestOption{ - option.WithAPIKey(provider.apiKey), - } - if provider.baseURL != "" { - clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL)) - } - - provider.client = openai.NewClient(clientOpts...) - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } - - return provider, nil +type openaiClient struct { + providerOptions providerClientOptions + options openaiOptions + client openai.Client } -func WithOpenAISystemMessage(message string) OpenAIOption { - return func(p *openaiProvider) { - p.systemMessage = message - } -} +type OpenAIClient ProviderClient -func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption { - return func(p *openaiProvider) { - p.maxTokens = maxTokens +func newOpenAIClient(opts providerClientOptions) OpenAIClient { + openaiOpts := openaiOptions{} + for _, o := range opts.openaiOptions { + o(&openaiOpts) } -} -func WithOpenAIModel(model models.Model) OpenAIOption { - return func(p *openaiProvider) { - p.model = model + openaiClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(p *openaiProvider) { - p.baseURL = baseURL + if openaiOpts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) } -} -func WithOpenAIKey(apiKey string) OpenAIOption { - return func(p *openaiProvider) { - p.apiKey = apiKey + client := openai.NewClient(openaiClientOptions...) + return &openaiClient{ + providerOptions: opts, + options: openaiOpts, + client: client, } } -func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion { - var chatMessages []openai.ChatCompletionMessageParamUnion - - chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage)) +func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) for _, msg := range messages { switch msg.Role { case message.User: - chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String())) + openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String())) case message.Assistant: assistantMsg := openai.ChatCompletionAssistantMessageParam{ @@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o } } - chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{ + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ OfAssistant: &assistantMsg, }) case message.Tool: for _, result := range msg.ToolResults() { - chatMessages = append(chatMessages, + openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID), ) } } } - return chatMessages + return } -func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { +func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) for i, tool := range tools { @@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C return openaiTools } -func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage { - cachedTokens := int64(0) - - cachedTokens = usage.PromptTokensDetails.CachedTokens - inputTokens := usage.PromptTokens - cachedTokens - - return TokenUsage{ - InputTokens: inputTokens, - OutputTokens: usage.CompletionTokens, - CacheCreationTokens: 0, // OpenAI doesn't provide this directly - CacheReadTokens: cachedTokens, +func (o *openaiClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown } } -func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - } - - response, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, err +func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + MaxTokens: openai.Int(o.providerOptions.maxTokens), + Tools: tools, } +} - content := "" - if response.Choices[0].Message.Content != "" { - content = response.Choices[0].Message.Content +func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - - var toolCalls []message.ToolCall - if len(response.Choices[0].Message.ToolCalls) > 0 { - toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls)) - for i, call := range response.Choices[0].Message.ToolCalls { - toolCalls[i] = message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", + attempts := 0 + for { + attempts++ + openaiResponse, err := o.client.Chat.Completions.New( + ctx, + params, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(response.Usage) + content := "" + if openaiResponse.Choices[0].Message.Content != "" { + content = openaiResponse.Choices[0].Message.Content + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: o.toolCalls(*openaiResponse), + Usage: o.usage(*openaiResponse), + FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)), + }, nil + } } -func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - StreamOptions: openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - }, +func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) - toolCalls := make([]message.ToolCall, 0) go func() { - defer close(eventChan) - - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if tool, ok := acc.JustFinishedToolCall(); ok { - toolCalls = append(toolCalls, message.ToolCall{ - ID: tool.Id, - Name: tool.Name, - Input: tool.Arguments, - Type: "function", - }) - } + for { + attempts++ + openaiStream := o.client.Chat.Completions.NewStreaming( + ctx, + params, + ) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + + for openaiStream.Next() { + chunk := openaiStream.Current() + acc.AddChunk(chunk) + + if tool, ok := acc.JustFinishedToolCall(); ok { + toolCalls = append(toolCalls, message.ToolCall{ + ID: tool.Id, + Name: tool.Name, + Input: tool.Arguments, + Type: "function", + }) + } - for _, choice := range chunk.Choices { - if choice.Delta.Content != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: choice.Delta.Content, + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } + currentContent += choice.Delta.Content } - currentContent += choice.Delta.Content } } - } - if err := stream.Err(); err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + err := openaiStream.Err() + if err == nil || errors.Is(err, io.EOF) { + // Stream completed successfully + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: o.usage(acc.ChatCompletion), + FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)), + }, + } + close(eventChan) + return } + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() == nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } + }() - tokenUsage := p.extractTokenUsage(acc.Usage) + return eventChan +} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, +func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *openai.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 } - }() + } + return true, int64(retryMs), nil +} - return eventChan, nil +func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + } + toolCalls = append(toolCalls, toolCall) + } + } + + return toolCalls } + +func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // OpenAI doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func WithOpenAIBaseURL(baseURL string) OpenAIOption { + return func(options *openaiOptions) { + options.baseURL = baseURL + } +} + +func WithOpenAIDisableCache() OpenAIOption { + return func(options *openaiOptions) { + options.disableCache = true + } +} + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 34d91f2b7..1a5b3dc8a 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -2,14 +2,17 @@ package provider import ( "context" + "fmt" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -// EventType represents the type of streaming event type EventType string +const maxRetries = 8 + const ( EventContentStart EventType = "content_start" EventContentDelta EventType = "content_delta" @@ -18,7 +21,6 @@ const ( EventComplete EventType = "complete" EventError EventType = "error" EventWarning EventType = "warning" - EventInfo EventType = "info" ) type TokenUsage struct { @@ -32,61 +34,152 @@ type ProviderResponse struct { Content string ToolCalls []message.ToolCall Usage TokenUsage - FinishReason string + FinishReason message.FinishReason } type ProviderEvent struct { - Type EventType + Type EventType + Content string Thinking string - ToolCall *message.ToolCall - Error error Response *ProviderResponse - // Used for giving users info on e.x retry - Info string + Error error } - type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) + StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() models.Model +} + +type providerClientOptions struct { + apiKey string + model models.Model + maxTokens int64 + systemMessage string + + anthropicOptions []AnthropicOption + openaiOptions []OpenAIOption + geminiOptions []GeminiOption + bedrockOptions []BedrockOption +} + +type ProviderClientOption func(*providerClientOptions) + +type ProviderClient interface { + send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent +} + +type baseProvider[C ProviderClient] struct { + options providerClientOptions + client C +} + +func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{} + for _, o := range opts { + o(&clientOptions) + } + switch providerName { + case models.ProviderAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions), + }, nil + case models.ProviderOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case models.ProviderGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case models.ProviderBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case models.ProviderMock: + // TODO: implement mock client for test + panic("not implemented") + } + return nil, fmt.Errorf("provider not supported: %s", providerName) } -func cleanupMessages(messages []message.Message) []message.Message { - // First pass: filter out canceled messages - var cleanedMessages []message.Message +func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { - if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { - // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been - // cancelled - cleanedMessages = append(cleanedMessages, msg) + // The message has no content + if len(msg.Parts) == 0 { + continue } + cleaned = append(cleaned, msg) } + return +} - // Second pass: filter out tool messages without a corresponding tool call - var result []message.Message - toolMessageIDs := make(map[string]bool) +func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = p.cleanMessages(messages) + return p.client.send(ctx, messages, tools) +} - for _, msg := range cleanedMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolMessageIDs[toolCall.ID] = true // Mark as referenced - } - } +func (p *baseProvider[C]) Model() models.Model { + return p.options.model +} + +func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = p.cleanMessages(messages) + return p.client.stream(ctx, messages, tools) +} + +func WithAPIKey(apiKey string) ProviderClientOption { + return func(options *providerClientOptions) { + options.apiKey = apiKey } +} - // Keep only messages that aren't unreferenced tool messages - for _, msg := range cleanedMessages { - if msg.Role == message.Tool { - for _, toolCall := range msg.ToolResults() { - if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced { - result = append(result, msg) - } - } - } else { - result = append(result, msg) - } +func WithModel(model models.Model) ProviderClientOption { + return func(options *providerClientOptions) { + options.model = model + } +} + +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func WithSystemMessage(systemMessage string) ProviderClientOption { + return func(options *providerClientOptions) { + options.systemMessage = systemMessage + } +} + +func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.anthropicOptions = anthropicOptions + } +} + +func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = openaiOptions + } +} + +func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.geminiOptions = geminiOptions + } +} + +func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.bedrockOptions = bedrockOptions } - return result } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0cea20878..c7c970e5a 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -23,7 +23,8 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - Took int64 `json:"took"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` } type bashTool struct { permissions permission.Service @@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } - took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - Took: took, + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), } if stdout == "" { return WithResponseMetadata(NewTextResponse("no output"), metadata), nil diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 97be3683a..dafb0ccc5 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) { }) } } - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 08d6d446c..148e7aba7 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -35,6 +36,7 @@ type EditResponseMetadata struct { type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } const ( @@ -88,10 +90,11 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return response, nil } + if response.IsError { + // Return early if there was an error during content replacement + // This prevents unnecessary LSP diagnostics processing + return response, nil + } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) text := fmt.Sprintf("\n%s\n\n", response.Content) @@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // File can't be in the history so we create a new file history + _, err = e.files.Create(ctx, sessionID, filePath, "") + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + // Add the new content to the file history + _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) + if err != nil { + // Log error but don't fail the operation + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string if err != nil { return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + if oldContent == newContent { + return NewTextErrorResponse("new content is the same as old content. No changes made."), nil + } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, - - Diff: diff, + Diff: diff, }, }, ) @@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 48a34ed75..0971775dd 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -14,7 +14,7 @@ import ( ) func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, EditToolName, info.Name) @@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "replace_content.txt") @@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "delete_content.txt") @@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: EditToolName, @@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := EditParams{ FilePath: "", @@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "non_existent_file.txt") params := EditParams{ @@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "content_not_found.txt") @@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file with duplicate content filePath := filepath.Join(tempDir, "duplicate_content.txt") @@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "not_read_file.txt") @@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "permission_denied.txt") diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 9c9707c9c..7f34fdc1f 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,8 +3,6 @@ package tools import ( "sync" "time" - - "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -19,14 +17,6 @@ var ( fileRecordMutex sync.RWMutex ) -func removeWorkingDirectoryPrefix(path string) string { - wd := config.WorkingDirectory() - if len(path) > len(wd) && path[:len(wd)] == wd { - return path[len(wd)+1:] - } - return path -} - func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index bdfc23b4a..7b4fb1187 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,7 +63,7 @@ type GlobParams struct { Path string `json:"path"` } -type GlobMetadata struct { +type GlobResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GlobMetadata{ + GlobResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7e52821d0..19333f50b 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,7 +27,7 @@ type grepMatch struct { modTime time.Time } -type GrepMetadata struct { +type GrepResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } @@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GrepMetadata{ + GrepResponseMetadata{ NumberOfMatches: len(matches), Truncated: truncated, }, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a679f261b..a63bf0eeb 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,7 +23,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -type LSMetadata struct { +type LSResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { return WithResponseMetadata( NewTextResponse(output), - LSMetadata{ + LSResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go new file mode 100644 index 000000000..321f09ac1 --- /dev/null +++ b/internal/llm/tools/mocks_test.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +// Mock permission service for testing +type mockPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + allow bool +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return m.allow +} + +func newMockPermissionService(allow bool) permission.Service { + return &mockPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } +} + +type mockFileHistoryService struct { + *pubsub.Broker[history.File] + files map[string]history.File // ID -> File + timeNow func() int64 +} + +// Create implements history.Service. +func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) +} + +// CreateVersion implements history.Service. +func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + var files []history.File + for _, file := range m.files { + if file.Path == path { + files = append(files, file) + } + } + + if len(files) == 0 { + // No previous versions, create initial + return m.Create(ctx, sessionID, path, content) + } + + // Sort files by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + // Get the latest version + latestFile := files[0] + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == history.InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return m.createWithVersion(ctx, sessionID, path, content, nextVersion) +} + +func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { + now := m.timeNow() + file := history.File{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + CreatedAt: now, + UpdatedAt: now, + } + + m.files[file.ID] = file + m.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +// Delete implements history.Service. +func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { + file, ok := m.files[id] + if !ok { + return fmt.Errorf("file not found: %s", id) + } + + delete(m.files, id) + m.Publish(pubsub.DeletedEvent, file) + return nil +} + +// DeleteSessionFiles implements history.Service. +func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := m.ListBySession(ctx, sessionID) + if err != nil { + return err + } + + for _, file := range files { + err = m.Delete(ctx, file.ID) + if err != nil { + return err + } + } + + return nil +} + +// Get implements history.Service. +func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { + file, ok := m.files[id] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", id) + } + return file, nil +} + +// GetByPathAndSession implements history.Service. +func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { + var latestFile history.File + var found bool + var latestTime int64 + + for _, file := range m.files { + if file.Path == path && file.SessionID == sessionID { + if !found || file.CreatedAt > latestTime { + latestFile = file + latestTime = file.CreatedAt + found = true + } + } + } + + if !found { + return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) + } + return latestFile, nil +} + +// ListBySession implements history.Service. +func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + var files []history.File + for _, file := range m.files { + if file.SessionID == sessionID { + files = append(files, file) + } + } + + // Sort by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + return files, nil +} + +// ListLatestSessionFiles implements history.Service. +func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + // Map to track the latest file for each path + latestFiles := make(map[string]history.File) + + for _, file := range m.files { + if file.SessionID == sessionID { + existing, ok := latestFiles[file.Path] + if !ok || file.CreatedAt > existing.CreatedAt { + latestFiles[file.Path] = file + } + } + } + + // Convert map to slice + var result []history.File + for _, file := range latestFiles { + result = append(result, file) + } + + // Sort by CreatedAt in descending order + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result, nil +} + +// Subscribe implements history.Service. +func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + return m.Broker.Subscribe(ctx) +} + +// Update implements history.Service. +func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + _, ok := m.files[file.ID] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", file.ID) + } + + file.UpdatedAt = m.timeNow() + m.files[file.ID] = file + m.Publish(pubsub.UpdatedEvent, file) + return file, nil +} + +func newMockFileHistoryService() history.Service { + return &mockFileHistoryService{ + Broker: pubsub.NewBroker[history.File](), + files: make(map[string]history.File), + timeNow: func() int64 { return time.Now().Unix() }, + } +} diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 64592f67d..4a776478a 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell { commandQueue: make(chan *commandExecution, 10), } - go shell.processCommands() + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() go func() { err := cmd.Wait() if err != nil { + // Log the error if needed } shell.isAlive = false close(shell.commandQueue) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 17bc610ea..a6f2c8afb 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,7 +18,7 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } -type SourcegraphMetadata struct { +type SourcegraphResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 07afe1363..bf0f8df0b 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -14,12 +14,17 @@ type ToolInfo struct { type toolResponseType string +type ( + sessionIDContextKey string + messageIDContextKey string +) + const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" - SessionIDContextKey = "session_id" - MessageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" ) type ToolResponse struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 889561d2a..bb49381fd 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -27,6 +28,7 @@ type WritePermissionsParams struct { type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } type WriteResponseMetadata struct { @@ -67,10 +69,11 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } + // Check if file exists in history + file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = w.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) waitForLspDiagnostics(ctx, filePath, w.lspClients) diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 50dafc14f..2264f36fb 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -14,7 +14,7 @@ import ( ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: WriteToolName, @@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: "", @@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") diff --git a/internal/logging/logger.go b/internal/logging/logger.go index b06391472..7ae2e7b87 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -1,6 +1,12 @@ package logging -import "log/slog" +import ( + "fmt" + "log/slog" + "os" + "runtime/debug" + "time" +) func Info(msg string, args ...any) { slog.Info(msg, args...) @@ -37,3 +43,36 @@ func ErrorPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) slog.Error(msg, args...) } + +// RecoverPanic is a common function to handle panics gracefully. +// It logs the error, creates a panic log file with stack trace, +// and executes an optional cleanup function before returning. +func RecoverPanic(name string, cleanup func()) { + if r := recover(); r != nil { + // Log the panic + ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r)) + + // Create a timestamped panic log file + timestamp := time.Now().Format("20060102-150405") + filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp) + + file, err := os.Create(filename) + if err != nil { + ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) + } else { + defer file.Close() + + // Write panic information and stack trace + fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r) + fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339)) + fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack()) + + InfoPersist(fmt.Sprintf("Panic details written to %s", filename)) + } + + // Execute cleanup function if provided + if cleanup != nil { + cleanup() + } + } +} diff --git a/internal/lsp/client.go b/internal/lsp/client.go index e2eedc4fc..0f03e7fcb 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -97,7 +97,12 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er }() // Start message handling loop - go client.handleMessages() + go func() { + defer logging.RecoverPanic("LSP-message-handler", func() { + logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired") + }) + client.handleMessages() + }() return client, nil } @@ -374,7 +379,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { }, } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { @@ -413,12 +418,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Warn("Error closing file", "file", filePath, "error", err) } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 4913c743d..c3088d685 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -88,7 +88,7 @@ func HandleServerMessage(params json.RawMessage) { Message string `json:"message"` } if err := json.Unmarshal(params, &msg); err == nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 4185966f3..89255fd78 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -20,7 +20,7 @@ func WriteMessage(w io.Writer, msg *Message) error { } cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } @@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } line = strings.TrimSpace(line) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received header", "line", line) } @@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Content-Length", "length", contentLength) } @@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { return nil, fmt.Errorf("failed to read content: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received content", "content", string(content)) } @@ -95,7 +95,7 @@ func (c *Client) handleMessages() { for { msg, err := ReadMessage(c.stdout) if err != nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Error("Error reading message", "error", err) } return @@ -103,7 +103,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } @@ -157,11 +157,11 @@ func (c *Client) handleMessages() { c.notificationMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for notification", "method", msg.Method) } continue @@ -174,12 +174,12 @@ func (c *Client) handleMessages() { c.handlersMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for response", "id", msg.ID) } } @@ -191,7 +191,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any cnf := config.Get() id := c.nextID.Add(1) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Making call", "method", method, "id", id) } @@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any return fmt.Errorf("failed to send request: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response", "id", id) } @@ -250,7 +250,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any // Notify sends a notification (a request without an ID that doesn't expect a response) func (c *Client) Notify(ctx context.Context, method string, params any) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending notification", "method", method) } diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index b5ef15710..156f38e1a 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -50,7 +50,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc w.registrations = append(w.registrations, watchers...) // Print detailed registration information for debugging - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), @@ -116,7 +116,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Skip directories that should be excluded if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -136,7 +136,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc }) elapsedTime := time.Since(startTime) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Workspace scan complete", "filesOpened", filesOpened, "elapsedTime", elapsedTime.Seconds(), @@ -144,7 +144,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) } - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Debug("Error scanning workspace for files to open", "error", err) } }() @@ -175,7 +175,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories (except workspace root) if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -228,7 +228,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str } // Debug logging - if cnf.Debug { + if cnf.DebugLSP { matched, kind := w.isPathWatched(event.Name) logging.Debug("File event", "path", event.Name, @@ -491,7 +491,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan // notifyFileEvent sends a didChangeWatchedFiles notification for a file event func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, @@ -615,7 +615,7 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), @@ -648,7 +648,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if this path should be watched according to server registrations if watched, _ := w.isPathWatched(path); watched { // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug { + if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { logging.Error("Error opening file", "path", path, "error", err) } } diff --git a/internal/message/content.go b/internal/message/content.go index 422c04f52..f9e76b11c 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -2,6 +2,7 @@ package message import ( "encoding/base64" + "slices" "time" "github.com/kujtimiihoxha/termai/internal/llm/models" @@ -16,6 +17,20 @@ const ( Tool MessageRole = "tool" ) +type FinishReason string + +const ( + FinishReasonEndTurn FinishReason = "end_turn" + FinishReasonMaxTokens FinishReason = "max_tokens" + FinishReasonToolUse FinishReason = "tool_use" + FinishReasonCanceled FinishReason = "canceled" + FinishReasonError FinishReason = "error" + FinishReasonPermissionDenied FinishReason = "permission_denied" + + // Should never happen + FinishReasonUnknown FinishReason = "unknown" +) + type ContentPart interface { isPart() } @@ -83,8 +98,8 @@ type ToolResult struct { func (ToolResult) isPart() {} type Finish struct { - Reason string `json:"reason"` - Time int64 `json:"time"` + Reason FinishReason `json:"reason"` + Time int64 `json:"time"` } func (Finish) isPart() {} @@ -176,7 +191,7 @@ func (m *Message) FinishPart() *Finish { return nil } -func (m *Message) FinishReason() string { +func (m *Message) FinishReason() FinishReason { for _, part := range m.Parts { if c, ok := part.(Finish); ok { return c.Reason @@ -246,7 +261,14 @@ func (m *Message) SetToolResults(tr []ToolResult) { } } -func (m *Message) AddFinish(reason string) { +func (m *Message) AddFinish(reason FinishReason) { + // remove any existing finish part + for i, part := range m.Parts { + if _, ok := part.(Finish); ok { + m.Parts = slices.Delete(m.Parts, i, i+1) + break + } + } m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()}) } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 633a6d57f..3e70ae095 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -5,7 +5,7 @@ import ( "sync" ) -const bufferSize = 1024 * 1024 +const bufferSize = 1024 type Logger interface { Debug(msg string, args ...any) diff --git a/internal/session/session.go b/internal/session/session.go index 9a16224c3..019019df4 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -24,6 +24,7 @@ type Session struct { type Service interface { pubsub.Suscriber[Session] Create(ctx context.Context, title string) (Session, error) + CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) Get(ctx context.Context, id string) (Session, error) List(ctx context.Context) ([]Session, error) @@ -63,6 +64,20 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi return session, nil } +func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ + ID: "title-" + parentSessionID, + ParentSessionID: sql.NullString{String: parentSessionID, Valid: true}, + Title: "Generate a title", + }) + if err != nil { + return Session{}, err + } + session := s.fromDBItem(dbSession) + s.Publish(pubsub.CreatedEvent, session) + return session, nil +} + func (s *service) Delete(ctx context.Context, id string) error { session, err := s.Get(ctx, id) if err != nil { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e893ec2f5..e98001efa 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -19,8 +19,6 @@ type SessionSelectedMsg = session.Session type SessionClearedMsg struct{} -type AgentWorkingMsg bool - type EditorFocusMsg bool func lspsConfigured(width int) string { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e87f1ffae..e2f4da9e2 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,14 +5,17 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) type editorCmp struct { - textarea textarea.Model - agentWorking bool + app *app.App + session session.Session + textarea textarea.Model } type focusedEditorKeyMaps struct { @@ -32,7 +35,7 @@ var focusedKeyMaps = focusedEditorKeyMaps{ ), Blur: key.NewBinding( key.WithKeys("esc"), - key.WithHelp("esc", "blur editor"), + key.WithHelp("esc", "focus messages"), ), } @@ -52,7 +55,7 @@ func (m *editorCmp) Init() tea.Cmd { } func (m *editorCmp) send() tea.Cmd { - if m.agentWorking { + if m.app.CoderAgent.IsSessionBusy(m.session.ID) { return util.ReportWarn("Agent is working, please wait...") } @@ -66,7 +69,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(AgentWorkingMsg(true)), util.CmdHandler(EditorFocusMsg(false)), ) } @@ -74,8 +76,11 @@ func (m *editorCmp) send() tea.Cmd { func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + m.session = msg + } + return m, nil case tea.KeyMsg: // if the key does not match any binding, return if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { @@ -122,7 +127,7 @@ func (m *editorCmp) BindingKeys() []key.Binding { return bindings } -func NewEditorCmp() tea.Model { +func NewEditorCmp(app *app.App) tea.Model { ti := textarea.New() ti.Prompt = " " ti.ShowLineNumbers = false @@ -138,6 +143,7 @@ func NewEditorCmp() tea.Model { ti.CharLimit = -1 ti.Focus() return &editorCmp{ + app: app, textarea: ti, } } diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index dc21fca29..26a98970e 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -6,7 +6,9 @@ import ( "fmt" "math" "strings" + "time" + "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" @@ -17,9 +19,11 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) @@ -32,6 +36,9 @@ const ( toolMessageType ) +// messagesTickMsg is a message sent by the timer to refresh messages +type messagesTickMsg time.Time + type uiMessage struct { ID string messageType uiMessageType @@ -52,24 +59,34 @@ type messagesCmp struct { renderer *glamour.TermRenderer focusRenderer *glamour.TermRenderer cachedContent map[string]string - agentWorking bool spinner spinner.Model needsRerender bool - lastViewport string } func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init()) + return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages()) +} + +func (m *messagesCmp) tickMessages() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return messagesTickMsg(t) + }) } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) - if m.agentWorking { - cmds = append(cmds, m.spinner.Tick) + case messagesTickMsg: + // Refresh messages if we have an active session + if m.session.ID != "" { + messages, err := m.app.Messages.List(context.Background(), m.session.ID) + if err == nil { + m.messages = messages + m.needsRerender = true + } } + // Continue ticking + cmds = append(cmds, m.tickMessages()) case EditorFocusMsg: m.writingMode = bool(msg) case SessionSelectedMsg: @@ -84,6 +101,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.messages = make([]message.Message, 0) m.currentMsgID = "" m.needsRerender = true + m.cachedContent = make(map[string]string) return m, nil case tea.KeyMsg: @@ -104,6 +122,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if !messageExists { + // If we have messages, ensure the previous last message is not cached + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + m.messages = append(m.messages, msg.Payload) delete(m.cachedContent, m.currentMsgID) m.currentMsgID = msg.Payload.ID @@ -112,36 +136,40 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } for _, v := range m.messages { for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called if c.ID == msg.Payload.SessionID { m.needsRerender = true } } } } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + logging.Debug("Message", "finish", msg.Payload.FinishReason()) for i, v := range m.messages { if v.ID == msg.Payload.ID { - if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" { - cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false))) - } m.messages[i] = msg.Payload delete(m.cachedContent, msg.Payload.ID) + + // If this is the last message, ensure it's not cached + if i == len(m.messages)-1 { + delete(m.cachedContent, msg.Payload.ID) + } + m.needsRerender = true break } } } } - if m.agentWorking { - u, cmd := m.spinner.Update(msg) - m.spinner = u - cmds = append(cmds, cmd) - } + oldPos := m.viewport.YPosition u, cmd := m.viewport.Update(msg) m.viewport = u m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos cmds = append(cmds, cmd) + + spinner, cmd := m.spinner.Update(msg) + m.spinner = spinner + cmds = append(cmds, cmd) + if m.needsRerender { m.renderView() if len(m.messages) > 0 { @@ -157,10 +185,21 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } +func (m *messagesCmp) IsAgentWorking() bool { + return m.app.CoderAgent.IsSessionBusy(m.session.ID) +} + func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { - if v, ok := m.cachedContent[msg.ID]; ok { - return v + // Check if this is the last message in the list + isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID + + // Only use cache for non-last messages + if !isLastMessage { + if v, ok := m.cachedContent[msg.ID]; ok { + return v + } } + style := styles.BaseStyle. Width(m.width). BorderLeft(true). @@ -191,7 +230,12 @@ func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) s parts..., ), ) - m.cachedContent[msg.ID] = rendered + + // Only cache if it's not the last message + if !isLastMessage { + m.cachedContent[msg.ID] = rendered + } + return rendered } @@ -207,32 +251,71 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string { return fmt.Sprintf("%dm%ds", minutes, seconds) } +func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult { + for _, v := range m.messages { + for _, c := range v.ToolResults() { + if c.ToolCallID == callID { + return &c + } + } + } + return nil +} + func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { key := "" value := "" + result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...") + + response := m.findToolResponse(toolCall.ID) + if response != nil && response.IsError { + // Clean up error message for display by removing newlines + // This ensures error messages display properly in the UI + errMsg := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "...")) + } else if response != nil { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done") + } switch toolCall.Name { // TODO: add result data to the tools case agent.AgentToolName: key = "Task" var params agent.AgentParams json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Prompt - // TODO: handle nested calls + value = strings.ReplaceAll(params.Prompt, "\n", " ") + if response != nil && !response.IsError { + firstRow := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "...")) + } case tools.BashToolName: key = "Bash" var params tools.BashParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Command + if response != nil && !response.IsError { + metadata := tools.BashResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime))) + } + case tools.EditToolName: key = "Edit" var params tools.EditParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.EditResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } case tools.FetchToolName: key = "Fetch" var params tools.FetchParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.URL + if response != nil && !response.IsError { + result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content) + } case tools.GlobToolName: key = "Glob" var params tools.GlobParams @@ -241,6 +324,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GlobResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.GrepToolName: key = "Grep" var params tools.GrepParams @@ -249,19 +341,46 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GrepResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches)) + } + } case tools.LSToolName: - key = "Ls" + key = "ls" var params tools.LSParams json.Unmarshal([]byte(toolCall.Input), ¶ms) if params.Path == "" { params.Path = "." } value = params.Path + if response != nil && !response.IsError { + metadata := tools.LSResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.SourcegraphToolName: key = "Sourcegraph" var params tools.SourcegraphParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Query + if response != nil && !response.IsError { + metadata := tools.SourcegraphResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches)) + } + } case tools.ViewToolName: key = "View" var params tools.ViewParams @@ -272,6 +391,12 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s var params tools.WriteParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.WriteResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } default: key = toolCall.Name var params map[string]any @@ -300,14 +425,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s ) if !isNested { value = valyeStyle. - Width(m.width - lipgloss.Width(keyValye) - 2). Render( ansi.Truncate( - value, - m.width-lipgloss.Width(keyValye)-2, + value+" ", + m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result), "...", ), ) + value += result + } else { keyValye = keyStyle.Render( fmt.Sprintf(" └ %s: ", key), @@ -409,6 +535,27 @@ func (m *messagesCmp) renderView() { m.uiMessages = make([]uiMessage, 0) pos := 0 + // If we have messages, ensure the last message is not cached + // This ensures we always render the latest content for the most recent message + // which may be actively updating (e.g., during generation) + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + + // Limit cache to 10 messages + if len(m.cachedContent) > 15 { + // Create a list of keys to delete (oldest messages first) + keys := make([]string, 0, len(m.cachedContent)) + for k := range m.cachedContent { + keys = append(keys, k) + } + // Delete oldest messages until we have 10 or fewer + for i := 0; i < len(keys)-15; i++ { + delete(m.cachedContent, keys[i]) + } + } + for _, v := range m.messages { switch v.Role { case message.User: @@ -487,7 +634,7 @@ func (m *messagesCmp) View() string { func (m *messagesCmp) help() string { text := "" - if m.agentWorking { + if m.IsAgentWorking() { text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), ) @@ -562,9 +709,15 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { m.messages = messages m.currentMsgID = m.messages[len(m.messages)-1].ID m.needsRerender = true + m.cachedContent = make(map[string]string) return nil } +func (m *messagesCmp) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(m.viewport.KeyMap) + return bindings +} + func NewMessagesCmp(app *app.App) tea.Model { focusRenderer, _ := glamour.NewTermRenderer( glamour.WithStyles(styles.MarkdownTheme(true)), diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 51192cf9a..b90269d1a 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -1,10 +1,15 @@ package chat import ( + "context" "fmt" + "strings" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/styles" @@ -13,9 +18,33 @@ import ( type sidebarCmp struct { width, height int session session.Session + history history.Service + modFiles map[string]struct { + additions int + removals int + } } func (m *sidebarCmp) Init() tea.Cmd { + if m.history != nil { + ctx := context.Background() + // Subscribe to file events + filesCh := m.history.Subscribe(ctx) + + // Initialize the modified files map + m.modFiles = make(map[string]struct { + additions int + removals int + }) + + // Load initial files and calculate diffs + m.loadModifiedFiles(ctx) + + // Return a command that will send file events to the Update method + return func() tea.Msg { + return <-filesCh + } + } return nil } @@ -27,6 +56,13 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.Payload } } + case pubsub.Event[history.File]: + if msg.Payload.SessionID == m.session.ID { + // When a file changes, reload all modified files + // This ensures we have the complete and accurate list + ctx := context.Background() + m.loadModifiedFiles(ctx) + } } return m, nil } @@ -86,18 +122,28 @@ func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) stri func (m *sidebarCmp) modifiedFiles() string { modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render("Modified Files:") - files := []struct { - path string - additions int - removals int - }{ - {"file1.txt", 10, 5}, - {"file2.txt", 20, 0}, - {"file3.txt", 0, 15}, + + // If no modified files, show a placeholder message + if m.modFiles == nil || len(m.modFiles) == 0 { + message := "No modified files" + remainingWidth := m.width - lipgloss.Width(modifiedFiles) + if remainingWidth > 0 { + message += strings.Repeat(" ", remainingWidth) + } + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + modifiedFiles, + styles.BaseStyle.Foreground(styles.ForgroundDim).Render(message), + ), + ) } + var fileViews []string - for _, file := range files { - fileViews = append(fileViews, m.modifiedFile(file.path, file.additions, file.removals)) + for path, stats := range m.modFiles { + fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals)) } return styles.BaseStyle. @@ -123,8 +169,116 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func NewSidebarCmp(session session.Session) tea.Model { +func NewSidebarCmp(session session.Session, history history.Service) tea.Model { return &sidebarCmp{ session: session, + history: history, + } +} + +func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { + if m.history == nil || m.session.ID == "" { + return + } + + // Get all latest files for this session + latestFiles, err := m.history.ListLatestSessionFiles(ctx, m.session.ID) + if err != nil { + return + } + + // Get all files for this session (to find initial versions) + allFiles, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Process each latest file + for _, file := range latestFiles { + // Skip if this is the initial version (no changes to show) + if file.Version == history.InitialVersion { + continue + } + + // Find the initial version for this specific file + var initialVersion history.File + for _, v := range allFiles { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + continue + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } + } + } +} + +func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) { + // Skip if not the latest version + if file.Version == history.InitialVersion { + return + } + + // Get all versions of this file + fileVersions, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Find the initial version + var initialVersion history.File + for _, v := range fileVersions { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + return + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } } } diff --git a/internal/tui/components/core/dialog.go b/internal/tui/components/core/dialog.go deleted file mode 100644 index a8fef2e86..000000000 --- a/internal/tui/components/core/dialog.go +++ /dev/null @@ -1,117 +0,0 @@ -package core - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SizeableModel interface { - tea.Model - layout.Sizeable -} - -type DialogMsg struct { - Content SizeableModel - WidthRatio float64 - HeightRatio float64 - - MinWidth int - MinHeight int -} - -type DialogCloseMsg struct{} - -type KeyBindings struct { - Return key.Binding -} - -var keys = KeyBindings{ - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), -} - -type DialogCmp interface { - tea.Model - layout.Bindings -} - -type dialogCmp struct { - content SizeableModel - screenWidth int - screenHeight int - - widthRatio float64 - heightRatio float64 - - minWidth int - minHeight int - - width int - height int -} - -func (d *dialogCmp) Init() tea.Cmd { - return nil -} - -func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - d.screenWidth = msg.Width - d.screenHeight = msg.Height - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - return d, nil - case DialogMsg: - d.content = msg.Content - d.widthRatio = msg.WidthRatio - d.heightRatio = msg.HeightRatio - d.minWidth = msg.MinWidth - d.minHeight = msg.MinHeight - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - case DialogCloseMsg: - d.content = nil - return d, nil - case tea.KeyMsg: - if key.Matches(msg, keys.Return) { - return d, util.CmdHandler(DialogCloseMsg{}) - } - } - if d.content != nil { - u, cmd := d.content.Update(msg) - d.content = u.(SizeableModel) - return d, cmd - } - return d, nil -} - -func (d *dialogCmp) BindingKeys() []key.Binding { - bindings := []key.Binding{keys.Return} - if d.content == nil { - return bindings - } - if c, ok := d.content.(layout.Bindings); ok { - return append(bindings, c.BindingKeys()...) - } - return bindings -} - -func (d *dialogCmp) View() string { - return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View()) -} - -func NewDialogCmp() DialogCmp { - return &dialogCmp{} -} diff --git a/internal/tui/components/core/help.go b/internal/tui/components/core/help.go deleted file mode 100644 index 4ef857c78..000000000 --- a/internal/tui/components/core/help.go +++ /dev/null @@ -1,119 +0,0 @@ -package core - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type HelpCmp interface { - tea.Model - SetBindings(bindings []key.Binding) - Height() int -} - -const ( - helpWidgetHeight = 12 -) - -type helpCmp struct { - width int - bindings []key.Binding -} - -func (h *helpCmp) Init() tea.Cmd { - return nil -} - -func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - h.width = msg.Width - } - return h, nil -} - -func (h *helpCmp) View() string { - helpKeyStyle := styles.Bold.Foreground(styles.Rosewater).Margin(0, 1, 0, 0) - helpDescStyle := styles.Regular.Foreground(styles.Flamingo) - // Compile list of bindings to render - bindings := removeDuplicateBindings(h.bindings) - // Enumerate through each group of bindings, populating a series of - // pairs of columns, one for keys, one for descriptions - var ( - pairs []string - width int - rows = helpWidgetHeight - 2 - ) - for i := 0; i < len(bindings); i += rows { - var ( - keys []string - descs []string - ) - for j := i; j < min(i+rows, len(bindings)); j++ { - keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) - descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) - } - // Render pair of columns; beyond the first pair, render a three space - // left margin, in order to visually separate the pairs. - var cols []string - if len(pairs) > 0 { - cols = []string{" "} - } - cols = append(cols, - strings.Join(keys, "\n"), - strings.Join(descs, "\n"), - ) - - pair := lipgloss.JoinHorizontal(lipgloss.Top, cols...) - // check whether it exceeds the maximum width avail (the width of the - // terminal, subtracting 2 for the borders). - width += lipgloss.Width(pair) - if width > h.width-2 { - break - } - pairs = append(pairs, pair) - } - - // Join pairs of columns and enclose in a border - content := lipgloss.JoinHorizontal(lipgloss.Top, pairs...) - return styles.DoubleBorder.Height(rows).PaddingLeft(1).Width(h.width - 2).Render(content) -} - -func removeDuplicateBindings(bindings []key.Binding) []key.Binding { - seen := make(map[string]struct{}) - result := make([]key.Binding, 0, len(bindings)) - - // Process bindings in reverse order - for i := len(bindings) - 1; i >= 0; i-- { - b := bindings[i] - k := strings.Join(b.Keys(), " ") - if _, ok := seen[k]; ok { - // duplicate, skip - continue - } - seen[k] = struct{}{} - // Add to the beginning of result to maintain original order - result = append([]key.Binding{b}, result...) - } - - return result -} - -func (h *helpCmp) SetBindings(bindings []key.Binding) { - h.bindings = bindings -} - -func (h helpCmp) Height() int { - return helpWidgetHeight -} - -func NewHelpCmp() HelpCmp { - return &helpCmp{ - width: 0, - bindings: make([]key.Binding, 0), - } -} diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 93ba34507..089dffa2c 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -1,21 +1,25 @@ package core import ( + "fmt" + "strings" "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/termai/internal/version" ) type statusCmp struct { info util.InfoMsg width int messageTTL time.Duration + lspClients map[string]*lsp.Client } // clearMessageCmd is a command that clears status messages after a timeout @@ -47,20 +51,18 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } -var ( - versionWidget = styles.Padded.Background(styles.DarkGrey).Foreground(styles.Text).Render(version.Version) - helpWidget = styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") -) +var helpWidget = styles.Padded.Background(styles.ForgroundMid).Foreground(styles.BackgroundDarker).Bold(true).Render("ctrl+? help") func (m statusCmp) View() string { - status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") + status := helpWidget + diagnostics := styles.Padded.Background(styles.BackgroundDarker).Render(m.projectDiagnostics()) if m.info.Msg != "" { infoStyle := styles.Padded. Foreground(styles.Base). - Width(m.availableFooterMsgWidth()) + Width(m.availableFooterMsgWidth(diagnostics)) switch m.info.Type { case util.InfoTypeInfo: - infoStyle = infoStyle.Background(styles.Blue) + infoStyle = infoStyle.Background(styles.BorderColor) case util.InfoTypeWarn: infoStyle = infoStyle.Background(styles.Peach) case util.InfoTypeError: @@ -68,7 +70,7 @@ func (m statusCmp) View() string { } // Truncate message if it's longer than available width msg := m.info.Msg - availWidth := m.availableFooterMsgWidth() - 10 + availWidth := m.availableFooterMsgWidth(diagnostics) - 10 if len(msg) > availWidth && availWidth > 0 { msg = msg[:availWidth] + "..." } @@ -76,27 +78,81 @@ func (m statusCmp) View() string { } else { status += styles.Padded. Foreground(styles.Base). - Background(styles.LightGrey). - Width(m.availableFooterMsgWidth()). + Background(styles.BackgroundDim). + Width(m.availableFooterMsgWidth(diagnostics)). Render("") } + status += diagnostics status += m.model() - status += versionWidget return status } -func (m statusCmp) availableFooterMsgWidth() int { - // -2 to accommodate padding - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model())) +func (m *statusCmp) projectDiagnostics() string { + errorDiagnostics := []protocol.Diagnostic{} + warnDiagnostics := []protocol.Diagnostic{} + hintDiagnostics := []protocol.Diagnostic{} + infoDiagnostics := []protocol.Diagnostic{} + for _, client := range m.lspClients { + for _, d := range client.GetDiagnostics() { + for _, diag := range d { + switch diag.Severity { + case protocol.SeverityError: + errorDiagnostics = append(errorDiagnostics, diag) + case protocol.SeverityWarning: + warnDiagnostics = append(warnDiagnostics, diag) + case protocol.SeverityHint: + hintDiagnostics = append(hintDiagnostics, diag) + case protocol.SeverityInformation: + infoDiagnostics = append(infoDiagnostics, diag) + } + } + } + } + + if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { + return "No diagnostics" + } + + diagnostics := []string{} + + if len(errorDiagnostics) > 0 { + errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) + diagnostics = append(diagnostics, errStr) + } + if len(warnDiagnostics) > 0 { + warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) + diagnostics = append(diagnostics, warnStr) + } + if len(hintDiagnostics) > 0 { + hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) + diagnostics = append(diagnostics, hintStr) + } + if len(infoDiagnostics) > 0 { + infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) + diagnostics = append(diagnostics, infoStr) + } + + return strings.Join(diagnostics, " ") +} + +func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)) } func (m statusCmp) model() string { - model := models.SupportedModels[config.Get().Model.Coder] + cfg := config.Get() + + coder, ok := cfg.Agents[config.AgentCoder] + if !ok { + return "Unknown" + } + model := models.SupportedModels[coder.Model] return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name) } -func NewStatusCmp() tea.Model { +func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model { return &statusCmp{ messageTTL: 10 * time.Second, + lspClients: lspClients, } } diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go new file mode 100644 index 000000000..1d3c2b077 --- /dev/null +++ b/internal/tui/components/dialog/help.go @@ -0,0 +1,182 @@ +package dialog + +import ( + "strings" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type helpCmp struct { + width int + height int + keys []key.Binding +} + +func (h *helpCmp) Init() tea.Cmd { + return nil +} + +func (h *helpCmp) SetBindings(k []key.Binding) { + h.keys = k +} + +func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + h.width = 80 + h.height = msg.Height + } + return h, nil +} + +func removeDuplicateBindings(bindings []key.Binding) []key.Binding { + seen := make(map[string]struct{}) + result := make([]key.Binding, 0, len(bindings)) + + // Process bindings in reverse order + for i := len(bindings) - 1; i >= 0; i-- { + b := bindings[i] + k := strings.Join(b.Keys(), " ") + if _, ok := seen[k]; ok { + // duplicate, skip + continue + } + seen[k] = struct{}{} + // Add to the beginning of result to maintain original order + result = append([]key.Binding{b}, result...) + } + + return result +} + +func (h *helpCmp) render() string { + helpKeyStyle := styles.Bold.Background(styles.Background).Foreground(styles.Forground).Padding(0, 1, 0, 0) + helpDescStyle := styles.Regular.Background(styles.Background).Foreground(styles.ForgroundMid) + // Compile list of bindings to render + bindings := removeDuplicateBindings(h.keys) + // Enumerate through each group of bindings, populating a series of + // pairs of columns, one for keys, one for descriptions + var ( + pairs []string + width int + rows = 12 - 2 + ) + for i := 0; i < len(bindings); i += rows { + var ( + keys []string + descs []string + ) + for j := i; j < min(i+rows, len(bindings)); j++ { + keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) + descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) + } + // Render pair of columns; beyond the first pair, render a three space + // left margin, in order to visually separate the pairs. + var cols []string + if len(pairs) > 0 { + cols = []string{styles.BaseStyle.Render(" ")} + } + + maxDescWidth := 0 + for _, desc := range descs { + if maxDescWidth < lipgloss.Width(desc) { + maxDescWidth = lipgloss.Width(desc) + } + } + for i := range descs { + remainingWidth := maxDescWidth - lipgloss.Width(descs[i]) + if remainingWidth > 0 { + descs[i] = descs[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + maxKeyWidth := 0 + for _, key := range keys { + if maxKeyWidth < lipgloss.Width(key) { + maxKeyWidth = lipgloss.Width(key) + } + } + for i := range keys { + remainingWidth := maxKeyWidth - lipgloss.Width(keys[i]) + if remainingWidth > 0 { + keys[i] = keys[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + + cols = append(cols, + strings.Join(keys, "\n"), + strings.Join(descs, "\n"), + ) + + pair := styles.BaseStyle.Render(lipgloss.JoinHorizontal(lipgloss.Top, cols...)) + // check whether it exceeds the maximum width avail (the width of the + // terminal, subtracting 2 for the borders). + width += lipgloss.Width(pair) + if width > h.width-2 { + break + } + pairs = append(pairs, pair) + } + + // https://github.com/charmbracelet/lipgloss/issues/209 + if len(pairs) > 1 { + prefix := pairs[:len(pairs)-1] + lastPair := pairs[len(pairs)-1] + prefix = append(prefix, lipgloss.Place( + lipgloss.Width(lastPair), // width + lipgloss.Height(prefix[0]), // height + lipgloss.Left, // x + lipgloss.Top, // y + lastPair, // content + lipgloss.WithWhitespaceBackground(styles.Background), // background + )) + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + prefix..., + ), + ) + return content + } + // Join pairs of columns and enclose in a border + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + pairs..., + ), + ) + return content +} + +func (h *helpCmp) View() string { + content := h.render() + header := styles.BaseStyle. + Bold(true). + Width(lipgloss.Width(content)). + Foreground(styles.PrimaryColor). + Render("Keyboard Shortcuts") + + return styles.BaseStyle.Padding(1). + Border(lipgloss.RoundedBorder()). + BorderForeground(styles.ForgroundDim). + Width(h.width). + BorderBackground(styles.Background). + Render( + lipgloss.JoinVertical(lipgloss.Center, + header, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(header))), + content, + ), + ) +} + +type HelpCmp interface { + tea.Model + SetBindings([]key.Binding) +} + +func NewHelpCmp() HelpCmp { + return &helpCmp{} +} diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index d147f89cd..9c55effde 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -12,12 +12,9 @@ import ( "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) type PermissionAction string @@ -35,69 +32,64 @@ type PermissionResponseMsg struct { Action PermissionAction } -// PermissionDialog interface for permission dialog component -type PermissionDialog interface { +// PermissionDialogCmp interface for permission dialog component +type PermissionDialogCmp interface { tea.Model - layout.Sizeable layout.Bindings + SetPermissions(permission permission.PermissionRequest) } -type keyMap struct { - ChangeFocus key.Binding +type permissionsMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Allow key.Binding + AllowSession key.Binding + Deny key.Binding + Tab key.Binding } -var keyMapValue = keyMap{ - ChangeFocus: key.NewBinding( +var permissionsKeys = permissionsMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Allow: key.NewBinding( + key.WithKeys("a"), + key.WithHelp("a", "allow"), + ), + AllowSession: key.NewBinding( + key.WithKeys("A"), + key.WithHelp("A", "allow for session"), + ), + Deny: key.NewBinding( + key.WithKeys("d"), + key.WithHelp("d", "deny"), + ), + Tab: key.NewBinding( key.WithKeys("tab"), - key.WithHelp("tab", "change focus"), + key.WithHelp("tab", "switch options"), ), } // permissionDialogCmp is the implementation of PermissionDialog type permissionDialogCmp struct { - form *huh.Form width int height int permission permission.PermissionRequest windowSize tea.WindowSizeMsg - r *glamour.TermRenderer contentViewPort viewport.Model - isViewportFocus bool - selectOption *huh.Select[string] -} + selectedOption int // 0: Allow, 1: Allow for session, 2: Deny -// formatDiff formats a diff string with colors for additions and deletions -func formatDiff(diffText string) string { - lines := strings.Split(diffText, "\n") - var formattedLines []string - - // Define styles for different line types - addStyle := lipgloss.NewStyle().Foreground(styles.Green) - removeStyle := lipgloss.NewStyle().Foreground(styles.Red) - headerStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Blue) - contextStyle := lipgloss.NewStyle().Foreground(styles.SubText0) - - // Process each line - for _, line := range lines { - if strings.HasPrefix(line, "+") { - formattedLines = append(formattedLines, addStyle.Render(line)) - } else if strings.HasPrefix(line, "-") { - formattedLines = append(formattedLines, removeStyle.Render(line)) - } else if strings.HasPrefix(line, "Changes:") || strings.HasPrefix(line, " ...") { - formattedLines = append(formattedLines, headerStyle.Render(line)) - } else if strings.HasPrefix(line, " ") { - formattedLines = append(formattedLines, contextStyle.Render(line)) - } else { - formattedLines = append(formattedLines, line) - } - } - - // Join all formatted lines - return strings.Join(formattedLines, "\n") + diffCache map[string]string + markdownCache map[string]string } func (p *permissionDialogCmp) Init() tea.Cmd { - return nil + return p.contentViewPort.Init() } func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -106,373 +98,363 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: p.windowSize = msg + p.SetSize() + p.markdownCache = make(map[string]string) + p.diffCache = make(map[string]string) case tea.KeyMsg: - if key.Matches(msg, keyMapValue.ChangeFocus) { - p.isViewportFocus = !p.isViewportFocus - if p.isViewportFocus { - p.selectOption.Blur() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.ReportInfo("Viewing content - use arrow keys to scroll"), - )) - } else { - p.selectOption.Focus() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.CmdHandler(util.ReportInfo("Select an action")), - )) - } - return p, tea.Batch(cmds...) - } - } - - if p.isViewportFocus { - viewPort, cmd := p.contentViewPort.Update(msg) - p.contentViewPort = viewPort - cmds = append(cmds, cmd) - } else { - form, cmd := p.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - p.form = f + switch { + case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): + // Change selected option + p.selectedOption = (p.selectedOption + 1) % 3 + return p, nil + case key.Matches(msg, permissionsKeys.EnterSpace): + // Select current option + return p, p.selectCurrentOption() + case key.Matches(msg, permissionsKeys.Allow): + // Select Allow + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.AllowSession): + // Select Allow for session + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.Deny): + // Select Deny + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) + default: + // Pass other keys to viewport + viewPort, cmd := p.contentViewPort.Update(msg) + p.contentViewPort = viewPort cmds = append(cmds, cmd) } - - if p.form.State == huh.StateCompleted { - // Get the selected action - action := p.form.GetString("action") - - // Close the dialog and return the response - return p, tea.Batch( - util.CmdHandler(core.DialogCloseMsg{}), - util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}), - ) - } } + return p, tea.Batch(cmds...) } -func (p *permissionDialogCmp) render() string { - keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater) - valueStyle := lipgloss.NewStyle().Foreground(styles.Peach) +func (p *permissionDialogCmp) selectCurrentOption() tea.Cmd { + var action PermissionAction - form := p.form.View() - - headerParts := []string{ - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)), - " ", - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)), - " ", + switch p.selectedOption { + case 0: + action = PermissionAllow + case 1: + action = PermissionAllowForSession + case 2: + action = PermissionDeny } - // Create the header content first so it can be used in all cases - headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(p.width-10), - glamour.WithEmoji(), - ) - - // Handle different tool types - switch p.permission.ToolName { - case tools.BashToolName: - pr := p.permission.Params.(tools.BashPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Command:")) - content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on content - contentLines := len(strings.Split(renderedContent, "\n")) - // Set a reasonable min/max for the viewport height - minContentHeight := 3 - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - - // Add some padding to the content lines - contentHeight := contentLines + 2 - contentHeight = max(contentHeight, minContentHeight) - contentHeight = min(contentHeight, maxContentHeight) - p.contentViewPort.Height = contentHeight - - p.contentViewPort.SetContent(renderedContent) + return util.CmdHandler(PermissionResponseMsg{Action: action, Permission: p.permission}) +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderButtons() string { + allowStyle := styles.BaseStyle + allowSessionStyle := styles.BaseStyle + denyStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + // Style the selected button + switch p.selectedOption { + case 0: + allowStyle = allowStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 1: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 2: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + } - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } + allowButton := allowStyle.Padding(0, 1).Render("Allow (a)") + allowSessionButton := allowSessionStyle.Padding(0, 1).Render("Allow for session (A)") + denyButton := denyStyle.Padding(0, 1).Render("Deny (d)") + + content := lipgloss.JoinHorizontal( + lipgloss.Left, + allowButton, + spacerStyle.Render(" "), + allowSessionButton, + spacerStyle.Render(" "), + denyButton, + spacerStyle.Render(" "), + ) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + remainingWidth := p.width - lipgloss.Width(content) + if remainingWidth > 0 { + content = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + content + } + return content +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderHeader() string { + toolKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Tool") + toolValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(toolKey)). + Render(fmt.Sprintf(": %s", p.permission.ToolName)) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + pathKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Path") + pathValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(pathKey)). + Render(fmt.Sprintf(": %s", p.permission.Path)) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + headerParts := []string{ + lipgloss.JoinHorizontal( + lipgloss.Left, + toolKey, + toolValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + lipgloss.JoinHorizontal( + lipgloss.Left, + pathKey, + pathValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + } + // Add tool-specific header information + switch p.permission.ToolName { + case tools.BashToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Command")) case tools.EditToolName: - pr := p.permission.Params.(tools.EditPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Update")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Format the diff with colors - - // Set up viewport for the diff content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } - p.contentViewPort.SetContent(diff) + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.WriteToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.FetchToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("URL")) + } - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor + return lipgloss.NewStyle().Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) +} - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } +func (p *permissionDialogCmp) renderBashContent() string { + if pr, ok := p.permission.Params.(tools.BashPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) + + finalContent := styles.BaseStyle. + Width(p.contentViewPort.Width). + Render(renderedContent) + p.contentViewPort.SetContent(finalContent) + return p.styleViewport() + } + return "" +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderEditContent() string { + if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok { + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) +func (p *permissionDialogCmp) renderWriteContent() string { + if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok { + // Use the cache for diff rendering + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - case tools.WriteToolName: - pr := p.permission.Params.(tools.WritePermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Content")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Set up viewport for the content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderFetchContent() string { + if pr, ok := p.permission.Params.(tools.FetchPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.URL) - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } - - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } + p.contentViewPort.SetContent(renderedContent) + return p.styleViewport() + } + return "" +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) +func (p *permissionDialogCmp) renderDefaultContent() string { + content := p.permission.Description - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.CatppuccinMarkdownStyle()), + glamour.WithWordWrap(p.width-10), ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - case tools.FetchToolName: - pr := p.permission.Params.(tools.FetchPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("URL: "+pr.URL)) - content := p.permission.Description + p.contentViewPort.SetContent(renderedContent) - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) + if renderedContent == "" { + return "" + } - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + return p.styleViewport() +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } +func (p *permissionDialogCmp) styleViewport() string { + contentStyle := lipgloss.NewStyle(). + Background(styles.Background) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + return contentStyle.Render(p.contentViewPort.View()) +} +func (p *permissionDialogCmp) render() string { + title := styles.BaseStyle. + Bold(true). + Width(p.width - 4). + Foreground(styles.PrimaryColor). + Render("Permission Required") + // Render header + headerContent := p.renderHeader() + // Render buttons + buttons := p.renderButtons() + + // Calculate content height dynamically based on window size + p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(buttons) - 2 - lipgloss.Height(title) + p.contentViewPort.Width = p.width - 4 + + // Render content based on tool type + var contentFinal string + switch p.permission.ToolName { + case tools.BashToolName: + contentFinal = p.renderBashContent() + case tools.EditToolName: + contentFinal = p.renderEditContent() + case tools.WriteToolName: + contentFinal = p.renderWriteContent() + case tools.FetchToolName: + contentFinal = p.renderFetchContent() default: - content := p.permission.Description - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) - - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + contentFinal = p.renderDefaultContent() + } - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } + content := lipgloss.JoinVertical( + lipgloss.Top, + title, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(title))), + headerContent, + contentFinal, + buttons, + ) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + return styles.BaseStyle. + Padding(1, 0, 0, 1). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(p.width). + Height(p.height). + Render( + content, ) - } } func (p *permissionDialogCmp) View() string { return p.render() } -func (p *permissionDialogCmp) GetSize() (int, int) { - return p.width, p.height +func (p *permissionDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(helpKeys) } -func (p *permissionDialogCmp) SetSize(width int, height int) { - p.width = width - p.height = height - p.form = p.form.WithWidth(width) +func (p *permissionDialogCmp) SetSize() { + if p.permission.ID == "" { + return + } + switch p.permission.ToolName { + case tools.BashToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + case tools.EditToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.WriteToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.FetchToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + default: + p.width = int(float64(p.windowSize.Width) * 0.7) + p.height = int(float64(p.windowSize.Height) * 0.5) + } } -func (p *permissionDialogCmp) BindingKeys() []key.Binding { - return p.form.KeyBinds() +func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) { + p.permission = permission + p.SetSize() } -func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog { - // Create a note field for displaying the content +// Helper to get or set cached diff content +func (c *permissionDialogCmp) GetOrSetDiff(key string, generator func() (string, error)) string { + if cached, ok := c.diffCache[key]; ok { + return cached + } - // Create select field for the permission options - selectOption := huh.NewSelect[string](). - Key("action"). - Options( - huh.NewOption("Allow", string(PermissionAllow)), - huh.NewOption("Allow for this session", string(PermissionAllowForSession)), - huh.NewOption("Deny", string(PermissionDeny)), - ). - Title("Select an action") + content, err := generator() + if err != nil { + return fmt.Sprintf("Error formatting diff: %v", err) + } - // Apply theme - theme := styles.HuhTheme() + c.diffCache[key] = content - // Setup form width and height - form := huh.NewForm(huh.NewGroup(selectOption)). - WithShowHelp(false). - WithTheme(theme). - WithShowErrors(false) + return content +} - // Focus the form for immediate interaction - selectOption.Focus() +// Helper to get or set cached markdown content +func (c *permissionDialogCmp) GetOrSetMarkdown(key string, generator func() (string, error)) string { + if cached, ok := c.markdownCache[key]; ok { + return cached + } - return &permissionDialogCmp{ - permission: permission, - form: form, - selectOption: selectOption, + content, err := generator() + if err != nil { + return fmt.Sprintf("Error rendering markdown: %v", err) } -} -// NewPermissionDialogCmd creates a new permission dialog command -func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd { - permDialog := newPermissionDialogCmp(permission) - - // Create the dialog layout - dialogPane := layout.NewSinglePane( - permDialog.(*permissionDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - layout.WithSinglePaneBorderText(map[layout.BorderPosition]string{ - layout.TopMiddleBorder: " Permission Required ", - }), - ) + c.markdownCache[key] = content - // Focus the dialog - dialogPane.Focus() - widthRatio := 0.7 - heightRatio := 0.6 - minWidth := 100 - minHeight := 30 + return content +} - // Make the dialog size more appropriate for different tools - switch permission.ToolName { - case tools.BashToolName: - // For bash commands, use a more compact dialog - widthRatio = 0.7 - heightRatio = 0.4 // Reduced from 0.5 - minWidth = 100 - minHeight = 20 // Reduced from 30 +func NewPermissionDialogCmp() PermissionDialogCmp { + // Create viewport for content + contentViewport := viewport.New(0, 0) + + return &permissionDialogCmp{ + contentViewPort: contentViewport, + selectedOption: 0, // Default to "Allow" + diffCache: make(map[string]string), + markdownCache: make(map[string]string), } - // Return the dialog command - return util.CmdHandler(core.DialogMsg{ - Content: dialogPane, - WidthRatio: widthRatio, - HeightRatio: heightRatio, - MinWidth: minWidth, - MinHeight: minHeight, - }) } diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 60c1fc0d2..10d9ba8a2 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -1,28 +1,58 @@ package dialog import ( + "strings" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" + "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) const question = "Are you sure you want to quit?" +type CloseQuitMsg struct{} + type QuitDialog interface { tea.Model - layout.Sizeable layout.Bindings } type quitDialogCmp struct { - form *huh.Form - width int - height int + selectedNo bool +} + +type helpMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Yes key.Binding + No key.Binding + Tab key.Binding +} + +var helpKeys = helpMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Yes: key.NewBinding( + key.WithKeys("y", "Y"), + key.WithHelp("y/Y", "yes"), + ), + No: key.NewBinding( + key.WithKeys("n", "N"), + key.WithHelp("n/N", "no"), + ), + Tab: key.NewBinding( + key.WithKeys("tab"), + key.WithHelp("tab", "switch options"), + ), } func (q *quitDialogCmp) Init() tea.Cmd { @@ -30,77 +60,73 @@ func (q *quitDialogCmp) Init() tea.Cmd { } func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - form, cmd := q.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - q.form = f - cmds = append(cmds, cmd) - } - - if q.form.State == huh.StateCompleted { - v := q.form.GetBool("quit") - if v { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, helpKeys.LeftRight) || key.Matches(msg, helpKeys.Tab): + q.selectedNo = !q.selectedNo + return q, nil + case key.Matches(msg, helpKeys.EnterSpace): + if !q.selectedNo { + return q, tea.Quit + } + return q, util.CmdHandler(CloseQuitMsg{}) + case key.Matches(msg, helpKeys.Yes): return q, tea.Quit + case key.Matches(msg, helpKeys.No): + return q, util.CmdHandler(CloseQuitMsg{}) } - cmds = append(cmds, util.CmdHandler(core.DialogCloseMsg{})) } - - return q, tea.Batch(cmds...) + return q, nil } func (q *quitDialogCmp) View() string { - return q.form.View() -} + yesStyle := styles.BaseStyle + noStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + if q.selectedNo { + noStyle = noStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + yesStyle = yesStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } else { + yesStyle = yesStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + noStyle = noStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } -func (q *quitDialogCmp) GetSize() (int, int) { - return q.width, q.height -} + yesButton := yesStyle.Padding(0, 1).Render("Yes") + noButton := noStyle.Padding(0, 1).Render("No") + + buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, spacerStyle.Render(" "), noButton) + + width := lipgloss.Width(question) + remainingWidth := width - lipgloss.Width(buttons) + if remainingWidth > 0 { + buttons = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + buttons + } -func (q *quitDialogCmp) SetSize(width int, height int) { - q.width = width - q.height = height - q.form = q.form.WithWidth(width).WithHeight(height) + content := styles.BaseStyle.Render( + lipgloss.JoinVertical( + lipgloss.Center, + question, + "", + buttons, + ), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) } func (q *quitDialogCmp) BindingKeys() []key.Binding { - return q.form.KeyBinds() + return layout.KeyMapToSlice(helpKeys) } -func newQuitDialogCmp() QuitDialog { - confirm := huh.NewConfirm(). - Title(question). - Affirmative("Yes!"). - Key("quit"). - Negative("No.") - - theme := styles.HuhTheme() - theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning) - theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning) - form := huh.NewForm(huh.NewGroup(confirm)). - WithShowHelp(false). - WithWidth(0). - WithHeight(0). - WithTheme(theme). - WithShowErrors(false) - confirm.Focus() +func NewQuitCmp() QuitDialog { return &quitDialogCmp{ - form: form, + selectedNo: true, } } - -func NewQuitDialogCmd() tea.Cmd { - content := layout.NewSinglePane( - newQuitDialogCmp().(*quitDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - ) - content.Focus() - return util.CmdHandler(core.DialogMsg{ - Content: content, - WidthRatio: 0.2, - HeightRatio: 0.1, - MinWidth: 40, - MinHeight: 5, - }) -} diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index dbace5508..18eb1a526 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -16,10 +16,8 @@ import ( type DetailComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type detailCmp struct { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 9500059b1..6e8eb58b1 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -16,22 +16,14 @@ import ( type TableComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type tableCmp struct { table table.Model } -func (i *tableCmp) BorderText() map[layout.BorderPosition]string { - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: "Logs", - } -} - type selectedLogMsg logging.LogMessage func (i *tableCmp) Init() tea.Cmd { @@ -74,20 +66,6 @@ func (i *tableCmp) View() string { return i.table.View() } -func (i *tableCmp) Blur() tea.Cmd { - i.table.Blur() - return nil -} - -func (i *tableCmp) Focus() tea.Cmd { - i.table.Focus() - return nil -} - -func (i *tableCmp) IsFocused() bool { - return i.table.Focused() -} - func (i *tableCmp) GetSize() (int, int) { return i.table.Width(), i.table.Height() } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go deleted file mode 100644 index b659775e0..000000000 --- a/internal/tui/components/repl/editor.go +++ /dev/null @@ -1,201 +0,0 @@ -package repl - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" - "golang.org/x/net/context" -) - -type EditorCmp interface { - tea.Model - layout.Focusable - layout.Sizeable - layout.Bordered - layout.Bindings -} - -type editorCmp struct { - app *app.App - editor vimtea.Editor - editorMode vimtea.EditorMode - sessionID string - focused bool - width int - height int - cancelMessage context.CancelFunc -} - -type editorKeyMap struct { - SendMessage key.Binding - SendMessageI key.Binding - CancelMessage key.Binding - InsertMode key.Binding - NormaMode key.Binding - VisualMode key.Binding - VisualLineMode key.Binding -} - -var editorKeyMapValue = editorKeyMap{ - SendMessage: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "send message normal mode"), - ), - SendMessageI: key.NewBinding( - key.WithKeys("ctrl+s"), - key.WithHelp("ctrl+s", "send message insert mode"), - ), - CancelMessage: key.NewBinding( - key.WithKeys("ctrl+x"), - key.WithHelp("ctrl+x", "cancel current message"), - ), - InsertMode: key.NewBinding( - key.WithKeys("i"), - key.WithHelp("i", "insert mode"), - ), - NormaMode: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "normal mode"), - ), - VisualMode: key.NewBinding( - key.WithKeys("v"), - key.WithHelp("v", "visual mode"), - ), - VisualLineMode: key.NewBinding( - key.WithKeys("V"), - key.WithHelp("V", "visual line mode"), - ), -} - -func (m *editorCmp) Init() tea.Cmd { - return m.editor.Init() -} - -func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case vimtea.EditorModeMsg: - m.editorMode = msg.Mode - case SelectedSessionMsg: - if msg.SessionID != m.sessionID { - m.sessionID = msg.SessionID - } - } - if m.IsFocused() { - switch msg := msg.(type) { - case tea.KeyMsg: - switch { - case key.Matches(msg, editorKeyMapValue.SendMessage): - if m.editorMode == vimtea.ModeNormal { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.SendMessageI): - if m.editorMode == vimtea.ModeInsert { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.CancelMessage): - return m, m.Cancel() - } - } - u, cmd := m.editor.Update(msg) - m.editor = u.(vimtea.Editor) - return m, cmd - } - return m, nil -} - -func (m *editorCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *editorCmp) BorderText() map[layout.BorderPosition]string { - title := "New Message" - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.BottomLeftBorder: title, - } -} - -func (m *editorCmp) Focus() tea.Cmd { - m.focused = true - return m.editor.Tick() -} - -func (m *editorCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *editorCmp) IsFocused() bool { - return m.focused -} - -func (m *editorCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.editor.SetSize(width, height) -} - -func (m *editorCmp) Cancel() tea.Cmd { - if m.cancelMessage == nil { - return util.ReportWarn("No message to cancel") - } - - m.cancelMessage() - m.cancelMessage = nil - return util.ReportWarn("Message cancelled") -} - -func (m *editorCmp) Send() tea.Cmd { - if m.cancelMessage != nil { - return util.ReportWarn("Assistant is still working on the previous message") - } - - messages, err := m.app.Messages.List(context.Background(), m.sessionID) - if err != nil { - return util.ReportError(err) - } - if hasUnfinishedMessages(messages) { - return util.ReportWarn("Assistant is still working on the previous message") - } - - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") - if len(content) == 0 { - return util.ReportWarn("Message is empty") - } - ctx, cancel := context.WithCancel(context.Background()) - m.cancelMessage = cancel - go func() { - defer cancel() - m.app.CoderAgent.Generate(ctx, m.sessionID, content) - m.cancelMessage = nil - }() - - return m.editor.Reset() -} - -func (m *editorCmp) View() string { - return m.editor.View() -} - -func (m *editorCmp) BindingKeys() []key.Binding { - return layout.KeyMapToSlice(editorKeyMapValue) -} - -func NewEditorCmp(app *app.App) EditorCmp { - editor := vimtea.NewEditor( - vimtea.WithFileName("message.md"), - ) - return &editorCmp{ - app: app, - editor: editor, - } -} diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go deleted file mode 100644 index 260be220e..000000000 --- a/internal/tui/components/repl/messages.go +++ /dev/null @@ -1,513 +0,0 @@ -package repl - -import ( - "context" - "encoding/json" - "fmt" - "sort" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/viewport" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/glamour" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type MessagesCmp interface { - tea.Model - layout.Focusable - layout.Bordered - layout.Sizeable - layout.Bindings -} - -type messagesCmp struct { - app *app.App - messages []message.Message - selectedMsgIdx int // Index of the selected message - session session.Session - viewport viewport.Model - mdRenderer *glamour.TermRenderer - width int - height int - focused bool - cachedView string -} - -func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case pubsub.Event[message.Message]: - if msg.Type == pubsub.CreatedEvent { - if msg.Payload.SessionID == m.session.ID { - m.messages = append(m.messages, msg.Payload) - m.renderView() - m.viewport.GotoBottom() - } - for _, v := range m.messages { - for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called - if c.ID == msg.Payload.SessionID { - m.renderView() - m.viewport.GotoBottom() - } - } - } - } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { - for i, v := range m.messages { - if v.ID == msg.Payload.ID { - m.messages[i] = msg.Payload - m.renderView() - if i == len(m.messages)-1 { - m.viewport.GotoBottom() - } - break - } - } - } - case pubsub.Event[session.Session]: - if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID { - m.session = msg.Payload - } - case SelectedSessionMsg: - m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID) - m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID) - m.renderView() - m.viewport.GotoBottom() - } - if m.focused { - u, cmd := m.viewport.Update(msg) - m.viewport = u - return m, cmd - } - return m, nil -} - -func borderColor(role message.MessageRole) lipgloss.TerminalColor { - switch role { - case message.Assistant: - return styles.Mauve - case message.User: - return styles.Rosewater - } - return styles.Blue -} - -func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string { - role := "" - icon := "" - switch msgRole { - case message.Assistant: - role = "Assistant" - icon = styles.BotIcon - case message.User: - role = "User" - icon = styles.UserIcon - } - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("%s %s ", role, icon)), - layout.TopRightBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("#%d ", currentMessage)), - } -} - -func hasUnfinishedMessages(messages []message.Message) bool { - if len(messages) == 0 { - return false - } - for _, msg := range messages { - if !msg.IsFinished() { - return true - } - } - return false -} - -func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string { - allParts := []string{content} - - leftPaddingValue := 4 - connectorStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - toolCallStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Peach). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - toolResultStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Green). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue) - - runningStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - renderTool := func(toolCall message.ToolCall) string { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := m.width - leftPaddingValue*2 - 10 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - return toolCallStyle.Render(toolContent) - } - - findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult { - for _, msg := range messages { - if msg.Role == message.Tool { - for _, result := range msg.ToolResults() { - if result.ToolCallID == toolCallID { - return &result - } - } - } - } - return nil - } - - renderToolResult := func(result message.ToolResult) string { - resultHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Green). - Render(fmt.Sprintf("%s Result", styles.CheckIcon)) - - // Use the same style for both header and border if it's an error - borderColor := styles.Green - if result.IsError { - resultHeader = lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Red). - Render(fmt.Sprintf("%s Error", styles.ErrorIcon)) - borderColor = styles.Red - } - - truncate := 200 - content := result.Content - if len(content) > truncate { - content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - - resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content) - return toolResultStyle.BorderForeground(borderColor).Render(resultContent) - } - - connector := connectorStyle.Render("└─> Tool Calls:") - allParts = append(allParts, connector) - - for _, toolCall := range tools { - toolOutput := renderTool(toolCall) - allParts = append(allParts, leftPadding.Render(toolOutput)) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - - resultOutput := renderToolResult(*result) - allParts = append(allParts, leftPadding.Render(resultOutput)) - - } else if toolCall.Name == agent.AgentToolName { - - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, leftPadding.Render(runningIndicator)) - taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID) - for _, msg := range taskSessionMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := 50 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent) - allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput)) - } - } - } - - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - - for _, msg := range futureMessages { - if msg.Content().String() != "" || msg.FinishReason() == "canceled" { - break - } - - for _, toolCall := range msg.ToolCalls() { - toolOutput := renderTool(toolCall) - allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n ")) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - resultOutput := renderToolResult(*result) - allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n ")) - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - } - - return lipgloss.JoinVertical(lipgloss.Left, allParts...) -} - -func (m *messagesCmp) renderView() { - stringMessages := make([]string, 0) - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(m.width-20), - glamour.WithEmoji(), - ) - textStyle := lipgloss.NewStyle().Width(m.width - 4) - currentMessage := 1 - displayedMsgCount := 0 // Track the actual displayed messages count - - prevMessageWasUser := false - for inx, msg := range m.messages { - content := msg.Content().String() - if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" { - if msg.ReasoningContent().String() != "" && content == "" { - content = msg.ReasoningContent().String() - } else if content == "" { - content = "..." - } - if msg.FinishReason() == "canceled" { - content, _ = r.Render(content) - content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled") - } else { - content, _ = r.Render(content) - } - - isSelected := inx == m.selectedMsgIdx - - border := lipgloss.DoubleBorder() - activeColor := borderColor(msg.Role) - - if isSelected { - activeColor = styles.Primary // Use primary color for selected message - } - - content = layout.Borderize( - textStyle.Render(content), - layout.BorderOptions{ - InactiveBorder: border, - ActiveBorder: border, - ActiveColor: activeColor, - InactiveColor: borderColor(msg.Role), - EmbeddedText: borderText(msg.Role, currentMessage), - }, - ) - if len(msg.ToolCalls()) > 0 { - content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:]) - } - stringMessages = append(stringMessages, content) - currentMessage++ - displayedMsgCount++ - } - if msg.Role == message.User && msg.Content().String() != "" { - prevMessageWasUser = true - } else { - prevMessageWasUser = false - } - } - m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...)) -} - -func (m *messagesCmp) View() string { - return lipgloss.NewStyle().Padding(1).Render(m.viewport.View()) -} - -func (m *messagesCmp) BindingKeys() []key.Binding { - keys := layout.KeyMapToSlice(m.viewport.KeyMap) - - return keys -} - -func (m *messagesCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *messagesCmp) projectDiagnostics() string { - errorDiagnostics := []protocol.Diagnostic{} - warnDiagnostics := []protocol.Diagnostic{} - hintDiagnostics := []protocol.Diagnostic{} - infoDiagnostics := []protocol.Diagnostic{} - for _, client := range m.app.LSPClients { - for _, d := range client.GetDiagnostics() { - for _, diag := range d { - switch diag.Severity { - case protocol.SeverityError: - errorDiagnostics = append(errorDiagnostics, diag) - case protocol.SeverityWarning: - warnDiagnostics = append(warnDiagnostics, diag) - case protocol.SeverityHint: - hintDiagnostics = append(hintDiagnostics, diag) - case protocol.SeverityInformation: - infoDiagnostics = append(infoDiagnostics, diag) - } - } - } - } - - if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { - return "No diagnostics" - } - - diagnostics := []string{} - - if len(errorDiagnostics) > 0 { - errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) - diagnostics = append(diagnostics, errStr) - } - if len(warnDiagnostics) > 0 { - warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) - diagnostics = append(diagnostics, warnStr) - } - if len(hintDiagnostics) > 0 { - hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) - diagnostics = append(diagnostics, hintStr) - } - if len(infoDiagnostics) > 0 { - infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) - diagnostics = append(diagnostics, infoStr) - } - - return strings.Join(diagnostics, " ") -} - -func (m *messagesCmp) BorderText() map[layout.BorderPosition]string { - title := m.session.Title - titleWidth := m.width / 2 - if len(title) > titleWidth { - title = title[:titleWidth] + "..." - } - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - borderTest := map[layout.BorderPosition]string{ - layout.TopLeftBorder: title, - layout.BottomRightBorder: m.projectDiagnostics(), - } - if hasUnfinishedMessages(m.messages) { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...") - } else { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ") - } - - return borderTest -} - -func (m *messagesCmp) Focus() tea.Cmd { - m.focused = true - return nil -} - -func (m *messagesCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *messagesCmp) IsFocused() bool { - return m.focused -} - -func (m *messagesCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.viewport.Width = width - 2 // padding - m.viewport.Height = height - 2 // padding - m.renderView() -} - -func (m *messagesCmp) Init() tea.Cmd { - return nil -} - -func NewMessagesCmp(app *app.App) MessagesCmp { - return &messagesCmp{ - app: app, - messages: []message.Message{}, - viewport: viewport.New(0, 0), - } -} diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go deleted file mode 100644 index c83c40367..000000000 --- a/internal/tui/components/repl/sessions.go +++ /dev/null @@ -1,249 +0,0 @@ -package repl - -import ( - "context" - "fmt" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/list" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SessionsCmp interface { - tea.Model - layout.Sizeable - layout.Focusable - layout.Bordered - layout.Bindings -} -type sessionsCmp struct { - app *app.App - list list.Model - focused bool -} - -type listItem struct { - id, title, desc string -} - -func (i listItem) Title() string { return i.title } -func (i listItem) Description() string { return i.desc } -func (i listItem) FilterValue() string { return i.title } - -type InsertSessionsMsg struct { - sessions []session.Session -} - -type SelectedSessionMsg struct { - SessionID string -} - -type sessionsKeyMap struct { - Select key.Binding -} - -var sessionKeyMapValue = sessionsKeyMap{ - Select: key.NewBinding( - key.WithKeys("enter", " "), - key.WithHelp("enter/space", "select session"), - ), -} - -func (i *sessionsCmp) Init() tea.Cmd { - existing, err := i.app.Sessions.List(context.Background()) - if err != nil { - return util.ReportError(err) - } - if len(existing) == 0 || existing[0].MessageCount > 0 { - newSession, err := i.app.Sessions.Create( - context.Background(), - "New Session", - ) - if err != nil { - return util.ReportError(err) - } - existing = append([]session.Session{newSession}, existing...) - } - return tea.Batch( - util.CmdHandler(InsertSessionsMsg{existing}), - util.CmdHandler(SelectedSessionMsg{existing[0].ID}), - ) -} - -func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case InsertSessionsMsg: - items := make([]list.Item, len(msg.sessions)) - for i, s := range msg.sessions { - items[i] = listItem{ - id: s.ID, - title: s.Title, - desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost), - } - } - return i, i.list.SetItems(items) - case pubsub.Event[session.Session]: - if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" { - // Check if the session is already in the list - items := i.list.Items() - for _, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - return i, nil - } - } - // insert the new session at the top of the list - items = append([]list.Item{listItem{ - id: msg.Payload.ID, - title: msg.Payload.Title, - desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost), - }}, items...) - return i, i.list.SetItems(items) - } else if msg.Type == pubsub.UpdatedEvent { - // update the session in the list - items := i.list.Items() - for idx, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - s.title = msg.Payload.Title - s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost) - items[idx] = s - break - } - } - return i, i.list.SetItems(items) - } - - case tea.KeyMsg: - switch { - case key.Matches(msg, sessionKeyMapValue.Select): - selected := i.list.SelectedItem() - if selected == nil { - return i, nil - } - return i, util.CmdHandler(SelectedSessionMsg{selected.(listItem).id}) - } - } - if i.focused { - u, cmd := i.list.Update(msg) - i.list = u - return i, cmd - } - return i, nil -} - -func (i *sessionsCmp) View() string { - return i.list.View() -} - -func (i *sessionsCmp) Blur() tea.Cmd { - i.focused = false - return nil -} - -func (i *sessionsCmp) Focus() tea.Cmd { - i.focused = true - return nil -} - -func (i *sessionsCmp) GetSize() (int, int) { - return i.list.Width(), i.list.Height() -} - -func (i *sessionsCmp) IsFocused() bool { - return i.focused -} - -func (i *sessionsCmp) SetSize(width int, height int) { - i.list.SetSize(width, height) -} - -func (i *sessionsCmp) BorderText() map[layout.BorderPosition]string { - totalCount := len(i.list.Items()) - itemsPerPage := i.list.Paginator.PerPage - currentPage := i.list.Paginator.Page - - current := min(currentPage*itemsPerPage+itemsPerPage, totalCount) - - pageInfo := fmt.Sprintf( - "%d-%d of %d", - currentPage*itemsPerPage+1, - current, - totalCount, - ) - - title := "Sessions" - if i.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.TopMiddleBorder: title, - layout.BottomMiddleBorder: pageInfo, - } -} - -func (i *sessionsCmp) BindingKeys() []key.Binding { - return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select) -} - -func formatTokensAndCost(tokens int64, cost float64) string { - // Format tokens in human-readable format (e.g., 110K, 1.2M) - var formattedTokens string - switch { - case tokens >= 1_000_000: - formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000) - case tokens >= 1_000: - formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000) - default: - formattedTokens = fmt.Sprintf("%d", tokens) - } - - // Remove .0 suffix if present - if strings.HasSuffix(formattedTokens, ".0K") { - formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1) - } - if strings.HasSuffix(formattedTokens, ".0M") { - formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1) - } - - // Format cost with $ symbol and 2 decimal places - formattedCost := fmt.Sprintf("$%.2f", cost) - - return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost) -} - -func NewSessionsCmp(app *app.App) SessionsCmp { - listDelegate := list.NewDefaultDelegate() - defaultItemStyle := list.NewDefaultItemStyles() - defaultItemStyle.SelectedTitle = defaultItemStyle.SelectedTitle.BorderForeground(styles.Secondary).Foreground(styles.Primary) - defaultItemStyle.SelectedDesc = defaultItemStyle.SelectedDesc.BorderForeground(styles.Secondary).Foreground(styles.Primary) - - defaultStyle := list.DefaultStyles() - defaultStyle.FilterPrompt = defaultStyle.FilterPrompt.Foreground(styles.Secondary) - defaultStyle.FilterCursor = defaultStyle.FilterCursor.Foreground(styles.Flamingo) - - listDelegate.Styles = defaultItemStyle - - listComponent := list.New([]list.Item{}, listDelegate, 0, 0) - listComponent.FilterInput.PromptStyle = defaultStyle.FilterPrompt - listComponent.FilterInput.Cursor.Style = defaultStyle.FilterCursor - listComponent.SetShowTitle(false) - listComponent.SetShowPagination(false) - listComponent.SetShowHelp(false) - listComponent.SetShowStatusBar(false) - listComponent.DisableQuitKeybindings() - - return &sessionsCmp{ - app: app, - list: listComponent, - focused: false, - } -} diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 22f9e00fe..4a1bcf661 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" @@ -45,13 +46,15 @@ func PlaceOverlay( if shadow { var shadowbg string = "" shadowchar := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#333333")). + Background(styles.BackgroundDarker). + Foreground(styles.Background). Render("░") + bgchar := styles.BaseStyle.Render(" ") for i := 0; i <= fgHeight; i++ { if i == 0 { - shadowbg += " " + strings.Repeat(" ", fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(bgchar, fgWidth) + "\n" } else { - shadowbg += " " + strings.Repeat(shadowchar, fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(shadowchar, fgWidth) + "\n" } } @@ -159,8 +162,6 @@ func max(a, b int) int { return b } - - type whitespace struct { style termenv.Style chars string diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 0ed85dd6f..6482fc74c 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -10,6 +10,7 @@ import ( type SplitPaneLayout interface { tea.Model Sizeable + Bindings SetLeftPanel(panel Container) SetRightPanel(panel Container) SetBottomPanel(panel Container) diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 439c89e1f..cebc0e461 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -37,7 +37,6 @@ var keyMap = ChatKeyMap{ } func (p *chatPage) Init() tea.Cmd { - // TODO: remove cmds := []tea.Cmd{ p.layout.Init(), } @@ -48,9 +47,7 @@ func (p *chatPage) Init() tea.Cmd { cmd := p.setSidebar() cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) } - return tea.Batch( - cmds..., - ) + return tea.Batch(cmds...) } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -68,6 +65,13 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.session = session.Session{} p.clearSidebar() return p, util.CmdHandler(chat.SessionClearedMsg{}) + case key.Matches(msg, keyMap.Cancel): + if p.session.ID != "" { + // Cancel the current session's generation process + // This allows users to interrupt long-running operations + p.app.CoderAgent.Cancel(p.session.ID) + return p, nil + } } } u, cmd := p.layout.Update(msg) @@ -80,7 +84,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (p *chatPage) setSidebar() tea.Cmd { sidebarContainer := layout.NewContainer( - chat.NewSidebarCmp(p.session), + chat.NewSidebarCmp(p.session, p.app.History), layout.WithPadding(1, 1, 1, 1), ) p.layout.SetRightPanel(sidebarContainer) @@ -111,14 +115,28 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - p.app.CoderAgent.Generate(context.Background(), p.session.ID, text) + p.app.CoderAgent.Run(context.Background(), p.session.ID, text) return tea.Batch(cmds...) } +func (p *chatPage) SetSize(width, height int) { + p.layout.SetSize(width, height) +} + +func (p *chatPage) GetSize() (int, int) { + return p.layout.GetSize() +} + func (p *chatPage) View() string { return p.layout.View() } +func (p *chatPage) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(keyMap) + bindings = append(bindings, p.layout.BindingKeys()...) + return bindings +} + func NewChatPage(app *app.App) tea.Model { messagesContainer := layout.NewContainer( chat.NewMessagesCmp(app), @@ -126,7 +144,7 @@ func NewChatPage(app *app.App) tea.Model { ) editorContainer := layout.NewContainer( - chat.NewEditorCmp(), + chat.NewEditorCmp(app), layout.WithBorder(true, false, false, false), ) return &chatPage{ diff --git a/internal/tui/page/init.go b/internal/tui/page/init.go deleted file mode 100644 index 0a5c6f82a..000000000 --- a/internal/tui/page/init.go +++ /dev/null @@ -1,308 +0,0 @@ -package page - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/huh" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/spf13/viper" -) - -var InitPage PageID = "init" - -type configSaved struct{} - -type initPage struct { - form *huh.Form - width int - height int - saved bool - errorMsg string - statusMsg string - modelOpts []huh.Option[string] - bigModel string - smallModel string - openAIKey string - anthropicKey string - groqKey string - maxTokens string - dataDir string - agent string -} - -func (i *initPage) Init() tea.Cmd { - return i.form.Init() -} - -func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.WindowSizeMsg: - i.width = msg.Width - 4 // Account for border - i.height = msg.Height - 4 - i.form = i.form.WithWidth(i.width).WithHeight(i.height) - return i, nil - - case configSaved: - i.saved = true - i.statusMsg = "Configuration saved successfully. Press any key to continue." - return i, nil - } - - if i.saved { - switch msg.(type) { - case tea.KeyMsg: - return i, util.CmdHandler(PageChangeMsg{ID: ReplPage}) - } - return i, nil - } - - // Process the form - form, cmd := i.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - i.form = f - cmds = append(cmds, cmd) - } - - if i.form.State == huh.StateCompleted { - // Save configuration to file - configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml") - maxTokens, _ := strconv.Atoi(i.maxTokens) - config := map[string]any{ - "models": map[string]string{ - "big": i.bigModel, - "small": i.smallModel, - }, - "providers": map[string]any{ - "openai": map[string]string{ - "key": i.openAIKey, - }, - "anthropic": map[string]string{ - "key": i.anthropicKey, - }, - "groq": map[string]string{ - "key": i.groqKey, - }, - "common": map[string]int{ - "max_tokens": maxTokens, - }, - }, - "data": map[string]string{ - "dir": i.dataDir, - }, - "agents": map[string]string{ - "default": i.agent, - }, - "log": map[string]string{ - "level": "info", - }, - } - - // Write config to viper - for k, v := range config { - viper.Set(k, v) - } - - // Save configuration - err := viper.WriteConfigAs(configPath) - if err != nil { - i.errorMsg = fmt.Sprintf("Failed to save configuration: %s", err) - return i, nil - } - - // Return to main page - return i, util.CmdHandler(configSaved{}) - } - - return i, tea.Batch(cmds...) -} - -func (i *initPage) View() string { - if i.saved { - return lipgloss.NewStyle(). - Width(i.width). - Height(i.height). - Align(lipgloss.Center, lipgloss.Center). - Render(lipgloss.JoinVertical( - lipgloss.Center, - lipgloss.NewStyle().Foreground(styles.Green).Render("✓ Configuration Saved"), - "", - lipgloss.NewStyle().Foreground(styles.Blue).Render(i.statusMsg), - )) - } - - view := i.form.View() - if i.errorMsg != "" { - errorBox := lipgloss.NewStyle(). - Padding(1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Red). - Width(i.width - 4). - Render(i.errorMsg) - view = lipgloss.JoinVertical(lipgloss.Left, errorBox, view) - } - return view -} - -func (i *initPage) GetSize() (int, int) { - return i.width, i.height -} - -func (i *initPage) SetSize(width int, height int) { - i.width = width - i.height = height - i.form = i.form.WithWidth(width).WithHeight(height) -} - -func (i *initPage) BindingKeys() []key.Binding { - if i.saved { - return []key.Binding{ - key.NewBinding( - key.WithKeys("enter", "space", "esc"), - key.WithHelp("any key", "continue"), - ), - } - } - return i.form.KeyBinds() -} - -func NewInitPage() tea.Model { - // Create model options - var modelOpts []huh.Option[string] - for id, model := range models.SupportedModels { - modelOpts = append(modelOpts, huh.NewOption(model.Name, string(id))) - } - - // Create agent options - agentOpts := []huh.Option[string]{ - huh.NewOption("Coder", "coder"), - huh.NewOption("Assistant", "assistant"), - } - - // Init page with form - initModel := &initPage{ - modelOpts: modelOpts, - bigModel: string(models.Claude37Sonnet), - smallModel: string(models.Claude37Sonnet), - maxTokens: "4000", - dataDir: ".termai", - agent: "coder", - } - - // API Keys group - apiKeysGroup := huh.NewGroup( - huh.NewNote(). - Title("API Keys"). - Description("You need to provide at least one API key to use termai"), - - huh.NewInput(). - Title("OpenAI API Key"). - Placeholder("sk-..."). - Key("openai_key"). - Value(&initModel.openAIKey), - - huh.NewInput(). - Title("Anthropic API Key"). - Placeholder("sk-ant-..."). - Key("anthropic_key"). - Value(&initModel.anthropicKey), - - huh.NewInput(). - Title("Groq API Key"). - Placeholder("gsk_..."). - Key("groq_key"). - Value(&initModel.groqKey), - ) - - // Model configuration group - modelsGroup := huh.NewGroup( - huh.NewNote(). - Title("Model Configuration"). - Description("Select which models to use"), - - huh.NewSelect[string](). - Title("Big Model"). - Options(modelOpts...). - Key("big_model"). - Value(&initModel.bigModel), - - huh.NewSelect[string](). - Title("Small Model"). - Options(modelOpts...). - Key("small_model"). - Value(&initModel.smallModel), - - huh.NewInput(). - Title("Max Tokens"). - Placeholder("4000"). - Key("max_tokens"). - CharLimit(5). - Validate(func(s string) error { - var n int - _, err := fmt.Sscanf(s, "%d", &n) - if err != nil || n <= 0 { - return fmt.Errorf("must be a positive number") - } - initModel.maxTokens = s - return nil - }). - Value(&initModel.maxTokens), - ) - - // General settings group - generalGroup := huh.NewGroup( - huh.NewNote(). - Title("General Settings"). - Description("Configure general termai settings"), - - huh.NewInput(). - Title("Data Directory"). - Placeholder(".termai"). - Key("data_dir"). - Value(&initModel.dataDir), - - huh.NewSelect[string](). - Title("Default Agent"). - Options(agentOpts...). - Key("agent"). - Value(&initModel.agent), - - huh.NewConfirm(). - Title("Save Configuration"). - Affirmative("Save"). - Negative("Cancel"), - ) - - // Create form with theme - form := huh.NewForm( - apiKeysGroup, - modelsGroup, - generalGroup, - ).WithTheme(styles.HuhTheme()). - WithShowHelp(true). - WithShowErrors(true) - - // Set the form in the model - initModel.form = form - - return layout.NewSinglePane( - initModel, - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneBorderText( - map[layout.BorderPosition]string{ - layout.TopMiddleBorder: "Welcome to termai - Initial Setup", - }, - ), - ) -} diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index 12afaf6aa..d1e557eab 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -8,6 +8,23 @@ import ( var LogsPage PageID = "logs" +type logsPage struct { + table logs.TableComponent + details logs.DetailComponent +} + +func (p *logsPage) Init() tea.Cmd { + return nil +} + +func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + return p, nil +} + +func (p *logsPage) View() string { + return p.table.View() + "\n" + p.details.View() +} + func NewLogsPage() tea.Model { return layout.NewBentoLayout( layout.BentoPanes{ diff --git a/internal/tui/page/repl.go b/internal/tui/page/repl.go deleted file mode 100644 index 47a924b7b..000000000 --- a/internal/tui/page/repl.go +++ /dev/null @@ -1,21 +0,0 @@ -package page - -import ( - tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" - "github.com/kujtimiihoxha/termai/internal/tui/layout" -) - -var ReplPage PageID = "repl" - -func NewReplPage(app *app.App) tea.Model { - return layout.NewBentoLayout( - layout.BentoPanes{ - layout.BentoLeftPane: repl.NewSessionsCmp(app), - layout.BentoRightTopPane: repl.NewMessagesCmp(app), - layout.BentoRightBottomPane: repl.NewEditorCmp(app), - }, - layout.WithBentoLayoutCurrentPane(layout.BentoRightBottomPane), - ) -} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1b1a1ed50..dff7ad63d 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,8 +1,6 @@ package tui import ( - "context" - "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -12,47 +10,41 @@ import ( "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/page" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" ) type keyMap struct { - Logs key.Binding - Return key.Binding - Back key.Binding - Quit key.Binding - Help key.Binding + Logs key.Binding + Quit key.Binding + Help key.Binding } var keys = keyMap{ Logs: key.NewBinding( - key.WithKeys("L"), - key.WithHelp("L", "logs"), - ), - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), - Back: key.NewBinding( - key.WithKeys("backspace"), - key.WithHelp("backspace", "back"), + key.WithKeys("ctrl+l"), + key.WithHelp("ctrl+L", "logs"), ), + Quit: key.NewBinding( - key.WithKeys("ctrl+c", "q"), - key.WithHelp("ctrl+c/q", "quit"), + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), ), Help: key.NewBinding( - key.WithKeys("?"), - key.WithHelp("?", "toggle help"), + key.WithKeys("ctrl+_"), + key.WithHelp("ctrl+?", "toggle help"), ), } -var replKeyMap = key.NewBinding( - key.WithKeys("N"), - key.WithHelp("N", "new session"), +var returnKey = key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), +) + +var logsKeyReturnKey = key.NewBinding( + key.WithKeys("backspace"), + key.WithHelp("backspace", "go back"), ) type appModel struct { @@ -62,18 +54,30 @@ type appModel struct { pages map[page.PageID]tea.Model loadedPages map[page.PageID]bool status tea.Model - help core.HelpCmp - dialog core.DialogCmp app *app.App - dialogVisible bool - editorMode vimtea.EditorMode - showHelp bool + + showPermissions bool + permissions dialog.PermissionDialogCmp + + showHelp bool + help dialog.HelpCmp + + showQuit bool + quit dialog.QuitDialog } func (a appModel) Init() tea.Cmd { + var cmds []tea.Cmd cmd := a.pages[a.currentPage].Init() a.loadedPages[a.currentPage] = true - return cmd + cmds = append(cmds, cmd) + cmd = a.status.Init() + cmds = append(cmds, cmd) + cmd = a.quit.Init() + cmds = append(cmds, cmd) + cmd = a.help.Init() + cmds = append(cmds, cmd) + return tea.Batch(cmds...) } func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -81,22 +85,20 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - var cmds []tea.Cmd msg.Height -= 1 // Make space for the status bar a.width, a.height = msg.Width, msg.Height a.status, _ = a.status.Update(msg) - - uh, _ := a.help.Update(msg) - a.help = uh.(core.HelpCmp) - - p, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) - a.pages[a.currentPage] = p - d, cmd := a.dialog.Update(msg) - cmds = append(cmds, cmd) - a.dialog = d.(core.DialogCmp) + prm, permCmd := a.permissions.Update(msg) + a.permissions = prm.(dialog.PermissionDialogCmp) + cmds = append(cmds, permCmd) + + help, helpCmd := a.help.Update(msg) + a.help = help.(dialog.HelpCmp) + cmds = append(cmds, helpCmd) return a, tea.Batch(cmds...) @@ -141,7 +143,9 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Permission case pubsub.Event[permission.PermissionRequest]: - return a, dialog.NewPermissionDialogCmd(msg.Payload) + a.showPermissions = true + a.permissions.SetPermissions(msg.Payload) + return a, nil case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: @@ -151,91 +155,71 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) } - - // Dialog - case core.DialogMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = true - return a, cmd - case core.DialogCloseMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = false - return a, cmd - - // Editor - case vimtea.EditorModeMsg: - a.editorMode = msg.Mode + a.showPermissions = false + return a, nil case page.PageChangeMsg: return a, a.moveToPage(msg.ID) + + case dialog.CloseQuitMsg: + a.showQuit = false + return a, nil + case tea.KeyMsg: - if a.editorMode == vimtea.ModeNormal { - switch { - case key.Matches(msg, keys.Quit): - return a, dialog.NewQuitDialogCmd() - case key.Matches(msg, keys.Back): - if a.previousPage != "" { - return a, a.moveToPage(a.previousPage) - } - case key.Matches(msg, keys.Return): - if a.showHelp { - a.ToggleHelp() - return a, nil - } - case key.Matches(msg, replKeyMap): - if a.currentPage == page.ReplPage { - sessions, err := a.app.Sessions.List(context.Background()) - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - lastSession := sessions[0] - if lastSession.MessageCount == 0 { - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID}) - } - s, err := a.app.Sessions.Create(context.Background(), "New Session") - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID}) - } - // case key.Matches(msg, keys.Logs): - // return a, a.moveToPage(page.LogsPage) - case msg.String() == "O": - return a, a.moveToPage(page.ReplPage) - case key.Matches(msg, keys.Help): - a.ToggleHelp() + switch { + case key.Matches(msg, keys.Quit): + a.showQuit = !a.showQuit + if a.showHelp { + a.showHelp = false + } + return a, nil + case key.Matches(msg, logsKeyReturnKey): + if a.currentPage == page.LogsPage { + return a, a.moveToPage(page.ChatPage) + } + case key.Matches(msg, returnKey): + if a.showQuit { + a.showQuit = !a.showQuit + return a, nil + } + if a.showHelp { + a.showHelp = !a.showHelp + return a, nil + } + case key.Matches(msg, keys.Logs): + return a, a.moveToPage(page.LogsPage) + case key.Matches(msg, keys.Help): + if a.showQuit { return a, nil } + a.showHelp = !a.showHelp + return a, nil } } - if a.dialogVisible { - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - cmds = append(cmds, cmd) - return a, tea.Batch(cmds...) + if a.showQuit { + q, quitCmd := a.quit.Update(msg) + a.quit = q.(dialog.QuitDialog) + cmds = append(cmds, quitCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + if a.showPermissions { + d, permissionsCmd := a.permissions.Update(msg) + a.permissions = d.(dialog.PermissionDialogCmp) + cmds = append(cmds, permissionsCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } } a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) } -func (a *appModel) ToggleHelp() { - if a.showHelp { - a.showHelp = false - a.height += a.help.Height() - } else { - a.showHelp = true - a.height -= a.help.Height() - } - - if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok { - sizable.SetSize(a.width, a.height) - } -} - func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { var cmd tea.Cmd if _, ok := a.loadedPages[pageID]; !ok { @@ -256,27 +240,55 @@ func (a appModel) View() string { a.pages[a.currentPage].View(), } + components = append(components, a.status.View()) + + appView := lipgloss.JoinVertical(lipgloss.Top, components...) + + if a.showPermissions { + overlay := a.permissions.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + if a.showHelp { bindings := layout.KeyMapToSlice(keys) if p, ok := a.pages[a.currentPage].(layout.Bindings); ok { bindings = append(bindings, p.BindingKeys()...) } - if a.dialogVisible { - bindings = append(bindings, a.dialog.BindingKeys()...) + if a.showPermissions { + bindings = append(bindings, a.permissions.BindingKeys()...) } - if a.currentPage == page.ReplPage { - bindings = append(bindings, replKeyMap) + if a.currentPage == page.LogsPage { + bindings = append(bindings, logsKeyReturnKey) } - a.help.SetBindings(bindings) - components = append(components, a.help.View()) - } - components = append(components, a.status.View()) + a.help.SetBindings(bindings) - appView := lipgloss.JoinVertical(lipgloss.Top, components...) + overlay := a.help.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } - if a.dialogVisible { - overlay := a.dialog.View() + if a.showQuit { + overlay := a.quit.View() row := lipgloss.Height(appView) / 2 row -= lipgloss.Height(overlay) / 2 col := lipgloss.Width(appView) / 2 @@ -289,30 +301,23 @@ func (a appModel) View() string { true, ) } + return appView } func New(app *app.App) tea.Model { - // homedir, _ := os.UserHomeDir() - // configPath := filepath.Join(homedir, ".termai.yaml") - // startPage := page.ChatPage - // if _, err := os.Stat(configPath); os.IsNotExist(err) { - // startPage = page.InitPage - // } - return &appModel{ currentPage: startPage, loadedPages: make(map[page.PageID]bool), - status: core.NewStatusCmp(), - help: core.NewHelpCmp(), - dialog: core.NewDialogCmp(), + status: core.NewStatusCmp(app.LSPClients), + help: dialog.NewHelpCmp(), + quit: dialog.NewQuitCmp(), + permissions: dialog.NewPermissionDialogCmp(), app: app, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), - page.InitPage: page.NewInitPage(), - page.ReplPage: page.NewReplPage(app), }, } } diff --git a/main.go b/main.go index 4bc8a22f0..2e6954646 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,15 @@ package main import ( "github.com/kujtimiihoxha/termai/cmd" + "github.com/kujtimiihoxha/termai/internal/logging" ) func main() { + // Set up panic recovery for the main function + defer logging.RecoverPanic("main", func() { + // Perform any necessary cleanup before exit + logging.ErrorPersist("Application terminated due to unhandled panic") + }) + cmd.Execute() } -- cgit v1.2.3 From cc07f7a186995f428436bc1adc66a264a95171a4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 21:48:29 +0200 Subject: rename to opencode --- .opencode.json | 11 ++++ cmd/root.go | 14 ++-- go.mod | 2 +- internal/app/app.go | 18 +++--- internal/app/lsp.go | 8 +-- internal/config/config.go | 4 +- internal/db/connect.go | 6 +- internal/diff/diff.go | 4 +- internal/history/file.go | 4 +- internal/llm/agent/agent-tool.go | 10 +-- internal/llm/agent/agent.go | 18 +++--- internal/llm/agent/mcp-tools.go | 10 +-- internal/llm/agent/tools.go | 12 ++-- internal/llm/prompt/coder.go | 97 ++++++++++++---------------- internal/llm/prompt/prompt.go | 4 +- internal/llm/prompt/task.go | 4 +- internal/llm/prompt/title.go | 2 +- internal/llm/provider/anthropic.go | 8 +-- internal/llm/provider/bedrock.go | 4 +- internal/llm/provider/gemini.go | 8 +-- internal/llm/provider/openai.go | 8 +-- internal/llm/provider/provider.go | 6 +- internal/llm/tools/bash.go | 16 ++--- internal/llm/tools/diagnostics.go | 4 +- internal/llm/tools/edit.go | 10 +-- internal/llm/tools/edit_test.go | 2 +- internal/llm/tools/fetch.go | 6 +- internal/llm/tools/glob.go | 2 +- internal/llm/tools/grep.go | 2 +- internal/llm/tools/ls.go | 2 +- internal/llm/tools/mocks_test.go | 6 +- internal/llm/tools/shell/shell.go | 8 +-- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/view.go | 4 +- internal/llm/tools/write.go | 10 +-- internal/llm/tools/write_test.go | 2 +- internal/logging/writer.go | 2 +- internal/lsp/client.go | 6 +- internal/lsp/handlers.go | 8 +-- internal/lsp/language.go | 2 +- internal/lsp/methods.go | 2 +- internal/lsp/transport.go | 4 +- internal/lsp/util/edit.go | 2 +- internal/lsp/watcher/watcher.go | 8 +-- internal/message/content.go | 2 +- internal/message/message.go | 6 +- internal/permission/permission.go | 2 +- internal/session/session.go | 4 +- internal/tui/components/chat/chat.go | 8 +-- internal/tui/components/chat/editor.go | 10 +-- internal/tui/components/chat/messages.go | 22 +++---- internal/tui/components/chat/sidebar.go | 12 ++-- internal/tui/components/core/status.go | 12 ++-- internal/tui/components/dialog/help.go | 2 +- internal/tui/components/dialog/permission.go | 12 ++-- internal/tui/components/dialog/quit.go | 6 +- internal/tui/components/logs/details.go | 6 +- internal/tui/components/logs/table.go | 10 +-- internal/tui/layout/border.go | 2 +- internal/tui/layout/container.go | 2 +- internal/tui/layout/overlay.go | 4 +- internal/tui/layout/split.go | 2 +- internal/tui/page/chat.go | 10 +-- internal/tui/page/logs.go | 4 +- internal/tui/tui.go | 18 +++--- main.go | 4 +- 66 files changed, 266 insertions(+), 266 deletions(-) (limited to 'internal/diff') diff --git a/.opencode.json b/.opencode.json index b7fc19b52..4b2944f86 100644 --- a/.opencode.json +++ b/.opencode.json @@ -3,5 +3,16 @@ "gopls": { "command": "gopls" } + }, + "agents": { + "coder": { + "model": "gpt-4.1" + }, + "task": { + "model": "gpt-4.1" + }, + "title": { + "model": "gpt-4.1" + } } } diff --git a/cmd/root.go b/cmd/root.go index ff71747d5..f506e9940 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,13 +8,13 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" ) diff --git a/go.mod b/go.mod index 16c88d3a6..822e70dbd 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/kujtimiihoxha/termai +module github.com/kujtimiihoxha/opencode go 1.24.0 diff --git a/internal/app/app.go b/internal/app/app.go index 1c16ccc11..748fdaa7f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,15 +7,15 @@ import ( "sync" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) type App struct { diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4a762f1a1..d8a35c8b3 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -4,10 +4,10 @@ import ( "context" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/watcher" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/watcher" ) func (app *App) initLSPClients(ctx context.Context) { diff --git a/internal/config/config.go b/internal/config/config.go index 147d6c83a..20a8bac97 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/spf13/viper" ) diff --git a/internal/db/connect.go b/internal/db/connect.go index 8bba9cad8..e850bc8d0 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -12,8 +12,8 @@ import ( "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/mattn/go-sqlite3" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func Connect() (*sql.DB, error) { @@ -24,7 +24,7 @@ func Connect() (*sql.DB, error) { if err := os.MkdirAll(dataDir, 0o700); err != nil { return nil, fmt.Errorf("failed to create data directory: %w", err) } - dbPath := filepath.Join(dataDir, "termai.db") + dbPath := filepath.Join(dataDir, "opencode.db") // Open the SQLite database db, err := sql.Open("sqlite3", dbPath) if err != nil { diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 829554c7e..f48079c9c 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,8 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) diff --git a/internal/history/file.go b/internal/history/file.go index 82017d4cf..1e8bc50bb 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 308412bde..be6e09a9b 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/session" ) type agentTool struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ab2742ec1..a5dadb89d 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,15 +7,15 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/prompt" - "github.com/kujtimiihoxha/termai/internal/llm/provider" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/prompt" + "github.com/kujtimiihoxha/opencode/internal/llm/provider" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) // Common errors diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index c7ea4916c..16dddc1ba 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/version" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index a37f1d65d..409d14273 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -3,12 +3,12 @@ package agent import ( "context" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) func CoderAgentTools( diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 7439fd570..3a06911da 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -8,9 +8,9 @@ import ( "runtime" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" ) func CoderPrompt(provider models.ModelProvider) string { @@ -24,69 +24,58 @@ func CoderPrompt(provider models.ModelProvider) string { return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) } -const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. - -# Your mindset -Act like a competent, efficient software engineer who is familiar with large codebases. You should: -- Think critically about user requests. -- Proactively search the codebase for related information. -- Infer likely commands, tools, or conventions. -- Write and edit code with minimal user input. -- Anticipate next steps (tests, lints, etc.), but never commit unless explicitly told. - -# Context awareness -- Before acting, infer the purpose of a file from its name, directory, and neighboring files. -- If a file or function appears malicious, refuse to interact with it or discuss it. -- If a termai.md file exists, auto-load it as memory. Offer to update it only if new useful info appears (commands, preferences, structure). - -# CLI communication -- Use GitHub-flavored markdown in monospace font. -- Be concise. Never add preambles or postambles unless asked. Max 4 lines per response. -- Never explain your code unless asked. Do not narrate actions. -- Avoid unnecessary questions. Infer, search, act. - -# Behavior guidelines -- Follow project conventions: naming, formatting, libraries, frameworks. -- Before using any library or framework, confirm it’s already used. -- Always look at the surrounding code to match existing style. -- Do not add comments unless the code is complex or the user asks. - -# Autonomy rules -You are allowed and expected to: -- Search for commands, tools, or config files before asking the user. -- Run multiple search tool calls concurrently to gather relevant context. -- Choose test, lint, and typecheck commands based on package files or scripts. -- Offer to store these commands in termai.md if not already present. - -# Example behavior -user: write tests for new feature -assistant: [searches for existing test patterns, finds appropriate location, generates test code using existing style, optionally asks to add test command to termai.md] +const baseOpenAICoderPrompt = ` +You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. + +### ── INTERNAL REFLECTION ── +• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). +• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. + +### ── PUBLIC RESPONSE RULES ── +• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. +• Use GitHub‑flavored Markdown. +• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. + +### ── CONTEXT & MEMORY ── +• Infer file intent from directory structure before editing. +• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. -user: how do I typecheck this codebase? -assistant: [searches for known commands, infers package manager, checks for scripts or config files] -tsc --noEmit +### ── AUTONOMY PRIORITY ── +**Ask‑Only‑If Decision Tree:** +1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. +2. **Critical unknown?** (no docs/tests; cannot infer) → ask. +3. **Tool failure after two self‑attempts?** → ask. +Otherwise, proceed autonomously. -user: is X function used anywhere else? -assistant: [searches repo for references, returns file paths and lines] +### ── SAFETY & STYLE ── +• Mimic existing code style; verify libraries exist before import. +• Never commit unless explicitly told. +• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). +• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. -# Tool usage -- Use parallel calls when possible. -- Use file search and content tools before asking the user. -- Do not ask the user for information unless it cannot be determined via tools. +### ── TOOL USAGE ── +• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. +• Communicate with the user only via visible text; do not expose tool output or internal reasoning. -Never commit changes unless the user explicitly asks you to.` +### ── EXAMPLES ── +user: list files +assistant: ls + +user: write tests for new feature +assistant: [searches & edits autonomously, no extra chit‑chat] +` -const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. # Memory -If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes: +If the current working directory contains a file called OpenCode.md, it will be automatically added to your context. This file serves multiple purposes: 1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time 2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) 3. Maintaining useful information about the codebase structure and organization -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time. +When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to OpenCode.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to OpenCode.md so you can remember it for next time. # Tone and style You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system). @@ -161,7 +150,7 @@ The user will primarily request you perform software engineering tasks. This inc 1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially. 2. Implement the solution using all tools available to you 3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. -4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time. +4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to opencode.md so that you will know to run it next time. NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 63fc2df7b..cdc3560ce 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,8 +1,8 @@ package prompt import ( - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 8bf604ad9..88cd1a0f4 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -3,11 +3,11 @@ package prompt import ( "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func TaskPrompt(_ models.ModelProvider) string { - agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. + agentPrompt := `You are an agent for OpenCode. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 3023a8550..6e5289b24 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,6 @@ package prompt -import "github.com/kujtimiihoxha/termai/internal/llm/models" +import "github.com/kujtimiihoxha/opencode/internal/llm/models" func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index c3a4efc49..7bbc02103 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,10 +12,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" ) type anthropicOptions struct { diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index d76925ad1..9415b30fe 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type bedrockOptions struct { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 804baea28..384bff900 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -11,10 +11,10 @@ import ( "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 9c2ad2012..13ce934f2 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -8,10 +8,10 @@ import ( "io" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 1a5b3dc8a..e04bee71b 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type EventType string diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index c7c970e5a..18533b761 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -7,9 +7,9 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools/shell" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools/shell" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type BashParams struct { @@ -122,16 +122,16 @@ When the user asks you to create a new git commit, follow these steps carefully: 4. Create the commit with a message ending with: -🤖 Generated with termai -Co-Authored-By: termai +🤖 Generated with opencode +Co-Authored-By: opencode - In order to ensure good formatting, ALWAYS pass the commit message via a HEREDOC, a la this example: git commit -m "$(cat <<'EOF' Commit message here. - 🤖 Generated with termai - Co-Authored-By: termai + 🤖 Generated with opencode + Co-Authored-By: opencode EOF )" @@ -193,7 +193,7 @@ gh pr create --title "the pr title" --body "$(cat <<'EOF' ## Test plan [Checklist of TODOs for testing the pull request...] -🤖 Generated with termai +🤖 Generated with opencode EOF )" diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index b7b2bb8ba..82989c774 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type DiagnosticsParams struct { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 148e7aba7..6a1616010 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type EditParams struct { diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 0971775dd..1b58a0d7d 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 91bcb36a0..827755863 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -11,8 +11,8 @@ import ( md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type FetchParams struct { @@ -146,7 +146,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 7b4fb1187..40262ce2b 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -12,7 +12,7 @@ import ( "time" "github.com/bmatcuk/doublestar/v4" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) const ( diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 19333f50b..3436dd7eb 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type GrepParams struct { diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a63bf0eeb..05f300c0e 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type LSParams struct { diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go index 321f09ac1..81993160c 100644 --- a/internal/llm/tools/mocks_test.go +++ b/internal/llm/tools/mocks_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) // Mock permission service for testing diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 4a776478a..e25bdf3ea 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -126,10 +126,10 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx } tempDir := os.TempDir() - stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano())) - stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano())) - statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano())) - cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano())) + stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano())) + stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano())) + statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano())) + cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano())) defer func() { os.Remove(stdoutFile) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index a6f2c8afb..0d38c975f 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -218,7 +218,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 7450a84bf..3fa4ca116 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -10,8 +10,8 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/lsp" ) type ViewParams struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index bb49381fd..261865c39 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -8,11 +8,11 @@ import ( "path/filepath" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type WriteParams struct { diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 2264f36fb..b5ecb3fda 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/logging/writer.go b/internal/logging/writer.go index 9fe469c5e..1dc07e853 100644 --- a/internal/logging/writer.go +++ b/internal/logging/writer.go @@ -9,7 +9,7 @@ import ( "time" "github.com/go-logfmt/logfmt" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 0f03e7fcb..dad07f3c0 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -13,9 +13,9 @@ import ( "sync/atomic" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type Client struct { diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index c3088d685..7a11286e6 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -3,10 +3,10 @@ package lsp import ( "encoding/json" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/lsp/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/util" ) // Requests diff --git a/internal/lsp/language.go b/internal/lsp/language.go index 2e276c464..65ccd54f3 100644 --- a/internal/lsp/language.go +++ b/internal/lsp/language.go @@ -4,7 +4,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func DetectLanguageID(uri string) protocol.LanguageKind { diff --git a/internal/lsp/methods.go b/internal/lsp/methods.go index 079b3bfe3..ab33d7e1b 100644 --- a/internal/lsp/methods.go +++ b/internal/lsp/methods.go @@ -4,7 +4,7 @@ package lsp import ( "context" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // Implementation sends a textDocument/implementation request to the LSP server. diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 89255fd78..fe59b0fbb 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -8,8 +8,8 @@ import ( "io" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) // Write writes an LSP message to the given writer diff --git a/internal/lsp/util/edit.go b/internal/lsp/util/edit.go index 3b94fb39f..52f03ee77 100644 --- a/internal/lsp/util/edit.go +++ b/internal/lsp/util/edit.go @@ -7,7 +7,7 @@ import ( "sort" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error { diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 156f38e1a..595c78db9 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -10,10 +10,10 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // WorkspaceWatcher manages LSP file watching diff --git a/internal/message/content.go b/internal/message/content.go index f9e76b11c..f52449f4a 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) type MessageRole string diff --git a/internal/message/message.go b/internal/message/message.go index 2871780a7..f165fcfc7 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -7,9 +7,9 @@ import ( "fmt" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type CreateMessageParams struct { diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 8aa280906..4cb379dea 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,7 +6,7 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) var ErrorPermissionDenied = errors.New("permission denied") diff --git a/internal/session/session.go b/internal/session/session.go index 019019df4..280da1ff0 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -5,8 +5,8 @@ import ( "database/sql" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type Session struct { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e98001efa..52ff4c8bf 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -5,10 +5,10 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/version" ) type SendMsg struct { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e2f4da9e2..4d6ef5ca0 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type editorCmp struct { diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index 26a98970e..c2ce7d88b 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -15,17 +15,17 @@ import ( "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type uiMessageType int diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index b90269d1a..54b39f4a1 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type sidebarCmp struct { diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 089dffa2c..411cac1c5 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type statusCmp struct { diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go index 1d3c2b077..6242017f1 100644 --- a/internal/tui/components/dialog/help.go +++ b/internal/tui/components/dialog/help.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type helpCmp struct { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 9c55effde..200a7970d 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,12 +9,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type PermissionAction string diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 10d9ba8a2..5bbe6696c 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -6,9 +6,9 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) const question = "Are you sure you want to quit?" diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 18eb1a526..3a8f17999 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -9,9 +9,9 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type DetailComponent interface { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 6e8eb58b1..dc6184e3d 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -7,11 +7,11 @@ import ( "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/table" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type TableComponent interface { diff --git a/internal/tui/layout/border.go b/internal/tui/layout/border.go index 8fe5c430c..ea9f5e0bc 100644 --- a/internal/tui/layout/border.go +++ b/internal/tui/layout/border.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type BorderPosition int diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index db07d49fb..603699955 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type Container interface { diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 4a1bcf661..4c05e8462 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,8 +5,8 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" "github.com/muesli/reflow/truncate" diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 6482fc74c..bfb616a53 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type SplitPaneLayout interface { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index cebc0e461..c268e677f 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/components/chat" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) var ChatPage PageID = "chat" diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index d1e557eab..c77a033f4 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -2,8 +2,8 @@ package page import ( tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/logs" - "github.com/kujtimiihoxha/termai/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/components/logs" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" ) var LogsPage PageID = "logs" diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dff7ad63d..657de6b6e 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -4,15 +4,15 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" - "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/page" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/components/core" + "github.com/kujtimiihoxha/opencode/internal/tui/components/dialog" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/page" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type keyMap struct { diff --git a/main.go b/main.go index 2e6954646..06578c7ef 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,8 @@ package main import ( - "github.com/kujtimiihoxha/termai/cmd" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/cmd" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func main() { -- cgit v1.2.3 From 333ea6ec4b2abfc2c1a9c3f6b0918ca5d296347f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 20:17:38 +0200 Subject: implement patch, update ui, improve rendering --- README.md | 301 +++++++++-- internal/config/config.go | 16 +- internal/diff/diff.go | 28 +- internal/diff/patch.go | 739 ++++++++++++++++++++++++++ internal/history/file.go | 21 +- internal/llm/agent/agent.go | 53 +- internal/llm/agent/tools.go | 3 +- internal/llm/models/anthropic.go | 6 + internal/llm/models/models.go | 64 +-- internal/llm/models/openai.go | 169 ++++++ internal/llm/prompt/coder.go | 83 +-- internal/llm/provider/openai.go | 49 +- internal/llm/tools/glob.go | 38 +- internal/llm/tools/grep.go | 37 +- internal/llm/tools/patch.go | 450 +++++++++------- internal/llm/tools/view.go | 13 +- internal/tui/components/chat/editor.go | 58 ++- internal/tui/components/chat/list.go | 463 +++++++++++++++++ internal/tui/components/chat/message.go | 561 ++++++++++++++++++++ internal/tui/components/chat/messages.go | 742 --------------------------- internal/tui/components/chat/sidebar.go | 94 +++- internal/tui/components/core/status.go | 24 +- internal/tui/components/dialog/permission.go | 16 +- internal/tui/components/dialog/session.go | 224 ++++++++ internal/tui/components/logs/details.go | 16 +- internal/tui/components/logs/table.go | 3 +- internal/tui/layout/bento.go | 392 -------------- internal/tui/layout/border.go | 121 ----- internal/tui/layout/container.go | 5 +- internal/tui/layout/grid.go | 254 --------- internal/tui/layout/layout.go | 6 +- internal/tui/layout/single.go | 189 ------- internal/tui/layout/split.go | 37 +- internal/tui/page/chat.go | 31 +- internal/tui/page/logs.go | 13 +- internal/tui/styles/background.go | 114 ++-- internal/tui/styles/icons.go | 14 +- internal/tui/tui.go | 111 +++- 38 files changed, 3304 insertions(+), 2254 deletions(-) create mode 100644 internal/diff/patch.go create mode 100644 internal/llm/models/openai.go create mode 100644 internal/tui/components/chat/list.go create mode 100644 internal/tui/components/chat/message.go delete mode 100644 internal/tui/components/chat/messages.go create mode 100644 internal/tui/components/dialog/session.go delete mode 100644 internal/tui/layout/bento.go delete mode 100644 internal/tui/layout/border.go delete mode 100644 internal/tui/layout/grid.go delete mode 100644 internal/tui/layout/single.go (limited to 'internal/diff') diff --git a/README.md b/README.md index 564284c7f..ef55b6929 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal. -[![OpenCode Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) - ## Overview OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. @@ -13,11 +11,13 @@ OpenCode is a Go-based CLI application that brings AI assistance to your termina ## Features - **Interactive TUI**: Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) for a smooth terminal experience -- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, and Google Gemini models +- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, Google Gemini, AWS Bedrock, and Groq - **Session Management**: Save and manage multiple conversation sessions - **Tool Integration**: AI can execute commands, search files, and modify code -- **Vim-like Editor**: Integrated editor with Vim keybindings for text input +- **Vim-like Editor**: Integrated editor with text input capabilities - **Persistent Storage**: SQLite database for storing conversations and sessions +- **LSP Integration**: Language Server Protocol support for code intelligence +- **File Change Tracking**: Track and visualize file changes during sessions ## Installation @@ -34,11 +34,107 @@ OpenCode looks for configuration in the following locations: - `$XDG_CONFIG_HOME/opencode/.opencode.json` - `./.opencode.json` (local directory) -You can also use environment variables: +### Environment Variables + +You can configure OpenCode using environment variables: + +| Environment Variable | Purpose | +| ----------------------- | ------------------------ | +| `ANTHROPIC_API_KEY` | For Claude models | +| `OPENAI_API_KEY` | For OpenAI models | +| `GEMINI_API_KEY` | For Google Gemini models | +| `GROQ_API_KEY` | For Groq models | +| `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) | +| `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) | +| `AWS_REGION` | For AWS Bedrock (Claude) | + +### Configuration File Structure + +```json +{ + "data": { + "directory": ".opencode" + }, + "providers": { + "openai": { + "apiKey": "your-api-key", + "disabled": false + }, + "anthropic": { + "apiKey": "your-api-key", + "disabled": false + } + }, + "agents": { + "coder": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "task": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "title": { + "model": "claude-3.7-sonnet", + "maxTokens": 80 + } + }, + "mcpServers": { + "example": { + "type": "stdio", + "command": "path/to/mcp-server", + "env": [], + "args": [] + } + }, + "lsp": { + "go": { + "disabled": false, + "command": "gopls" + } + }, + "debug": false, + "debugLSP": false +} +``` -- `ANTHROPIC_API_KEY`: For Claude models -- `OPENAI_API_KEY`: For OpenAI models -- `GEMINI_API_KEY`: For Google Gemini models +## Supported AI Models + +### OpenAI Models + +| Model ID | Name | Context Window | +| ----------------- | --------------- | ---------------- | +| `gpt-4.1` | GPT 4.1 | 1,047,576 tokens | +| `gpt-4.1-mini` | GPT 4.1 Mini | 200,000 tokens | +| `gpt-4.1-nano` | GPT 4.1 Nano | 1,047,576 tokens | +| `gpt-4.5-preview` | GPT 4.5 Preview | 128,000 tokens | +| `gpt-4o` | GPT-4o | 128,000 tokens | +| `gpt-4o-mini` | GPT-4o Mini | 128,000 tokens | +| `o1` | O1 | 200,000 tokens | +| `o1-pro` | O1 Pro | 200,000 tokens | +| `o1-mini` | O1 Mini | 128,000 tokens | +| `o3` | O3 | 200,000 tokens | +| `o3-mini` | O3 Mini | 200,000 tokens | +| `o4-mini` | O4 Mini | 128,000 tokens | + +### Anthropic Models + +| Model ID | Name | Context Window | +| ------------------- | ----------------- | -------------- | +| `claude-3.5-sonnet` | Claude 3.5 Sonnet | 200,000 tokens | +| `claude-3-haiku` | Claude 3 Haiku | 200,000 tokens | +| `claude-3.7-sonnet` | Claude 3.7 Sonnet | 200,000 tokens | +| `claude-3.5-haiku` | Claude 3.5 Haiku | 200,000 tokens | +| `claude-3-opus` | Claude 3 Opus | 200,000 tokens | + +### Other Models + +| Model ID | Provider | Name | Context Window | +| --------------------------- | ----------- | ----------------- | -------------- | +| `gemini-2.5` | Google | Gemini 2.5 Pro | - | +| `gemini-2.0-flash` | Google | Gemini 2.0 Flash | - | +| `qwen-qwq` | Groq | Qwen Qwq | - | +| `bedrock.claude-3.7-sonnet` | AWS Bedrock | Claude 3.7 Sonnet | - | ## Usage @@ -48,36 +144,78 @@ opencode # Start with debug logging opencode -d + +# Start with a specific working directory +opencode -c /path/to/project ``` -### Keyboard Shortcuts +## Command-line Flags + +| Flag | Short | Description | +| --------- | ----- | ----------------------------- | +| `--help` | `-h` | Display help information | +| `--debug` | `-d` | Enable debug mode | +| `--cwd` | `-c` | Set current working directory | + +## Keyboard Shortcuts + +### Global Shortcuts + +| Shortcut | Action | +| -------- | ------------------------------------------------------- | +| `Ctrl+C` | Quit application | +| `Ctrl+?` | Toggle help dialog | +| `Ctrl+L` | View logs | +| `Esc` | Close current overlay/dialog or return to previous mode | + +### Chat Page Shortcuts + +| Shortcut | Action | +| -------- | --------------------------------------- | +| `Ctrl+N` | Create new session | +| `Ctrl+X` | Cancel current operation/generation | +| `i` | Focus editor (when not in writing mode) | +| `Esc` | Exit writing mode and focus messages | + +### Editor Shortcuts -#### Global Shortcuts +| Shortcut | Action | +| ------------------- | ----------------------------------------- | +| `Ctrl+S` | Send message (when editor is focused) | +| `Enter` or `Ctrl+S` | Send message (when editor is not focused) | +| `Esc` | Blur editor and focus messages | -- `?`: Toggle help panel -- `Ctrl+C` or `q`: Quit application -- `L`: View logs -- `Backspace`: Go back to previous page -- `Esc`: Close current view/dialog or return to normal mode +### Logs Page Shortcuts -#### Session Management +| Shortcut | Action | +| ----------- | ------------------- | +| `Backspace` | Return to chat page | -- `N`: Create new session -- `Enter` or `Space`: Select session (in sessions list) +## AI Assistant Tools -#### Editor Shortcuts (Vim-like) +OpenCode's AI assistant has access to various tools to help with coding tasks: -- `i`: Enter insert mode -- `Esc`: Enter normal mode -- `v`: Enter visual mode -- `V`: Enter visual line mode -- `Enter`: Send message (in normal mode) -- `Ctrl+S`: Send message (in insert mode) +### File and Code Tools -#### Navigation +| Tool | Description | Parameters | +| ------------- | --------------------------- | ---------------------------------------------------------------------------------------- | +| `glob` | Find files by pattern | `pattern` (required), `path` (optional) | +| `grep` | Search file contents | `pattern` (required), `path` (optional), `include` (optional), `literal_text` (optional) | +| `ls` | List directory contents | `path` (optional), `ignore` (optional array of patterns) | +| `view` | View file contents | `file_path` (required), `offset` (optional), `limit` (optional) | +| `write` | Write to files | `file_path` (required), `content` (required) | +| `edit` | Edit files | Various parameters for file editing | +| `patch` | Apply patches to files | `file_path` (required), `diff` (required) | +| `diagnostics` | Get diagnostics information | `file_path` (optional) | -- Arrow keys: Navigate through lists and content -- Page Up/Down: Scroll through content +### Other Tools + +| Tool | Description | Parameters | +| ------------- | -------------------------------------- | ----------------------------------------------------------------------------------------- | +| `bash` | Execute shell commands | `command` (required), `timeout` (optional) | +| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) | +| `sourcegraph` | Search code across public repositories | `query` (required), `count` (optional), `context_window` (optional), `timeout` (optional) | +| `agent` | Run sub-tasks with the AI agent | `prompt` (required) | ## Architecture @@ -92,6 +230,101 @@ OpenCode is built with a modular architecture: - **internal/logging**: Logging infrastructure - **internal/message**: Message handling - **internal/session**: Session management +- **internal/lsp**: Language Server Protocol integration + +## MCP (Model Context Protocol) + +OpenCode implements the Model Context Protocol (MCP) to extend its capabilities through external tools. MCP provides a standardized way for the AI assistant to interact with external services and tools. + +### MCP Features + +- **External Tool Integration**: Connect to external tools and services via a standardized protocol +- **Tool Discovery**: Automatically discover available tools from MCP servers +- **Multiple Connection Types**: + - **Stdio**: Communicate with tools via standard input/output + - **SSE**: Communicate with tools via Server-Sent Events +- **Security**: Permission system for controlling access to MCP tools + +### Configuring MCP Servers + +MCP servers are defined in the configuration file under the `mcpServers` section: + +```json +{ + "mcpServers": { + "example": { + "type": "stdio", + "command": "path/to/mcp-server", + "env": [], + "args": [] + }, + "web-example": { + "type": "sse", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "Bearer token" + } + } + } +} +``` + +### MCP Tool Usage + +Once configured, MCP tools are automatically available to the AI assistant alongside built-in tools. They follow the same permission model as other tools, requiring user approval before execution. + +## LSP (Language Server Protocol) + +OpenCode integrates with Language Server Protocol to provide rich code intelligence features across multiple programming languages. + +### LSP Features + +- **Multi-language Support**: Connect to language servers for different programming languages +- **Code Intelligence**: Get diagnostics, completions, and navigation assistance +- **File Watching**: Automatically notify language servers of file changes +- **Diagnostics**: Display errors, warnings, and hints in your code + +### Supported LSP Features + +| Feature | Description | +| ----------------- | ----------------------------------- | +| Diagnostics | Error checking and linting | +| Completions | Code suggestions and autocompletion | +| Hover | Documentation on hover | +| Definition | Go to definition | +| References | Find all references | +| Document Symbols | Navigate symbols in current file | +| Workspace Symbols | Search symbols across workspace | +| Formatting | Code formatting | +| Code Actions | Quick fixes and refactorings | + +### Configuring LSP + +Language servers are configured in the configuration file under the `lsp` section: + +```json +{ + "lsp": { + "go": { + "disabled": false, + "command": "gopls" + }, + "typescript": { + "disabled": false, + "command": "typescript-language-server", + "args": ["--stdio"] + } + } +} +``` + +### LSP Integration with AI + +The AI assistant can access LSP features through the `diagnostics` tool, allowing it to: + +- Check for errors in your code +- Suggest fixes based on diagnostics +- Provide intelligent code assistance ## Development @@ -124,8 +357,16 @@ OpenCode builds upon the work of several open source projects and developers: ## License -[License information coming soon] +OpenCode is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. ## Contributing -[Contribution guidelines coming soon] +Contributions are welcome! Here's how you can contribute: + +1. Fork the repository +2. Create a feature branch (`git checkout -b feature/amazing-feature`) +3. Commit your changes (`git commit -m 'Add some amazing feature'`) +4. Push to the branch (`git push origin feature/amazing-feature`) +5. Open a Pull Request + +Please make sure to update tests as appropriate and follow the existing code style. diff --git a/internal/config/config.go b/internal/config/config.go index 5b6d51efa..0cb727158 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,8 +41,9 @@ const ( // Agent defines configuration for different LLM models and their token limits. type Agent struct { - Model models.ModelID `json:"model"` - MaxTokens int64 `json:"maxTokens"` + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` + ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh } // Provider defines configuration for an LLM provider. @@ -80,7 +81,6 @@ type Config struct { const ( defaultDataDirectory = ".opencode" defaultLogLevel = "info" - defaultMaxTokens = int64(5000) appName = "opencode" ) @@ -202,9 +202,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { viper.SetDefault("providers.groq.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.QWENQwq) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.QWENQwq) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.QWENQwq) } @@ -212,9 +210,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.GRMINI20Flash) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.GRMINI20Flash) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.GRMINI20Flash) } @@ -222,9 +218,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.GPT4o) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.GPT4o) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.GPT4o) } @@ -233,17 +227,13 @@ func setProviderDefaults() { if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.Claude37Sonnet) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.Claude37Sonnet) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.Claude37Sonnet) } if hasAWSCredentials() { viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) } } diff --git a/internal/diff/diff.go b/internal/diff/diff.go index f48079c9c..7b48de25f 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -79,8 +79,9 @@ type linePair struct { // StyleConfig defines styling for diff rendering type StyleConfig struct { - ShowHeader bool - FileNameFg lipgloss.Color + ShowHeader bool + ShowHunkHeader bool + FileNameFg lipgloss.Color // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color @@ -111,7 +112,8 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { // Default color scheme config := StyleConfig{ ShowHeader: true, - FileNameFg: lipgloss.Color("#fab283"), + ShowHunkHeader: true, + FileNameFg: lipgloss.Color("#a0a0a0"), RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), @@ -204,6 +206,10 @@ func WithShowHeader(show bool) StyleOption { return func(s *StyleConfig) { s.ShowHeader = show } } +func WithShowHunkHeader(show bool) StyleOption { + return func(s *StyleConfig) { s.ShowHunkHeader = show } +} + // ------------------------------------------------------------------------- // Parse Configuration // ------------------------------------------------------------------------- @@ -914,13 +920,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { for _, h := range diffResult.Hunks { // Render hunk header - sb.WriteString( - lipgloss.NewStyle(). - Background(config.Style.HunkLineBg). - Foreground(config.Style.HunkLineFg). - Width(config.TotalWidth). - Render(h.Header) + "\n", - ) + if config.Style.ShowHunkHeader { + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) + } sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } diff --git a/internal/diff/patch.go b/internal/diff/patch.go new file mode 100644 index 000000000..aab0f956d --- /dev/null +++ b/internal/diff/patch.go @@ -0,0 +1,739 @@ +package diff + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +type ActionType string + +const ( + ActionAdd ActionType = "add" + ActionDelete ActionType = "delete" + ActionUpdate ActionType = "update" +) + +type FileChange struct { + Type ActionType + OldContent *string + NewContent *string + MovePath *string +} + +type Commit struct { + Changes map[string]FileChange +} + +type Chunk struct { + OrigIndex int // line index of the first line in the original file + DelLines []string // lines to delete + InsLines []string // lines to insert +} + +type PatchAction struct { + Type ActionType + NewFile *string + Chunks []Chunk + MovePath *string +} + +type Patch struct { + Actions map[string]PatchAction +} + +type DiffError struct { + message string +} + +func (e DiffError) Error() string { + return e.message +} + +// Helper functions for error handling +func NewDiffError(message string) DiffError { + return DiffError{message: message} +} + +func fileError(action, reason, path string) DiffError { + return NewDiffError(fmt.Sprintf("%s File Error: %s: %s", action, reason, path)) +} + +func contextError(index int, context string, isEOF bool) DiffError { + prefix := "Invalid Context" + if isEOF { + prefix = "Invalid EOF Context" + } + return NewDiffError(fmt.Sprintf("%s %d:\n%s", prefix, index, context)) +} + +type Parser struct { + currentFiles map[string]string + lines []string + index int + patch Patch + fuzz int +} + +func NewParser(currentFiles map[string]string, lines []string) *Parser { + return &Parser{ + currentFiles: currentFiles, + lines: lines, + index: 0, + patch: Patch{Actions: make(map[string]PatchAction, len(currentFiles))}, + fuzz: 0, + } +} + +func (p *Parser) isDone(prefixes []string) bool { + if p.index >= len(p.lines) { + return true + } + if prefixes != nil { + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true + } + } + } + return false +} + +func (p *Parser) startsWith(prefix any) bool { + var prefixes []string + switch v := prefix.(type) { + case string: + prefixes = []string{v} + case []string: + prefixes = v + } + + for _, pfx := range prefixes { + if strings.HasPrefix(p.lines[p.index], pfx) { + return true + } + } + return false +} + +func (p *Parser) readStr(prefix string, returnEverything bool) string { + if p.index >= len(p.lines) { + return "" // Changed from panic to return empty string for safer operation + } + if strings.HasPrefix(p.lines[p.index], prefix) { + var text string + if returnEverything { + text = p.lines[p.index] + } else { + text = p.lines[p.index][len(prefix):] + } + p.index++ + return text + } + return "" +} + +func (p *Parser) Parse() error { + endPatchPrefixes := []string{"*** End Patch"} + + for !p.isDone(endPatchPrefixes) { + path := p.readStr("*** Update File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Update", "Duplicate Path", path) + } + moveTo := p.readStr("*** Move to: ", false) + if _, exists := p.currentFiles[path]; !exists { + return fileError("Update", "Missing File", path) + } + text := p.currentFiles[path] + action, err := p.parseUpdateFile(text) + if err != nil { + return err + } + if moveTo != "" { + action.MovePath = &moveTo + } + p.patch.Actions[path] = action + continue + } + + path = p.readStr("*** Delete File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Delete", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; !exists { + return fileError("Delete", "Missing File", path) + } + p.patch.Actions[path] = PatchAction{Type: ActionDelete, Chunks: []Chunk{}} + continue + } + + path = p.readStr("*** Add File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Add", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; exists { + return fileError("Add", "File already exists", path) + } + action, err := p.parseAddFile() + if err != nil { + return err + } + p.patch.Actions[path] = action + continue + } + + return NewDiffError(fmt.Sprintf("Unknown Line: %s", p.lines[p.index])) + } + + if !p.startsWith("*** End Patch") { + return NewDiffError("Missing End Patch") + } + p.index++ + + return nil +} + +func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { + action := PatchAction{Type: ActionUpdate, Chunks: []Chunk{}} + fileLines := strings.Split(text, "\n") + index := 0 + + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + "*** End of File", + } + + for !p.isDone(endPrefixes) { + defStr := p.readStr("@@ ", false) + sectionStr := "" + if defStr == "" && p.index < len(p.lines) && p.lines[p.index] == "@@" { + sectionStr = p.lines[p.index] + p.index++ + } + if !(defStr != "" || sectionStr != "" || index == 0) { + return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) + } + if strings.TrimSpace(defStr) != "" { + found := false + for i := range fileLines[:index] { + if fileLines[i] == defStr { + found = true + break + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if fileLines[i] == defStr { + index = i + 1 + found = true + break + } + } + } + + if !found { + for i := range fileLines[:index] { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + found = true + break + } + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + index = i + 1 + p.fuzz++ + found = true + break + } + } + } + } + + nextChunkContext, chunks, endPatchIndex, eof := peekNextSection(p.lines, p.index) + newIndex, fuzz := findContext(fileLines, nextChunkContext, index, eof) + if newIndex == -1 { + ctxText := strings.Join(nextChunkContext, "\n") + return action, contextError(index, ctxText, eof) + } + p.fuzz += fuzz + + for _, ch := range chunks { + ch.OrigIndex += newIndex + action.Chunks = append(action.Chunks, ch) + } + index = newIndex + len(nextChunkContext) + p.index = endPatchIndex + } + return action, nil +} + +func (p *Parser) parseAddFile() (PatchAction, error) { + lines := make([]string, 0, 16) // Preallocate space for better performance + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + } + + for !p.isDone(endPrefixes) { + s := p.readStr("", true) + if !strings.HasPrefix(s, "+") { + return PatchAction{}, NewDiffError(fmt.Sprintf("Invalid Add File Line: %s", s)) + } + lines = append(lines, s[1:]) + } + + newFile := strings.Join(lines, "\n") + return PatchAction{ + Type: ActionAdd, + NewFile: &newFile, + Chunks: []Chunk{}, + }, nil +} + +// Refactored to use a matcher function for each comparison type +func findContextCore(lines []string, context []string, start int) (int, int) { + if len(context) == 0 { + return start, 0 + } + + // Try exact match + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return a == b + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming right whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimRight(a, " \t") == strings.TrimRight(b, " \t") + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming all whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimSpace(a) == strings.TrimSpace(b) + }); idx >= 0 { + return idx, fuzz + } + + return -1, 0 +} + +// Helper function to DRY up the match logic +func tryFindMatch(lines []string, context []string, start int, + compareFunc func(string, string) bool, +) (int, int) { + for i := start; i < len(lines); i++ { + if i+len(context) <= len(lines) { + match := true + for j := range context { + if !compareFunc(lines[i+j], context[j]) { + match = false + break + } + } + if match { + // Return fuzz level: 0 for exact, 1 for trimRight, 100 for trimSpace + var fuzz int + if compareFunc("a ", "a") && !compareFunc("a", "b") { + fuzz = 1 + } else if compareFunc("a ", "a") { + fuzz = 100 + } + return i, fuzz + } + } + } + return -1, 0 +} + +func findContext(lines []string, context []string, start int, eof bool) (int, int) { + if eof { + newIndex, fuzz := findContextCore(lines, context, len(lines)-len(context)) + if newIndex != -1 { + return newIndex, fuzz + } + newIndex, fuzz = findContextCore(lines, context, start) + return newIndex, fuzz + 10000 + } + return findContextCore(lines, context, start) +} + +func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, bool) { + index := initialIndex + old := make([]string, 0, 32) // Preallocate for better performance + delLines := make([]string, 0, 8) + insLines := make([]string, 0, 8) + chunks := make([]Chunk, 0, 4) + mode := "keep" + + // End conditions for the section + endSectionConditions := func(s string) bool { + return strings.HasPrefix(s, "@@") || + strings.HasPrefix(s, "*** End Patch") || + strings.HasPrefix(s, "*** Update File:") || + strings.HasPrefix(s, "*** Delete File:") || + strings.HasPrefix(s, "*** Add File:") || + strings.HasPrefix(s, "*** End of File") || + s == "***" || + strings.HasPrefix(s, "***") + } + + for index < len(lines) { + s := lines[index] + if endSectionConditions(s) { + break + } + index++ + lastMode := mode + line := s + + if len(line) > 0 { + switch line[0] { + case '+': + mode = "add" + case '-': + mode = "delete" + case ' ': + mode = "keep" + default: + mode = "keep" + line = " " + line + } + } else { + mode = "keep" + line = " " + } + + line = line[1:] + if mode == "keep" && lastMode != mode { + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + delLines = make([]string, 0, 8) + insLines = make([]string, 0, 8) + } + if mode == "delete" { + delLines = append(delLines, line) + old = append(old, line) + } else if mode == "add" { + insLines = append(insLines, line) + } else { + old = append(old, line) + } + } + + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + + if index < len(lines) && lines[index] == "*** End of File" { + index++ + return old, chunks, index, true + } + return old, chunks, index, false +} + +func TextToPatch(text string, orig map[string]string) (Patch, int, error) { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + if len(lines) < 2 || !strings.HasPrefix(lines[0], "*** Begin Patch") || lines[len(lines)-1] != "*** End Patch" { + return Patch{}, 0, NewDiffError("Invalid patch text") + } + parser := NewParser(orig, lines) + parser.index = 1 + if err := parser.Parse(); err != nil { + return Patch{}, 0, err + } + return parser.patch, parser.fuzz, nil +} + +func IdentifyFilesNeeded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Update File: ") { + result[line[len("*** Update File: "):]] = true + } + if strings.HasPrefix(line, "*** Delete File: ") { + result[line[len("*** Delete File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func IdentifyFilesAdded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Add File: ") { + result[line[len("*** Add File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func getUpdatedFile(text string, action PatchAction, path string) (string, error) { + if action.Type != ActionUpdate { + return "", errors.New("Expected UPDATE action") + } + origLines := strings.Split(text, "\n") + destLines := make([]string, 0, len(origLines)) // Preallocate with capacity + origIndex := 0 + + for _, chunk := range action.Chunks { + if chunk.OrigIndex > len(origLines) { + return "", NewDiffError(fmt.Sprintf("%s: chunk.orig_index %d > len(lines) %d", path, chunk.OrigIndex, len(origLines))) + } + if origIndex > chunk.OrigIndex { + return "", NewDiffError(fmt.Sprintf("%s: orig_index %d > chunk.orig_index %d", path, origIndex, chunk.OrigIndex)) + } + destLines = append(destLines, origLines[origIndex:chunk.OrigIndex]...) + delta := chunk.OrigIndex - origIndex + origIndex += delta + + if len(chunk.InsLines) > 0 { + destLines = append(destLines, chunk.InsLines...) + } + origIndex += len(chunk.DelLines) + } + + destLines = append(destLines, origLines[origIndex:]...) + return strings.Join(destLines, "\n"), nil +} + +func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { + commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} + for pathKey, action := range patch.Actions { + if action.Type == ActionDelete { + oldContent := orig[pathKey] + commit.Changes[pathKey] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else if action.Type == ActionAdd { + commit.Changes[pathKey] = FileChange{ + Type: ActionAdd, + NewContent: action.NewFile, + } + } else if action.Type == ActionUpdate { + newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) + if err != nil { + return Commit{}, err + } + oldContent := orig[pathKey] + fileChange := FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + if action.MovePath != nil { + fileChange.MovePath = action.MovePath + } + commit.Changes[pathKey] = fileChange + } + } + return commit, nil +} + +func AssembleChanges(orig map[string]string, updatedFiles map[string]string) Commit { + commit := Commit{Changes: make(map[string]FileChange, len(updatedFiles))} + for p, newContent := range updatedFiles { + oldContent, exists := orig[p] + if exists && oldContent == newContent { + continue + } + + if exists && newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + } else if newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionAdd, + NewContent: &newContent, + } + } else if exists { + commit.Changes[p] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else { + return commit // Changed from panic to simply return current commit + } + } + return commit +} + +func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string]string, error) { + orig := make(map[string]string, len(paths)) + for _, p := range paths { + content, err := openFn(p) + if err != nil { + return nil, fileError("Open", "File not found", p) + } + orig[p] = content + } + return orig, nil +} + +func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { + for p, change := range commit.Changes { + if change.Type == ActionDelete { + if err := removeFn(p); err != nil { + return err + } + } else if change.Type == ActionAdd { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) + } + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } else if change.Type == ActionUpdate { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) + } + if change.MovePath != nil { + if err := writeFn(*change.MovePath, *change.NewContent); err != nil { + return err + } + if err := removeFn(p); err != nil { + return err + } + } else { + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } + } + } + return nil +} + +func ProcessPatch(text string, openFn func(string) (string, error), writeFn func(string, string) error, removeFn func(string) error) (string, error) { + if !strings.HasPrefix(text, "*** Begin Patch") { + return "", NewDiffError("Patch must start with *** Begin Patch") + } + paths := IdentifyFilesNeeded(text) + orig, err := LoadFiles(paths, openFn) + if err != nil { + return "", err + } + + patch, fuzz, err := TextToPatch(text, orig) + if err != nil { + return "", err + } + + if fuzz > 0 { + return "", NewDiffError(fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz)) + } + + commit, err := PatchToCommit(patch, orig) + if err != nil { + return "", err + } + + if err := ApplyCommit(commit, writeFn, removeFn); err != nil { + return "", err + } + + return "Patch applied successfully", nil +} + +func OpenFile(p string) (string, error) { + data, err := os.ReadFile(p) + if err != nil { + return "", err + } + return string(data), nil +} + +func WriteFile(p string, content string) error { + if filepath.IsAbs(p) { + return NewDiffError("We do not support absolute paths.") + } + + dir := filepath.Dir(p) + if dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + } + + return os.WriteFile(p, []byte(content), 0o644) +} + +func RemoveFile(p string) error { + return os.Remove(p) +} + +func ValidatePatch(patchText string, files map[string]string) (bool, string, error) { + if !strings.HasPrefix(patchText, "*** Begin Patch") { + return false, "Patch must start with *** Begin Patch", nil + } + + neededFiles := IdentifyFilesNeeded(patchText) + for _, filePath := range neededFiles { + if _, exists := files[filePath]; !exists { + return false, fmt.Sprintf("File not found: %s", filePath), nil + } + } + + patch, fuzz, err := TextToPatch(patchText, files) + if err != nil { + return false, err.Error(), nil + } + + if fuzz > 0 { + return false, fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz), nil + } + + _, err = PatchToCommit(patch, files) + if err != nil { + return false, err.Error(), nil + } + + return true, "Patch is valid", nil +} diff --git a/internal/history/file.go b/internal/history/file.go index 8453ac272..7e206a2d9 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -50,6 +50,7 @@ func NewService(q *db.Queries, db *sql.DB) Service { return &service{ Broker: pubsub.NewBroker[File](), q: q, + db: db, } } @@ -100,30 +101,30 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten var err error // Retry loop for transaction conflicts - for attempt := 0; attempt < maxRetries; attempt++ { + for attempt := range maxRetries { // Start a transaction - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return File{}, fmt.Errorf("failed to begin transaction: %w", err) + tx, txErr := s.db.Begin() + if txErr != nil { + return File{}, fmt.Errorf("failed to begin transaction: %w", txErr) } // Create a new queries instance with the transaction qtx := s.q.WithTx(tx) // Try to create the file within the transaction - dbFile, err := qtx.CreateFile(ctx, db.CreateFileParams{ + dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{ ID: uuid.New().String(), SessionID: sessionID, Path: path, Content: content, Version: version, }) - if err != nil { + if txErr != nil { // Rollback the transaction tx.Rollback() // Check if this is a uniqueness constraint violation - if strings.Contains(err.Error(), "UNIQUE constraint failed") { + if strings.Contains(txErr.Error(), "UNIQUE constraint failed") { if attempt < maxRetries-1 { // If we have retries left, generate a new version and try again if strings.HasPrefix(version, "v") { @@ -138,12 +139,12 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten continue } } - return File{}, err + return File{}, txErr } // Commit the transaction - if err = tx.Commit(); err != nil { - return File{}, fmt.Errorf("failed to commit transaction: %w", err) + if txErr = tx.Commit(); txErr != nil { + return File{}, fmt.Errorf("failed to commit transaction: %w", txErr) } file = s.fromDBItem(dbFile) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index a5dadb89d..5e9785991 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -41,6 +41,7 @@ type Service interface { Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) Cancel(sessionID string) IsSessionBusy(sessionID string) bool + IsBusy() bool } type agent struct { @@ -95,6 +96,20 @@ func (a *agent) Cancel(sessionID string) { } } +func (a *agent) IsBusy() bool { + busy := false + a.activeRequests.Range(func(key, value interface{}) bool { + if cancelFunc, ok := value.(context.CancelFunc); ok { + if cancelFunc != nil { + busy = true + return false // Stop iterating + } + } + return true // Continue iterating + }) + return busy +} + func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy @@ -313,23 +328,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg } } a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) - } else { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: toolErr.Error(), - IsError: true, - } - for j := i; j < len(toolCalls); j++ { - toolResults[j] = message.ToolResult{ - ToolCallID: toolCalls[j].ID, - Content: "Previous tool failed", - IsError: true, - } - } - a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + break } - // If permission is denied or an error happens we cancel all the following tools - break } toolResults[i] = message.ToolResult{ ToolCallID: toolCall.ID, @@ -437,12 +437,27 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) if providerCfg.Disabled { return nil, fmt.Errorf("provider %s is not enabled", model.Provider) } - agentProvider, err := provider.NewProvider( - model.Provider, + maxTokens := model.DefaultMaxTokens + if agentConfig.MaxTokens > 0 { + maxTokens = agentConfig.MaxTokens + } + opts := []provider.ProviderClientOption{ provider.WithAPIKey(providerCfg.APIKey), provider.WithModel(model), provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), - provider.WithMaxTokens(agentConfig.MaxTokens), + provider.WithMaxTokens(maxTokens), + } + if model.Provider == models.ProviderOpenAI && model.CanReason { + opts = append( + opts, + provider.WithOpenAIOptions( + provider.WithReasoningEffort(agentConfig.ReasoningEffort), + ), + ) + } + agentProvider, err := provider.NewProvider( + model.Provider, + opts..., ) if err != nil { return nil, fmt.Errorf("could not create provider: %v", err) diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index 9120809ff..b2e6816d5 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -31,10 +31,9 @@ func CoderAgentTools( tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), - // TODO: see if we want to use this tool - // tools.NewPatchTool(lspClients, permissions, history), tools.NewSourcegraphTool(), tools.NewViewTool(lspClients), + tools.NewPatchTool(lspClients, permissions, history), tools.NewWriteTool(lspClients, permissions, history), NewAgentTool(sessions, messages, lspClients), }, otherTools..., diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go index 48307e6d3..87e9b4c89 100644 --- a/internal/llm/models/anthropic.go +++ b/internal/llm/models/anthropic.go @@ -23,6 +23,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.30, CostPer1MOut: 15.0, ContextWindow: 200000, + DefaultMaxTokens: 5000, }, Claude3Haiku: { ID: Claude3Haiku, @@ -34,6 +35,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.03, CostPer1MOut: 1.25, ContextWindow: 200000, + DefaultMaxTokens: 5000, }, Claude37Sonnet: { ID: Claude37Sonnet, @@ -45,6 +47,8 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.30, CostPer1MOut: 15.0, ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, }, Claude35Haiku: { ID: Claude35Haiku, @@ -56,6 +60,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.08, CostPer1MOut: 4.0, ContextWindow: 200000, + DefaultMaxTokens: 4096, }, Claude3Opus: { ID: Claude3Opus, @@ -67,5 +72,6 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 1.50, CostPer1MOut: 75.0, ContextWindow: 200000, + DefaultMaxTokens: 4096, }, } diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 4d4589bfd..bbce6130e 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -17,15 +17,12 @@ type Model struct { CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` } // Model IDs -const ( - // OpenAI - GPT4o ModelID = "gpt-4o" - GPT41 ModelID = "gpt-4.1" - - // GEMINI +const ( // GEMINI GEMINI25 ModelID = "gemini-2.5" GRMINI20Flash ModelID = "gemini-2.0-flash" @@ -37,7 +34,6 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" ProviderBedrock ModelProvider = "bedrock" ProviderGemini ModelProvider = "gemini" ProviderGROQ ModelProvider = "groq" @@ -47,59 +43,6 @@ const ( ) var SupportedModels = map[ModelID]Model{ - // // Anthropic - // Claude35Sonnet: { - // ID: Claude35Sonnet, - // Name: "Claude 3.5 Sonnet", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-5-sonnet-latest", - // CostPer1MIn: 3.0, - // CostPer1MInCached: 3.75, - // CostPer1MOutCached: 0.30, - // CostPer1MOut: 15.0, - // }, - // Claude3Haiku: { - // ID: Claude3Haiku, - // Name: "Claude 3 Haiku", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-haiku-latest", - // CostPer1MIn: 0.80, - // CostPer1MInCached: 1, - // CostPer1MOutCached: 0.08, - // CostPer1MOut: 4, - // }, - // Claude37Sonnet: { - // ID: Claude37Sonnet, - // Name: "Claude 3.7 Sonnet", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-7-sonnet-latest", - // CostPer1MIn: 3.0, - // CostPer1MInCached: 3.75, - // CostPer1MOutCached: 0.30, - // CostPer1MOut: 15.0, - // }, - // - // // OpenAI - GPT4o: { - ID: GPT4o, - Name: "GPT-4o", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0, - CostPer1MOut: 8.00, - }, - GPT41: { - ID: GPT41, - Name: "GPT-4.1", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0, - CostPer1MOut: 8.00, - }, // // // GEMINI // GEMINI25: { @@ -151,4 +94,5 @@ var SupportedModels = map[ModelID]Model{ func init() { maps.Copy(SupportedModels, AnthropicModels) + maps.Copy(SupportedModels, OpenAIModels) } diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go new file mode 100644 index 000000000..f0cbb298c --- /dev/null +++ b/internal/llm/models/openai.go @@ -0,0 +1,169 @@ +package models + +const ( + ProviderOpenAI ModelProvider = "openai" + + GPT41 ModelID = "gpt-4.1" + GPT41Mini ModelID = "gpt-4.1-mini" + GPT41Nano ModelID = "gpt-4.1-nano" + GPT45Preview ModelID = "gpt-4.5-preview" + GPT4o ModelID = "gpt-4o" + GPT4oMini ModelID = "gpt-4o-mini" + O1 ModelID = "o1" + O1Pro ModelID = "o1-pro" + O1Mini ModelID = "o1-mini" + O3 ModelID = "o3" + O3Mini ModelID = "o3-mini" + O4Mini ModelID = "o4-mini" +) + +var OpenAIModels = map[ModelID]Model{ + GPT41: { + ID: GPT41, + Name: "GPT 4.1", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 8.00, + ContextWindow: 1_047_576, + DefaultMaxTokens: 20000, + }, + GPT41Mini: { + ID: GPT41Mini, + Name: "GPT 4.1 mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 0.40, + CostPer1MInCached: 0.10, + CostPer1MOutCached: 0.0, + CostPer1MOut: 1.60, + ContextWindow: 200_000, + DefaultMaxTokens: 20000, + }, + GPT41Nano: { + ID: GPT41Nano, + Name: "GPT 4.1 nano", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1-nano", + CostPer1MIn: 0.10, + CostPer1MInCached: 0.025, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.40, + ContextWindow: 1_047_576, + DefaultMaxTokens: 20000, + }, + GPT45Preview: { + ID: GPT45Preview, + Name: "GPT 4.5 preview", + Provider: ProviderOpenAI, + APIModel: "gpt-4.5-preview", + CostPer1MIn: 75.00, + CostPer1MInCached: 37.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 150.00, + ContextWindow: 128_000, + DefaultMaxTokens: 15000, + }, + GPT4o: { + ID: GPT4o, + Name: "GPT 4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4o", + CostPer1MIn: 2.50, + CostPer1MInCached: 1.25, + CostPer1MOutCached: 0.0, + CostPer1MOut: 10.00, + ContextWindow: 128_000, + DefaultMaxTokens: 4096, + }, + GPT4oMini: { + ID: GPT4oMini, + Name: "GPT 4o mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4o-mini", + CostPer1MIn: 0.15, + CostPer1MInCached: 0.075, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.60, + ContextWindow: 128_000, + }, + O1: { + ID: O1, + Name: "O1", + Provider: ProviderOpenAI, + APIModel: "o1", + CostPer1MIn: 15.00, + CostPer1MInCached: 7.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 60.00, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O1Pro: { + ID: O1Pro, + Name: "o1 pro", + Provider: ProviderOpenAI, + APIModel: "o1-pro", + CostPer1MIn: 150.00, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 600.00, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O1Mini: { + ID: O1Mini, + Name: "o1 mini", + Provider: ProviderOpenAI, + APIModel: "o1-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.55, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 128_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O3: { + ID: O3, + Name: "o3", + Provider: ProviderOpenAI, + APIModel: "o3", + CostPer1MIn: 10.00, + CostPer1MInCached: 2.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 40.00, + ContextWindow: 200_000, + CanReason: true, + }, + O3Mini: { + ID: O3Mini, + Name: "o3 mini", + Provider: ProviderOpenAI, + APIModel: "o3-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.55, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O4Mini: { + ID: O4Mini, + Name: "o4 mini", + Provider: ProviderOpenAI, + APIModel: "o4-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.275, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 128_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, +} diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 3a06911da..d7ca7b2fd 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -25,44 +25,49 @@ func CoderPrompt(provider models.ModelProvider) string { } const baseOpenAICoderPrompt = ` -You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. - -### ── INTERNAL REFLECTION ── -• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). -• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. - -### ── PUBLIC RESPONSE RULES ── -• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. -• Use GitHub‑flavored Markdown. -• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. - -### ── CONTEXT & MEMORY ── -• Infer file intent from directory structure before editing. -• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. - -### ── AUTONOMY PRIORITY ── -**Ask‑Only‑If Decision Tree:** -1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. -2. **Critical unknown?** (no docs/tests; cannot infer) → ask. -3. **Tool failure after two self‑attempts?** → ask. -Otherwise, proceed autonomously. - -### ── SAFETY & STYLE ── -• Mimic existing code style; verify libraries exist before import. -• Never commit unless explicitly told. -• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). -• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. - -### ── TOOL USAGE ── -• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. -• Communicate with the user only via visible text; do not expose tool output or internal reasoning. - -### ── EXAMPLES ── -user: list files -assistant: ls - -user: write tests for new feature -assistant: [searches & edits autonomously, no extra chit‑chat] +You are operating as and within the OpenCode CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful. + +You can: +- Receive user prompts, project context, and files. +- Stream responses and emit function calls (e.g., shell commands, code edits). +- Apply patches, run commands, and manage user approvals based on policy. +- Work inside a sandboxed, git-backed workspace with rollback support. +- Log telemetry so sessions can be replayed or inspected later. +- More details on your functionality are available at "opencode --help" + + +You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. If you are not sure about file content or codebase structure pertaining to the user's request, use your tools to read files and gather the relevant information: do NOT guess or make up an answer. + +Please resolve the user's task by editing and testing the code files in your current code execution session. You are a deployed coding agent. Your session allows for you to modify and run code. The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct. + +You MUST adhere to the following criteria when executing the task: +- Working on the repo(s) in the current environment is allowed, even if they are proprietary. +- Analyzing code for vulnerabilities is allowed. +- Showing user code and tool call details is allowed. +- User instructions may overwrite the *CODING GUIDELINES* section in this developer message. +- If completing the user's task requires writing or modifying files: + - Your code and final answer should follow these *CODING GUIDELINES*: + - Fix the problem at the root cause rather than applying surface-level patches, when possible. + - Avoid unneeded complexity in your solution. + - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them. + - Update documentation as necessary. + - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. + - Use "git log" and "git blame" to search the history of the codebase if additional context is required; internet access is disabled. + - NEVER add copyright or license headers unless specifically requested. + - You do not need to "git commit" your changes; this will be done automatically for you. + - Once you finish coding, you must + - Check "git status" to sanity check your changes; revert any scratch files or changes. + - Remove all inline comments you added as much as possible, even if they look normal. Check using "git diff". Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments. + - Check if you accidentally add copyright or license headers. If so, remove them. + - For smaller tasks, describe in brief bullet points + - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer. +- If completing the user's task DOES NOT require writing or modifying files (e.g., the user asks a question about the code base): + - Respond in a friendly tune as a remote teammate, who is knowledgeable, capable and eager to help with coding. +- When your task involves writing or modifying files: + - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using "apply_patch". Instead, reference the file as already saved. + - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them. +- When doing things with paths, always use use the full path, if the working directory is /abc/xyz and you want to edit the file abc.go in the working dir refer to it as /abc/xyz/abc.go. +- If you send a path not including the working dir, the working dir will be prepended to it. ` const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. @@ -125,7 +130,7 @@ assistant: src/foo.c user: write tests for new feature -assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests] +assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit/patch file tool to write new tests] # Proactiveness diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 13ce934f2..6c6f74988 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -14,11 +14,13 @@ import ( "github.com/kujtimiihoxha/opencode/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" ) type openaiOptions struct { - baseURL string - disableCache bool + baseURL string + disableCache bool + reasoningEffort string } type OpenAIOption func(*openaiOptions) @@ -32,7 +34,9 @@ type openaiClient struct { type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { - openaiOpts := openaiOptions{} + openaiOpts := openaiOptions{ + reasoningEffort: "medium", + } for _, o := range opts.openaiOptions { o(&openaiOpts) } @@ -138,12 +142,29 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { } func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { - return openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.APIModel), - Messages: messages, - MaxTokens: openai.Int(o.providerOptions.maxTokens), - Tools: tools, + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + Tools: tools, } + + if o.providerOptions.model.CanReason == true { + params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) + switch o.options.reasoningEffort { + case "low": + params.ReasoningEffort = shared.ReasoningEffortLow + case "medium": + params.ReasoningEffort = shared.ReasoningEffortMedium + case "high": + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + params.ReasoningEffort = shared.ReasoningEffortMedium + } + } else { + params.MaxTokens = openai.Int(o.providerOptions.maxTokens) + } + + return params } func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { @@ -359,3 +380,15 @@ func WithOpenAIDisableCache() OpenAIOption { } } +func WithReasoningEffort(effort string) OpenAIOption { + return func(options *openaiOptions) { + defaultReasoningEffort := "medium" + switch effort { + case "low", "medium", "high": + defaultReasoningEffort = effort + default: + logging.Warn("Invalid reasoning effort, using default: medium") + } + options.reasoningEffort = defaultReasoningEffort + } +} diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 40262ce2b..e3c7b7b61 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -192,6 +192,42 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { } func skipHidden(path string) bool { + // Check for hidden files (starting with a dot) base := filepath.Base(path) - return base != "." && strings.HasPrefix(base, ".") + if base != "." && strings.HasPrefix(base, ".") { + return true + } + + // List of commonly ignored directories in development projects + commonIgnoredDirs := map[string]bool{ + "node_modules": true, + "vendor": true, + "dist": true, + "build": true, + "target": true, + ".git": true, + ".idea": true, + ".vscode": true, + "__pycache__": true, + "bin": true, + "obj": true, + "out": true, + "coverage": true, + "tmp": true, + "temp": true, + "logs": true, + "generated": true, + "bower_components": true, + "jspm_packages": true, + } + + // Check if any path component is in our ignore list + parts := strings.SplitSeq(path, string(os.PathSeparator)) + for part := range parts { + if commonIgnoredDirs[part] { + return true + } + } + + return false } diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 3436dd7eb..086a5e686 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -17,9 +17,10 @@ import ( ) type GrepParams struct { - Pattern string `json:"pattern"` - Path string `json:"path"` - Include string `json:"include"` + Pattern string `json:"pattern"` + Path string `json:"path"` + Include string `json:"include"` + LiteralText bool `json:"literal_text"` } type grepMatch struct { @@ -45,11 +46,12 @@ WHEN TO USE THIS TOOL: HOW TO USE: - Provide a regex pattern to search for within file contents +- Set literal_text=true if you want to search for the exact text with special characters (recommended for non-regex users) - Optionally specify a starting directory (defaults to current working directory) - Optionally provide an include pattern to filter which files to search - Results are sorted with most recently modified files first -REGEX PATTERN SYNTAX: +REGEX PATTERN SYNTAX (when literal_text=false): - Supports standard regular expression syntax - 'function' searches for the literal text "function" - 'log\..*Error' finds text starting with "log." and ending with "Error" @@ -69,7 +71,8 @@ LIMITATIONS: TIPS: - For faster, more targeted searches, first use Glob to find relevant files, then use Grep - When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed` +- Always check if results are truncated and refine your search pattern if needed +- Use literal_text=true when searching for exact text containing special characters like dots, parentheses, etc.` ) func NewGrepTool() BaseTool { @@ -93,11 +96,27 @@ func (g *grepTool) Info() ToolInfo { "type": "string", "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")", }, + "literal_text": map[string]any{ + "type": "boolean", + "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.", + }, }, Required: []string{"pattern"}, } } +// escapeRegexPattern escapes special regex characters so they're treated as literal characters +func escapeRegexPattern(pattern string) string { + specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"} + escaped := pattern + + for _, char := range specialChars { + escaped = strings.ReplaceAll(escaped, char, "\\"+char) + } + + return escaped +} + func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params GrepParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -108,12 +127,18 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("pattern is required"), nil } + // If literal_text is true, escape the pattern + searchPattern := params.Pattern + if params.LiteralText { + searchPattern = escapeRegexPattern(params.Pattern) + } + searchPath := params.Path if searchPath == "" { searchPath = config.WorkingDirectory() } - matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100) + matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100) if err != nil { return ToolResponse{}, fmt.Errorf("error searching files: %w", err) } diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go index 12060d72a..0f879462c 100644 --- a/internal/llm/tools/patch.go +++ b/internal/llm/tools/patch.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "time" "github.com/kujtimiihoxha/opencode/internal/config" @@ -17,19 +16,13 @@ import ( ) type PatchParams struct { - FilePath string `json:"file_path"` - Patch string `json:"patch"` -} - -type PatchPermissionsParams struct { - FilePath string `json:"file_path"` - Diff string `json:"diff"` + PatchText string `json:"patch_text"` } type PatchResponseMetadata struct { - Diff string `json:"diff"` - Additions int `json:"additions"` - Removals int `json:"removals"` + FilesChanged []string `json:"files_changed"` + Additions int `json:"additions"` + Removals int `json:"removals"` } type patchTool struct { @@ -39,47 +32,35 @@ type patchTool struct { } const ( - // TODO: test if this works as expected PatchToolName = "patch" - patchDescription = `Applies a patch to a file. This tool is similar to the edit tool but accepts a unified diff patch instead of old/new strings. + patchDescription = `Applies a patch to multiple files in one operation. This tool is useful for making coordinated changes across multiple files. + +The patch text must follow this format: +*** Begin Patch +*** Update File: /path/to/file +@@ Context line (unique within the file) + Line to keep +-Line to remove ++Line to add + Line to keep +*** Add File: /path/to/new/file ++Content of the new file ++More content +*** Delete File: /path/to/file/to/delete +*** End Patch Before using this tool: - -1. Use the FileRead tool to understand the file's contents and context - -2. Verify the directory path is correct: - - Use the LS tool to verify the parent directory exists and is the correct location - -To apply a patch, provide the following: -1. file_path: The absolute path to the file to modify (must be absolute, not relative) -2. patch: A unified diff patch to apply to the file - -The tool will apply the patch to the specified file. The patch must be in unified diff format. +1. Use the FileRead tool to understand the files' contents and context +2. Verify all file paths are correct (use the LS tool) CRITICAL REQUIREMENTS FOR USING THIS TOOL: -1. PATCH FORMAT: The patch must be in unified diff format, which includes: - - File headers (--- a/file_path, +++ b/file_path) - - Hunk headers (@@ -start,count +start,count @@) - - Added lines (prefixed with +) - - Removed lines (prefixed with -) - -2. CONTEXT: The patch must include sufficient context around the changes to ensure it applies correctly. - -3. VERIFICATION: Before using this tool: - - Ensure the patch applies cleanly to the current state of the file - - Check that the file exists and you have read it first - -WARNING: If you do not follow these requirements: - - The tool will fail if the patch doesn't apply cleanly - - You may change the wrong parts of the file if the context is insufficient - -When applying patches: - - Ensure the patch results in idiomatic, correct code - - Do not leave the code in a broken state - - Always use absolute file paths (starting with /) +1. UNIQUENESS: Context lines MUST uniquely identify the specific sections you want to change +2. PRECISION: All whitespace, indentation, and surrounding code must match exactly +3. VALIDATION: Ensure edits result in idiomatic, correct code +4. PATHS: Always use absolute file paths (starting with /) -Remember: patches are a powerful way to make multiple related changes at once, but they require careful preparation.` +The tool will apply all changes in a single atomic operation.` ) func NewPatchTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { @@ -95,16 +76,12 @@ func (p *patchTool) Info() ToolInfo { Name: PatchToolName, Description: patchDescription, Parameters: map[string]any{ - "file_path": map[string]any{ + "patch_text": map[string]any{ "type": "string", - "description": "The absolute path to the file to modify", - }, - "patch": map[string]any{ - "type": "string", - "description": "The unified diff patch to apply", + "description": "The full patch text that describes all changes to be made", }, }, - Required: []string{"file_path", "patch"}, + Required: []string{"patch_text"}, } } @@ -114,187 +91,278 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse("invalid parameters"), nil } - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil + if params.PatchText == "" { + return NewTextErrorResponse("patch_text is required"), nil } - if params.Patch == "" { - return NewTextErrorResponse("patch is required"), nil - } + // Identify all files needed for the patch and verify they've been read + filesToRead := diff.IdentifyFilesNeeded(params.PatchText) + for _, filePath := range filesToRead { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - if !filepath.IsAbs(params.FilePath) { - wd := config.WorkingDirectory() - params.FilePath = filepath.Join(wd, params.FilePath) - } + if getLastReadTime(absPath).IsZero() { + return NewTextErrorResponse(fmt.Sprintf("you must read the file %s before patching it. Use the FileRead tool first", filePath)), nil + } - // Check if file exists - fileInfo, err := os.Stat(params.FilePath) - if err != nil { - if os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil + fileInfo, err := os.Stat(absPath) + if err != nil { + if os.IsNotExist(err) { + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", absPath)), nil + } + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) } - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil - } + if fileInfo.IsDir() { + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", absPath)), nil + } - if getLastReadTime(params.FilePath).IsZero() { - return NewTextErrorResponse("you must read the file before patching it. Use the View tool first"), nil + modTime := fileInfo.ModTime() + lastRead := getLastReadTime(absPath) + if modTime.After(lastRead) { + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + absPath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil + } } - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(params.FilePath) - if modTime.After(lastRead) { - return NewTextErrorResponse( - fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), - )), nil - } + // Check for new files to ensure they don't already exist + filesToAdd := diff.IdentifyFilesAdded(params.PatchText) + for _, filePath := range filesToAdd { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - // Read the current file content - content, err := os.ReadFile(params.FilePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) + _, err := os.Stat(absPath) + if err == nil { + return NewTextErrorResponse(fmt.Sprintf("file already exists and cannot be added: %s", absPath)), nil + } else if !os.IsNotExist(err) { + return ToolResponse{}, fmt.Errorf("failed to check file: %w", err) + } } - oldContent := string(content) + // Load all required files + currentFiles := make(map[string]string) + for _, filePath := range filesToRead { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - // Parse and apply the patch - diffResult, err := diff.ParseUnifiedDiff(params.Patch) - if err != nil { - return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %v", err)), nil + content, err := os.ReadFile(absPath) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to read file %s: %w", absPath, err) + } + currentFiles[filePath] = string(content) } - // Apply the patch to get the new content - newContent, err := applyPatch(oldContent, diffResult) + // Process the patch + patch, fuzz, err := diff.TextToPatch(params.PatchText, currentFiles) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %v", err)), nil + return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %s", err)), nil } - if oldContent == newContent { - return NewTextErrorResponse("patch did not result in any changes to the file"), nil + if fuzz > 0 { + return NewTextErrorResponse(fmt.Sprintf("patch contains fuzzy matches (fuzz level: %d). Please make your context lines more precise", fuzz)), nil } + // Convert patch to commit + commit, err := diff.PatchToCommit(patch, currentFiles) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("failed to create commit from patch: %s", err)), nil + } + + // Get session ID and message ID sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for patching a file") + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a patch") } - // Generate a diff for permission request and metadata - diffText, additions, removals := diff.GenerateDiff( - oldContent, - newContent, - params.FilePath, - ) - - // Request permission to apply the patch - p.permissions.Request( - permission.CreatePermissionRequest{ - Path: filepath.Dir(params.FilePath), - ToolName: PatchToolName, - Action: "patch", - Description: fmt.Sprintf("Apply patch to file %s", params.FilePath), - Params: PatchPermissionsParams{ - FilePath: params.FilePath, - Diff: diffText, - }, - }, - ) - - // Write the new content to the file - err = os.WriteFile(params.FilePath, []byte(newContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + // Request permission for all changes + for path, change := range commit.Changes { + switch change.Type { + case diff.ActionAdd: + dir := filepath.Dir(path) + patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "create", + Description: fmt.Sprintf("Create file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + case diff.ActionUpdate: + currentContent := "" + if change.OldContent != nil { + currentContent = *change.OldContent + } + newContent := "" + if change.NewContent != nil { + newContent = *change.NewContent + } + patchDiff, _, _ := diff.GenerateDiff(currentContent, newContent, path) + dir := filepath.Dir(path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "update", + Description: fmt.Sprintf("Update file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + case diff.ActionDelete: + dir := filepath.Dir(path) + patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "delete", + Description: fmt.Sprintf("Delete file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + } } - // Update file history - file, err := p.files.GetByPathAndSession(ctx, params.FilePath, sessionID) - if err != nil { - _, err = p.files.Create(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + // Apply the changes to the filesystem + err = diff.ApplyCommit(commit, func(path string, content string) error { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) } - } - if file.Content != oldContent { - // User manually changed the content, store an intermediate version - _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + + // Create parent directories if needed + dir := filepath.Dir(absPath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create parent directories for %s: %w", absPath, err) } - } - // Store the new version - _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, newContent) + + return os.WriteFile(absPath, []byte(content), 0o644) + }, func(path string) error { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } + return os.Remove(absPath) + }) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %s", err)), nil } - recordFileWrite(params.FilePath) - recordFileRead(params.FilePath) + // Update file history for all modified files + changedFiles := []string{} + totalAdditions := 0 + totalRemovals := 0 - // Wait for LSP diagnostics and include them in the response - waitForLspDiagnostics(ctx, params.FilePath, p.lspClients) - text := fmt.Sprintf("\nPatch applied to file: %s\n\n", params.FilePath) - text += getDiagnostics(params.FilePath, p.lspClients) + for path, change := range commit.Changes { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } + changedFiles = append(changedFiles, absPath) - return WithResponseMetadata( - NewTextResponse(text), - PatchResponseMetadata{ - Diff: diffText, - Additions: additions, - Removals: removals, - }), nil -} + oldContent := "" + if change.OldContent != nil { + oldContent = *change.OldContent + } -// applyPatch applies a parsed diff to a string and returns the resulting content -func applyPatch(content string, diffResult diff.DiffResult) (string, error) { - lines := strings.Split(content, "\n") + newContent := "" + if change.NewContent != nil { + newContent = *change.NewContent + } - // Process each hunk in the diff - for _, hunk := range diffResult.Hunks { - // Parse the hunk header to get line numbers - var oldStart, oldCount, newStart, newCount int - _, err := fmt.Sscanf(hunk.Header, "@@ -%d,%d +%d,%d @@", &oldStart, &oldCount, &newStart, &newCount) - if err != nil { - // Try alternative format with single line counts - _, err = fmt.Sscanf(hunk.Header, "@@ -%d +%d @@", &oldStart, &newStart) + // Calculate diff statistics + _, additions, removals := diff.GenerateDiff(oldContent, newContent, path) + totalAdditions += additions + totalRemovals += removals + + // Update history + file, err := p.files.GetByPathAndSession(ctx, absPath, sessionID) + if err != nil && change.Type != diff.ActionAdd { + // If not adding a file, create history entry for existing file + _, err = p.files.Create(ctx, sessionID, absPath, oldContent) if err != nil { - return "", fmt.Errorf("invalid hunk header format: %s", hunk.Header) + fmt.Printf("Error creating file history: %v\n", err) } - oldCount = 1 - newCount = 1 } - // Adjust for 0-based array indexing - oldStart-- - newStart-- - - // Apply the changes - newLines := make([]string, 0) - newLines = append(newLines, lines[:oldStart]...) - - // Process the hunk lines in order - currentOldLine := oldStart - for _, line := range hunk.Lines { - switch line.Kind { - case diff.LineContext: - newLines = append(newLines, line.Content) - currentOldLine++ - case diff.LineRemoved: - // Skip this line in the output (it's being removed) - currentOldLine++ - case diff.LineAdded: - // Add the new line - newLines = append(newLines, line.Content) + if err == nil && change.Type != diff.ActionAdd && file.Content != oldContent { + // User manually changed content, store intermediate version + _, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) } } - // Append the rest of the file - newLines = append(newLines, lines[currentOldLine:]...) - lines = newLines + // Store new version + if change.Type == diff.ActionDelete { + _, err = p.files.CreateVersion(ctx, sessionID, absPath, "") + } else { + _, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent) + } + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + + // Record file operations + recordFileWrite(absPath) + recordFileRead(absPath) } - return strings.Join(lines, "\n"), nil -} + // Run LSP diagnostics on all changed files + for _, filePath := range changedFiles { + waitForLspDiagnostics(ctx, filePath, p.lspClients) + } + result := fmt.Sprintf("Patch applied successfully. %d files changed, %d additions, %d removals", + len(changedFiles), totalAdditions, totalRemovals) + + diagnosticsText := "" + for _, filePath := range changedFiles { + diagnosticsText += getDiagnostics(filePath, p.lspClients) + } + + if diagnosticsText != "" { + result += "\n\nDiagnostics:\n" + diagnosticsText + } + + return WithResponseMetadata( + NewTextResponse(result), + PatchResponseMetadata{ + FilesChanged: changedFiles, + Additions: totalAdditions, + Removals: totalRemovals, + }), nil +} diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 3fa4ca116..dc02b34f3 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -24,6 +24,11 @@ type viewTool struct { lspClients map[string]*lsp.Client } +type ViewResponseMetadata struct { + FilePath string `json:"file_path"` + Content string `json:"content"` +} + const ( ViewToolName = "view" MaxReadSize = 250 * 1024 @@ -180,7 +185,13 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) output += "\n\n" output += getDiagnostics(filePath, v.lspClients) recordFileRead(filePath) - return NewTextResponse(output), nil + return WithResponseMetadata( + NewTextResponse(output), + ViewResponseMetadata{ + FilePath: filePath, + Content: content, + }, + ), nil } func addLineNumbers(content string, startLine int) string { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index ded0639bb..537ef392c 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -1,6 +1,9 @@ package chat import ( + "os" + "os/exec" + "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" @@ -19,13 +22,15 @@ type editorCmp struct { } type focusedEditorKeyMaps struct { - Send key.Binding - Blur key.Binding + Send key.Binding + OpenEditor key.Binding + Blur key.Binding } type bluredEditorKeyMaps struct { - Send key.Binding - Focus key.Binding + Send key.Binding + Focus key.Binding + OpenEditor key.Binding } var focusedKeyMaps = focusedEditorKeyMaps{ @@ -37,6 +42,10 @@ var focusedKeyMaps = focusedEditorKeyMaps{ key.WithKeys("esc"), key.WithHelp("esc", "focus messages"), ), + OpenEditor: key.NewBinding( + key.WithKeys("ctrl+e"), + key.WithHelp("ctrl+e", "open editor"), + ), } var bluredKeyMaps = bluredEditorKeyMaps{ @@ -48,6 +57,40 @@ var bluredKeyMaps = bluredEditorKeyMaps{ key.WithKeys("i"), key.WithHelp("i", "focus editor"), ), + OpenEditor: key.NewBinding( + key.WithKeys("ctrl+e"), + key.WithHelp("ctrl+e", "open editor"), + ), +} + +func openEditor() tea.Cmd { + editor := os.Getenv("EDITOR") + if editor == "" { + editor = "nvim" + } + + tmpfile, err := os.CreateTemp("", "msg_*.md") + if err != nil { + return util.ReportError(err) + } + tmpfile.Close() + c := exec.Command(editor, tmpfile.Name()) //nolint:gosec + c.Stdin = os.Stdin + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return tea.ExecProcess(c, func(err error) tea.Msg { + if err != nil { + return util.ReportError(err) + } + content, err := os.ReadFile(tmpfile.Name()) + if err != nil { + return util.ReportError(err) + } + os.Remove(tmpfile.Name()) + return SendMsg{ + Text: string(content), + } + }) } func (m *editorCmp) Init() tea.Cmd { @@ -82,6 +125,10 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil case tea.KeyMsg: + if key.Matches(msg, focusedKeyMaps.OpenEditor) { + m.textarea.Blur() + return m, openEditor() + } // if the key does not match any binding, return if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { return m, m.send() @@ -108,9 +155,10 @@ func (m *editorCmp) View() string { return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View()) } -func (m *editorCmp) SetSize(width, height int) { +func (m *editorCmp) SetSize(width, height int) tea.Cmd { m.textarea.SetWidth(width - 3) // account for the prompt and padding right m.textarea.SetHeight(height) + return nil } func (m *editorCmp) GetSize() (int, int) { diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go new file mode 100644 index 000000000..f95b53731 --- /dev/null +++ b/internal/tui/components/chat/list.go @@ -0,0 +1,463 @@ +package chat + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/spinner" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +type messagesCmp struct { + app *app.App + width, height int + writingMode bool + viewport viewport.Model + session session.Session + messages []message.Message + uiMessages []uiMessage + currentMsgID string + mutex sync.Mutex + cachedContent map[string][]uiMessage + spinner spinner.Model + rendering bool +} +type renderFinishedMsg struct{} + +func (m *messagesCmp) Init() tea.Cmd { + return tea.Batch(m.viewport.Init()) +} + +func (m *messagesCmp) preloadSessions() tea.Cmd { + return func() tea.Msg { + sessions, err := m.app.Sessions.List(context.Background()) + if err != nil { + return util.ReportError(err)() + } + if len(sessions) == 0 { + return nil + } + if len(sessions) > 20 { + sessions = sessions[:20] + } + for _, s := range sessions { + messages, err := m.app.Messages.List(context.Background(), s.ID) + if err != nil { + return util.ReportError(err)() + } + if len(messages) == 0 { + continue + } + m.cacheSessionMessages(messages, m.width) + + } + logging.Debug("preloaded sessions") + + return nil + } +} + +func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { + m.mutex.Lock() + defer m.mutex.Unlock() + pos := 0 + if m.width == 0 { + return + } + for inx, msg := range messages { + switch msg.Role { + case message.User: + userMsg := renderUserMessage( + msg, + false, + width, + pos, + ) + m.cachedContent[msg.ID] = []uiMessage{userMsg} + pos += userMsg.height + 1 // + 1 for spacing + case message.Assistant: + assistantMessages := renderAssistantMessage( + msg, + inx, + messages, + m.app.Messages, + "", + width, + pos, + ) + for _, msg := range assistantMessages { + pos += msg.height + 1 // + 1 for spacing + } + m.cachedContent[msg.ID] = assistantMessages + } + } +} + +func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + switch msg := msg.(type) { + case EditorFocusMsg: + m.writingMode = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + cmd := m.SetSession(msg) + return m, cmd + } + return m, nil + case SessionClearedMsg: + m.session = session.Session{} + m.messages = make([]message.Message, 0) + m.currentMsgID = "" + m.rendering = false + return m, nil + + case renderFinishedMsg: + m.rendering = false + m.viewport.GotoBottom() + case tea.KeyMsg: + if m.writingMode { + return m, nil + } + case pubsub.Event[message.Message]: + needsRerender := false + if msg.Type == pubsub.CreatedEvent { + if msg.Payload.SessionID == m.session.ID { + + messageExists := false + for _, v := range m.messages { + if v.ID == msg.Payload.ID { + messageExists = true + break + } + } + + if !messageExists { + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + + m.messages = append(m.messages, msg.Payload) + delete(m.cachedContent, m.currentMsgID) + m.currentMsgID = msg.Payload.ID + needsRerender = true + } + } + // There are tool calls from the child task + for _, v := range m.messages { + for _, c := range v.ToolCalls() { + if c.ID == msg.Payload.SessionID { + delete(m.cachedContent, v.ID) + needsRerender = true + } + } + } + } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + for i, v := range m.messages { + if v.ID == msg.Payload.ID { + m.messages[i] = msg.Payload + delete(m.cachedContent, msg.Payload.ID) + needsRerender = true + break + } + } + } + if needsRerender { + m.renderView() + if len(m.messages) > 0 { + if (msg.Type == pubsub.CreatedEvent) || + (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) { + m.viewport.GotoBottom() + } + } + } + } + + u, cmd := m.viewport.Update(msg) + m.viewport = u + cmds = append(cmds, cmd) + + spinner, cmd := m.spinner.Update(msg) + m.spinner = spinner + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) +} + +func (m *messagesCmp) IsAgentWorking() bool { + return m.app.CoderAgent.IsSessionBusy(m.session.ID) +} + +func formatTimeDifference(unixTime1, unixTime2 int64) string { + diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1))) + + if diffSeconds < 60 { + return fmt.Sprintf("%.1fs", diffSeconds) + } + + minutes := int(diffSeconds / 60) + seconds := int(diffSeconds) % 60 + return fmt.Sprintf("%dm%ds", minutes, seconds) +} + +func (m *messagesCmp) renderView() { + m.uiMessages = make([]uiMessage, 0) + pos := 0 + + if m.width == 0 { + return + } + for inx, msg := range m.messages { + switch msg.Role { + case message.User: + if messages, ok := m.cachedContent[msg.ID]; ok { + m.uiMessages = append(m.uiMessages, messages...) + continue + } + userMsg := renderUserMessage( + msg, + msg.ID == m.currentMsgID, + m.width, + pos, + ) + m.uiMessages = append(m.uiMessages, userMsg) + m.cachedContent[msg.ID] = []uiMessage{userMsg} + pos += userMsg.height + 1 // + 1 for spacing + case message.Assistant: + if messages, ok := m.cachedContent[msg.ID]; ok { + m.uiMessages = append(m.uiMessages, messages...) + continue + } + assistantMessages := renderAssistantMessage( + msg, + inx, + m.messages, + m.app.Messages, + m.currentMsgID, + m.width, + pos, + ) + for _, msg := range assistantMessages { + m.uiMessages = append(m.uiMessages, msg) + pos += msg.height + 1 // + 1 for spacing + } + m.cachedContent[msg.ID] = assistantMessages + } + } + + messages := make([]string, 0) + for _, v := range m.uiMessages { + messages = append(messages, v.content, + styles.BaseStyle. + Width(m.width). + Render( + "", + ), + ) + } + m.viewport.SetContent( + styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + messages..., + ), + ), + ) +} + +func (m *messagesCmp) View() string { + if m.rendering { + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + "Loading...", + m.working(), + m.help(), + ), + ) + } + if len(m.messages) == 0 { + content := styles.BaseStyle. + Width(m.width). + Height(m.height - 1). + Render( + m.initialScreen(), + ) + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + content, + "", + m.help(), + ), + ) + } + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + m.viewport.View(), + m.working(), + m.help(), + ), + ) +} + +func hasToolsWithoutResponse(messages []message.Message) bool { + toolCalls := make([]message.ToolCall, 0) + toolResults := make([]message.ToolResult, 0) + for _, m := range messages { + toolCalls = append(toolCalls, m.ToolCalls()...) + toolResults = append(toolResults, m.ToolResults()...) + } + + for _, v := range toolCalls { + found := false + for _, r := range toolResults { + if v.ID == r.ToolCallID { + found = true + break + } + } + if !found { + return true + } + } + + return false +} + +func (m *messagesCmp) working() string { + text := "" + if m.IsAgentWorking() { + task := "Thinking..." + lastMessage := m.messages[len(m.messages)-1] + if hasToolsWithoutResponse(m.messages) { + task = "Waiting for tool response..." + } else if !lastMessage.IsFinished() { + lastUpdate := lastMessage.UpdatedAt + currentTime := time.Now().Unix() + if lastMessage.Content().String() != "" && lastUpdate != 0 && currentTime-lastUpdate > 5 { + task = "Building tool call..." + } else if lastMessage.Content().String() == "" { + task = "Generating..." + } + task = "" + } + if task != "" { + text += styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render( + fmt.Sprintf("%s %s ", m.spinner.View(), task), + ) + } + } + return text +} + +func (m *messagesCmp) help() string { + text := "" + + if m.writingMode { + text += lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), + ) + } else { + text += lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"), + ) + } + + return styles.BaseStyle. + Width(m.width). + Render(text) +} + +func (m *messagesCmp) initialScreen() string { + return styles.BaseStyle.Width(m.width).Render( + lipgloss.JoinVertical( + lipgloss.Top, + header(m.width), + "", + lspsConfigured(m.width), + ), + ) +} + +func (m *messagesCmp) SetSize(width, height int) tea.Cmd { + if m.width == width && m.height == height { + return nil + } + m.width = width + m.height = height + m.viewport.Width = width + m.viewport.Height = height - 2 + m.renderView() + return m.preloadSessions() +} + +func (m *messagesCmp) GetSize() (int, int) { + return m.width, m.height +} + +func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { + if m.session.ID == session.ID { + return nil + } + m.rendering = true + return func() tea.Msg { + m.session = session + messages, err := m.app.Messages.List(context.Background(), session.ID) + if err != nil { + return util.ReportError(err) + } + m.messages = messages + m.currentMsgID = m.messages[len(m.messages)-1].ID + delete(m.cachedContent, m.currentMsgID) + m.renderView() + return renderFinishedMsg{} + } +} + +func (m *messagesCmp) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(m.viewport.KeyMap) + return bindings +} + +func NewMessagesCmp(app *app.App) tea.Model { + s := spinner.New() + s.Spinner = spinner.Pulse + return &messagesCmp{ + app: app, + writingMode: true, + cachedContent: make(map[string][]uiMessage), + viewport: viewport.New(0, 0), + spinner: s, + } +} diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go new file mode 100644 index 000000000..be6c7ce50 --- /dev/null +++ b/internal/tui/components/chat/message.go @@ -0,0 +1,561 @@ +package chat + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/charmbracelet/glamour" + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" +) + +type uiMessageType int + +const ( + userMessageType uiMessageType = iota + assistantMessageType + toolMessageType + + maxResultHeight = 15 +) + +var diffStyle = diff.NewStyleConfig(diff.WithShowHeader(false), diff.WithShowHunkHeader(false)) + +type uiMessage struct { + ID string + messageType uiMessageType + position int + height int + content string +} + +type renderCache struct { + mutex sync.Mutex + cache map[string][]uiMessage +} + +func toMarkdown(content string, focused bool, width int) string { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(false)), + glamour.WithWordWrap(width), + ) + if focused { + r, _ = glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(width), + ) + } + rendered, _ := r.Render(content) + return rendered +} + +func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...string) string { + style := styles.BaseStyle. + Width(width - 1). + BorderLeft(true). + Foreground(styles.ForgroundDim). + BorderForeground(styles.PrimaryColor). + BorderStyle(lipgloss.ThickBorder()) + if isUser { + style = style. + BorderForeground(styles.Blue) + } + parts := []string{ + styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), styles.Background), + } + + // remove newline at the end + parts[0] = strings.TrimSuffix(parts[0], "\n") + if len(info) > 0 { + parts = append(parts, info...) + } + rendered := style.Render( + lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ), + ) + + return rendered +} + +func renderUserMessage(msg message.Message, isFocused bool, width int, position int) uiMessage { + content := renderMessage(msg.Content().String(), true, isFocused, width) + userMsg := uiMessage{ + ID: msg.ID, + messageType: userMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + } + return userMsg +} + +// Returns multiple uiMessages because of the tool calls +func renderAssistantMessage( + msg message.Message, + msgIndex int, + allMessages []message.Message, // we need this to get tool results and the user message + messagesService message.Service, // We need this to get the task tool messages + focusedUIMessageId string, + width int, + position int, +) []uiMessage { + // find the user message that is before this assistant message + var userMsg message.Message + for i := msgIndex - 1; i >= 0; i-- { + msg := allMessages[i] + if msg.Role == message.User { + userMsg = allMessages[i] + break + } + } + + messages := []uiMessage{} + content := msg.Content().String() + finished := msg.IsFinished() + finishData := msg.FinishPart() + info := []string{} + + // Add finish info if available + if finished { + switch finishData.Reason { + case message.FinishReasonEndTurn: + took := formatTimeDifference(userMsg.CreatedAt, finishData.Time) + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), + )) + case message.FinishReasonCanceled: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "canceled"), + )) + case message.FinishReasonError: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "error"), + )) + case message.FinishReasonPermissionDenied: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "permission denied"), + )) + } + } + if content != "" { + content = renderMessage(content, false, msg.ID == focusedUIMessageId, width, info...) + messages = append(messages, uiMessage{ + ID: msg.ID, + messageType: assistantMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + }) + position += messages[0].height + position++ // for the space + } + + for i, toolCall := range msg.ToolCalls() { + toolCallContent := renderToolMessage( + toolCall, + allMessages, + messagesService, + focusedUIMessageId, + false, + width, + i+1, + ) + messages = append(messages, toolCallContent) + position += toolCallContent.height + position++ // for the space + } + return messages +} + +func findToolResponse(toolCallID string, futureMessages []message.Message) *message.ToolResult { + for _, msg := range futureMessages { + for _, result := range msg.ToolResults() { + if result.ToolCallID == toolCallID { + return &result + } + } + } + return nil +} + +func toolName(name string) string { + switch name { + case agent.AgentToolName: + return "Task" + case tools.BashToolName: + return "Bash" + case tools.EditToolName: + return "Edit" + case tools.FetchToolName: + return "Fetch" + case tools.GlobToolName: + return "Glob" + case tools.GrepToolName: + return "Grep" + case tools.LSToolName: + return "List" + case tools.SourcegraphToolName: + return "Sourcegraph" + case tools.ViewToolName: + return "View" + case tools.WriteToolName: + return "Write" + } + return name +} + +// renders params, params[0] (params[1]=params[2] ....) +func renderParams(paramsWidth int, params ...string) string { + if len(params) == 0 { + return "" + } + mainParam := params[0] + if len(mainParam) > paramsWidth { + mainParam = mainParam[:paramsWidth-3] + "..." + } + + if len(params) == 1 { + return mainParam + } + otherParams := params[1:] + // create pairs of key/value + // if odd number of params, the last one is a key without value + if len(otherParams)%2 != 0 { + otherParams = append(otherParams, "") + } + parts := make([]string, 0, len(otherParams)/2) + for i := 0; i < len(otherParams); i += 2 { + key := otherParams[i] + value := otherParams[i+1] + if value == "" { + continue + } + parts = append(parts, fmt.Sprintf("%s=%s", key, value)) + } + + partsRendered := strings.Join(parts, ", ") + remainingWidth := paramsWidth - lipgloss.Width(partsRendered) - 5 // for the space + if remainingWidth < 30 { + // No space for the params, just show the main + return mainParam + } + + if len(parts) > 0 { + mainParam = fmt.Sprintf("%s (%s)", mainParam, strings.Join(parts, ", ")) + } + + return ansi.Truncate(mainParam, paramsWidth, "...") +} + +func removeWorkingDirPrefix(path string) string { + wd := config.WorkingDirectory() + if strings.HasPrefix(path, wd) { + path = strings.TrimPrefix(path, wd) + } + if strings.HasPrefix(path, "/") { + path = strings.TrimPrefix(path, "/") + } + if strings.HasPrefix(path, "./") { + path = strings.TrimPrefix(path, "./") + } + if strings.HasPrefix(path, "../") { + path = strings.TrimPrefix(path, "../") + } + return path +} + +func renderToolParams(paramWidth int, toolCall message.ToolCall) string { + params := "" + switch toolCall.Name { + case agent.AgentToolName: + var params agent.AgentParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + prompt := strings.ReplaceAll(params.Prompt, "\n", " ") + return renderParams(paramWidth, prompt) + case tools.BashToolName: + var params tools.BashParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + command := strings.ReplaceAll(params.Command, "\n", " ") + return renderParams(paramWidth, command) + case tools.EditToolName: + var params tools.EditParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + return renderParams(paramWidth, filePath) + case tools.FetchToolName: + var params tools.FetchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + url := params.URL + toolParams := []string{ + url, + } + if params.Format != "" { + toolParams = append(toolParams, "format", params.Format) + } + if params.Timeout != 0 { + toolParams = append(toolParams, "timeout", (time.Duration(params.Timeout) * time.Second).String()) + } + return renderParams(paramWidth, toolParams...) + case tools.GlobToolName: + var params tools.GlobParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + pattern := params.Pattern + toolParams := []string{ + pattern, + } + if params.Path != "" { + toolParams = append(toolParams, "path", params.Path) + } + return renderParams(paramWidth, toolParams...) + case tools.GrepToolName: + var params tools.GrepParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + pattern := params.Pattern + toolParams := []string{ + pattern, + } + if params.Path != "" { + toolParams = append(toolParams, "path", params.Path) + } + if params.Include != "" { + toolParams = append(toolParams, "include", params.Include) + } + if params.LiteralText { + toolParams = append(toolParams, "literal", "true") + } + return renderParams(paramWidth, toolParams...) + case tools.LSToolName: + var params tools.LSParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + path := params.Path + if path == "" { + path = "." + } + return renderParams(paramWidth, path) + case tools.SourcegraphToolName: + var params tools.SourcegraphParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + return renderParams(paramWidth, params.Query) + case tools.ViewToolName: + var params tools.ViewParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + toolParams := []string{ + filePath, + } + if params.Limit != 0 { + toolParams = append(toolParams, "limit", fmt.Sprintf("%d", params.Limit)) + } + if params.Offset != 0 { + toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset)) + } + return renderParams(paramWidth, toolParams...) + case tools.WriteToolName: + var params tools.WriteParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + return renderParams(paramWidth, filePath) + default: + input := strings.ReplaceAll(toolCall.Input, "\n", " ") + params = renderParams(paramWidth, input) + } + return params +} + +func truncateHeight(content string, height int) string { + lines := strings.Split(content, "\n") + if len(lines) > height { + return strings.Join(lines[:height], "\n") + } + return content +} + +func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, width int) string { + if response.IsError { + errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " ")) + errContent = ansi.Truncate(errContent, width-1, "...") + return styles.BaseStyle. + Foreground(styles.Error). + Render(errContent) + } + resultContent := truncateHeight(response.Content, maxResultHeight) + switch toolCall.Name { + case agent.AgentToolName: + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, false, width), + styles.Background, + ) + case tools.BashToolName: + resultContent = fmt.Sprintf("```bash\n%s\n```", resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.EditToolName: + metadata := tools.EditResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + truncDiff := truncateHeight(metadata.Diff, maxResultHeight) + formattedDiff, _ := diff.FormatDiff(truncDiff, diff.WithTotalWidth(width), diff.WithStyle(diffStyle)) + return formattedDiff + case tools.FetchToolName: + var params tools.FetchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + mdFormat := "markdown" + switch params.Format { + case "text": + mdFormat = "text" + case "html": + mdFormat = "html" + } + resultContent = fmt.Sprintf("```%s\n%s\n```", mdFormat, resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.GlobToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.GrepToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.LSToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.SourcegraphToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.ViewToolName: + metadata := tools.ViewResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + ext := filepath.Ext(metadata.FilePath) + if ext == "" { + ext = "" + } else { + ext = strings.ToLower(ext[1:]) + } + resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(metadata.Content, maxResultHeight)) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.WriteToolName: + params := tools.WriteParams{} + json.Unmarshal([]byte(toolCall.Input), ¶ms) + metadata := tools.WriteResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + ext := filepath.Ext(params.FilePath) + if ext == "" { + ext = "" + } else { + ext = strings.ToLower(ext[1:]) + } + resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(params.Content, maxResultHeight)) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + default: + resultContent = fmt.Sprintf("```text\n%s\n```", resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + } +} + +func renderToolMessage( + toolCall message.ToolCall, + allMessages []message.Message, + messagesService message.Service, + focusedUIMessageId string, + nested bool, + width int, + position int, +) uiMessage { + if nested { + width = width - 3 + } + response := findToolResponse(toolCall.ID, allMessages) + toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name))) + params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall) + responseContent := "" + if response != nil { + responseContent = renderToolResponse(toolCall, *response, width-2) + responseContent = strings.TrimSuffix(responseContent, "\n") + } else { + responseContent = styles.BaseStyle. + Italic(true). + Width(width - 2). + Foreground(styles.ForgroundDim). + Render("Waiting for response...") + } + style := styles.BaseStyle. + Width(width - 1). + BorderLeft(true). + BorderStyle(lipgloss.ThickBorder()). + PaddingLeft(1). + BorderForeground(styles.ForgroundDim) + + parts := []string{} + if !nested { + params := styles.BaseStyle. + Width(width - 2 - lipgloss.Width(toolName)). + Foreground(styles.ForgroundDim). + Render(params) + + parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolName, params)) + } else { + prefix := styles.BaseStyle. + Foreground(styles.ForgroundDim). + Render(" └ ") + params := styles.BaseStyle. + Width(width - 2 - lipgloss.Width(toolName)). + Foreground(styles.ForgroundMid). + Render(params) + parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolName, params)) + } + if toolCall.Name == agent.AgentToolName { + taskMessages, _ := messagesService.List(context.Background(), toolCall.ID) + toolCalls := []message.ToolCall{} + for _, v := range taskMessages { + toolCalls = append(toolCalls, v.ToolCalls()...) + } + for _, call := range toolCalls { + rendered := renderToolMessage(call, []message.Message{}, messagesService, focusedUIMessageId, true, width, 0) + parts = append(parts, rendered.content) + } + } + if responseContent != "" && !nested { + parts = append(parts, responseContent) + } + + content := style.Render( + lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ), + ) + if nested { + content = lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ) + } + toolMsg := uiMessage{ + messageType: toolMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + } + return toolMsg +} diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go deleted file mode 100644 index c2ce7d88b..000000000 --- a/internal/tui/components/chat/messages.go +++ /dev/null @@ -1,742 +0,0 @@ -package chat - -import ( - "context" - "encoding/json" - "fmt" - "math" - "strings" - "time" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/spinner" - "github.com/charmbracelet/bubbles/viewport" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/glamour" - "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/opencode/internal/app" - "github.com/kujtimiihoxha/opencode/internal/llm/agent" - "github.com/kujtimiihoxha/opencode/internal/llm/models" - "github.com/kujtimiihoxha/opencode/internal/llm/tools" - "github.com/kujtimiihoxha/opencode/internal/logging" - "github.com/kujtimiihoxha/opencode/internal/message" - "github.com/kujtimiihoxha/opencode/internal/pubsub" - "github.com/kujtimiihoxha/opencode/internal/session" - "github.com/kujtimiihoxha/opencode/internal/tui/layout" - "github.com/kujtimiihoxha/opencode/internal/tui/styles" - "github.com/kujtimiihoxha/opencode/internal/tui/util" -) - -type uiMessageType int - -const ( - userMessageType uiMessageType = iota - assistantMessageType - toolMessageType -) - -// messagesTickMsg is a message sent by the timer to refresh messages -type messagesTickMsg time.Time - -type uiMessage struct { - ID string - messageType uiMessageType - position int - height int - content string -} - -type messagesCmp struct { - app *app.App - width, height int - writingMode bool - viewport viewport.Model - session session.Session - messages []message.Message - uiMessages []uiMessage - currentMsgID string - renderer *glamour.TermRenderer - focusRenderer *glamour.TermRenderer - cachedContent map[string]string - spinner spinner.Model - needsRerender bool -} - -func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages()) -} - -func (m *messagesCmp) tickMessages() tea.Cmd { - return tea.Tick(time.Second, func(t time.Time) tea.Msg { - return messagesTickMsg(t) - }) -} - -func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - switch msg := msg.(type) { - case messagesTickMsg: - // Refresh messages if we have an active session - if m.session.ID != "" { - messages, err := m.app.Messages.List(context.Background(), m.session.ID) - if err == nil { - m.messages = messages - m.needsRerender = true - } - } - // Continue ticking - cmds = append(cmds, m.tickMessages()) - case EditorFocusMsg: - m.writingMode = bool(msg) - case SessionSelectedMsg: - if msg.ID != m.session.ID { - cmd := m.SetSession(msg) - m.needsRerender = true - return m, cmd - } - return m, nil - case SessionClearedMsg: - m.session = session.Session{} - m.messages = make([]message.Message, 0) - m.currentMsgID = "" - m.needsRerender = true - m.cachedContent = make(map[string]string) - return m, nil - - case tea.KeyMsg: - if m.writingMode { - return m, nil - } - case pubsub.Event[message.Message]: - if msg.Type == pubsub.CreatedEvent { - if msg.Payload.SessionID == m.session.ID { - // check if message exists - - messageExists := false - for _, v := range m.messages { - if v.ID == msg.Payload.ID { - messageExists = true - break - } - } - - if !messageExists { - // If we have messages, ensure the previous last message is not cached - if len(m.messages) > 0 { - lastMsgID := m.messages[len(m.messages)-1].ID - delete(m.cachedContent, lastMsgID) - } - - m.messages = append(m.messages, msg.Payload) - delete(m.cachedContent, m.currentMsgID) - m.currentMsgID = msg.Payload.ID - m.needsRerender = true - } - } - for _, v := range m.messages { - for _, c := range v.ToolCalls() { - if c.ID == msg.Payload.SessionID { - m.needsRerender = true - } - } - } - } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { - logging.Debug("Message", "finish", msg.Payload.FinishReason()) - for i, v := range m.messages { - if v.ID == msg.Payload.ID { - m.messages[i] = msg.Payload - delete(m.cachedContent, msg.Payload.ID) - - // If this is the last message, ensure it's not cached - if i == len(m.messages)-1 { - delete(m.cachedContent, msg.Payload.ID) - } - - m.needsRerender = true - break - } - } - } - } - - oldPos := m.viewport.YPosition - u, cmd := m.viewport.Update(msg) - m.viewport = u - m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos - cmds = append(cmds, cmd) - - spinner, cmd := m.spinner.Update(msg) - m.spinner = spinner - cmds = append(cmds, cmd) - - if m.needsRerender { - m.renderView() - if len(m.messages) > 0 { - if msg, ok := msg.(pubsub.Event[message.Message]); ok { - if (msg.Type == pubsub.CreatedEvent) || - (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) { - m.viewport.GotoBottom() - } - } - } - m.needsRerender = false - } - return m, tea.Batch(cmds...) -} - -func (m *messagesCmp) IsAgentWorking() bool { - return m.app.CoderAgent.IsSessionBusy(m.session.ID) -} - -func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { - // Check if this is the last message in the list - isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID - - // Only use cache for non-last messages - if !isLastMessage { - if v, ok := m.cachedContent[msg.ID]; ok { - return v - } - } - - style := styles.BaseStyle. - Width(m.width). - BorderLeft(true). - Foreground(styles.ForgroundDim). - BorderForeground(styles.ForgroundDim). - BorderStyle(lipgloss.ThickBorder()) - - renderer := m.renderer - if msg.ID == m.currentMsgID { - style = style. - Foreground(styles.Forground). - BorderForeground(styles.Blue). - BorderStyle(lipgloss.ThickBorder()) - renderer = m.focusRenderer - } - c, _ := renderer.Render(msg.Content().String()) - parts := []string{ - styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background), - } - // remove newline at the end - parts[0] = strings.TrimSuffix(parts[0], "\n") - if len(info) > 0 { - parts = append(parts, info...) - } - rendered := style.Render( - lipgloss.JoinVertical( - lipgloss.Left, - parts..., - ), - ) - - // Only cache if it's not the last message - if !isLastMessage { - m.cachedContent[msg.ID] = rendered - } - - return rendered -} - -func formatTimeDifference(unixTime1, unixTime2 int64) string { - diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1))) - - if diffSeconds < 60 { - return fmt.Sprintf("%.1fs", diffSeconds) - } - - minutes := int(diffSeconds / 60) - seconds := int(diffSeconds) % 60 - return fmt.Sprintf("%dm%ds", minutes, seconds) -} - -func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult { - for _, v := range m.messages { - for _, c := range v.ToolResults() { - if c.ToolCallID == callID { - return &c - } - } - } - return nil -} - -func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { - key := "" - value := "" - result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...") - - response := m.findToolResponse(toolCall.ID) - if response != nil && response.IsError { - // Clean up error message for display by removing newlines - // This ensures error messages display properly in the UI - errMsg := strings.ReplaceAll(response.Content, "\n", " ") - result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "...")) - } else if response != nil { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done") - } - switch toolCall.Name { - // TODO: add result data to the tools - case agent.AgentToolName: - key = "Task" - var params agent.AgentParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = strings.ReplaceAll(params.Prompt, "\n", " ") - if response != nil && !response.IsError { - firstRow := strings.ReplaceAll(response.Content, "\n", " ") - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "...")) - } - case tools.BashToolName: - key = "Bash" - var params tools.BashParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Command - if response != nil && !response.IsError { - metadata := tools.BashResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime))) - } - - case tools.EditToolName: - key = "Edit" - var params tools.EditParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - if response != nil && !response.IsError { - metadata := tools.EditResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) - } - case tools.FetchToolName: - key = "Fetch" - var params tools.FetchParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.URL - if response != nil && !response.IsError { - result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content) - } - case tools.GlobToolName: - key = "Glob" - var params tools.GlobParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) - if response != nil && !response.IsError { - metadata := tools.GlobResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) - } - } - case tools.GrepToolName: - key = "Grep" - var params tools.GrepParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) - if response != nil && !response.IsError { - metadata := tools.GrepResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches)) - } - } - case tools.LSToolName: - key = "ls" - var params tools.LSParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = params.Path - if response != nil && !response.IsError { - metadata := tools.LSResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) - } - } - case tools.SourcegraphToolName: - key = "Sourcegraph" - var params tools.SourcegraphParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Query - if response != nil && !response.IsError { - metadata := tools.SourcegraphResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches)) - } - } - case tools.ViewToolName: - key = "View" - var params tools.ViewParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - case tools.WriteToolName: - key = "Write" - var params tools.WriteParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - if response != nil && !response.IsError { - metadata := tools.WriteResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) - } - default: - key = toolCall.Name - var params map[string]any - json.Unmarshal([]byte(toolCall.Input), ¶ms) - jsonData, _ := json.Marshal(params) - value = string(jsonData) - } - - style := styles.BaseStyle. - Width(m.width). - BorderLeft(true). - BorderStyle(lipgloss.ThickBorder()). - PaddingLeft(1). - BorderForeground(styles.Yellow) - - keyStyle := styles.BaseStyle. - Foreground(styles.ForgroundDim) - valyeStyle := styles.BaseStyle. - Foreground(styles.Forground) - - if isNested { - valyeStyle = valyeStyle.Foreground(styles.ForgroundMid) - } - keyValye := keyStyle.Render( - fmt.Sprintf("%s: ", key), - ) - if !isNested { - value = valyeStyle. - Render( - ansi.Truncate( - value+" ", - m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result), - "...", - ), - ) - value += result - - } else { - keyValye = keyStyle.Render( - fmt.Sprintf(" └ %s: ", key), - ) - value = valyeStyle. - Width(m.width - lipgloss.Width(keyValye) - 2). - Render( - ansi.Truncate( - value, - m.width-lipgloss.Width(keyValye)-2, - "...", - ), - ) - } - - innerToolCalls := make([]string, 0) - if toolCall.Name == agent.AgentToolName { - messages, _ := m.app.Messages.List(context.Background(), toolCall.ID) - toolCalls := make([]message.ToolCall, 0) - for _, v := range messages { - toolCalls = append(toolCalls, v.ToolCalls()...) - } - for _, v := range toolCalls { - call := m.renderToolCall(v, true) - innerToolCalls = append(innerToolCalls, call) - } - } - - if isNested { - return lipgloss.JoinHorizontal( - lipgloss.Left, - keyValye, - value, - ) - } - callContent := lipgloss.JoinHorizontal( - lipgloss.Left, - keyValye, - value, - ) - callContent = strings.ReplaceAll(callContent, "\n", "") - if len(innerToolCalls) > 0 { - callContent = lipgloss.JoinVertical( - lipgloss.Left, - callContent, - lipgloss.JoinVertical( - lipgloss.Left, - innerToolCalls..., - ), - ) - } - return style.Render(callContent) -} - -func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage { - // find the user message that is before this assistant message - var userMsg message.Message - for i := len(m.messages) - 1; i >= 0; i-- { - if m.messages[i].Role == message.User { - userMsg = m.messages[i] - break - } - } - messages := make([]uiMessage, 0) - if msg.Content().String() != "" { - info := make([]string, 0) - if msg.IsFinished() && msg.FinishReason() == "end_turn" { - finish := msg.FinishPart() - took := formatTimeDifference(userMsg.CreatedAt, finish.Time) - - info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render( - fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), - )) - } - content := m.renderSimpleMessage(msg, info...) - messages = append(messages, uiMessage{ - messageType: assistantMessageType, - position: 0, // gets updated in renderView - height: lipgloss.Height(content), - content: content, - }) - } - for _, v := range msg.ToolCalls() { - content := m.renderToolCall(v, false) - messages = append(messages, - uiMessage{ - messageType: toolMessageType, - position: 0, // gets updated in renderView - height: lipgloss.Height(content), - content: content, - }, - ) - } - - return messages -} - -func (m *messagesCmp) renderView() { - m.uiMessages = make([]uiMessage, 0) - pos := 0 - - // If we have messages, ensure the last message is not cached - // This ensures we always render the latest content for the most recent message - // which may be actively updating (e.g., during generation) - if len(m.messages) > 0 { - lastMsgID := m.messages[len(m.messages)-1].ID - delete(m.cachedContent, lastMsgID) - } - - // Limit cache to 10 messages - if len(m.cachedContent) > 15 { - // Create a list of keys to delete (oldest messages first) - keys := make([]string, 0, len(m.cachedContent)) - for k := range m.cachedContent { - keys = append(keys, k) - } - // Delete oldest messages until we have 10 or fewer - for i := 0; i < len(keys)-15; i++ { - delete(m.cachedContent, keys[i]) - } - } - - for _, v := range m.messages { - switch v.Role { - case message.User: - content := m.renderSimpleMessage(v) - m.uiMessages = append(m.uiMessages, uiMessage{ - messageType: userMessageType, - position: pos, - height: lipgloss.Height(content), - content: content, - }) - pos += lipgloss.Height(content) + 1 // + 1 for spacing - case message.Assistant: - assistantMessages := m.renderAssistantMessage(v) - for _, msg := range assistantMessages { - msg.position = pos - m.uiMessages = append(m.uiMessages, msg) - pos += msg.height + 1 // + 1 for spacing - } - - } - } - - messages := make([]string, 0) - for _, v := range m.uiMessages { - messages = append(messages, v.content, - styles.BaseStyle. - Width(m.width). - Render( - "", - ), - ) - } - m.viewport.SetContent( - styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - messages..., - ), - ), - ) -} - -func (m *messagesCmp) View() string { - if len(m.messages) == 0 { - content := styles.BaseStyle. - Width(m.width). - Height(m.height - 1). - Render( - m.initialScreen(), - ) - - return styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - content, - m.help(), - ), - ) - } - - return styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - m.viewport.View(), - m.help(), - ), - ) -} - -func (m *messagesCmp) help() string { - text := "" - - if m.IsAgentWorking() { - text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( - fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), - ) - } - if m.writingMode { - text += lipgloss.JoinHorizontal( - lipgloss.Left, - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), - styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), - ) - } else { - text += lipgloss.JoinHorizontal( - lipgloss.Left, - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), - styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"), - ) - } - - return styles.BaseStyle. - Width(m.width). - Render(text) -} - -func (m *messagesCmp) initialScreen() string { - return styles.BaseStyle.Width(m.width).Render( - lipgloss.JoinVertical( - lipgloss.Top, - header(m.width), - "", - lspsConfigured(m.width), - ), - ) -} - -func (m *messagesCmp) SetSize(width, height int) { - m.width = width - m.height = height - m.viewport.Width = width - m.viewport.Height = height - 1 - focusRenderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(true)), - glamour.WithWordWrap(width-1), - ) - renderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(false)), - glamour.WithWordWrap(width-1), - ) - m.focusRenderer = focusRenderer - // clear the cached content - for k := range m.cachedContent { - delete(m.cachedContent, k) - } - m.renderer = renderer - if len(m.messages) > 0 { - m.renderView() - m.viewport.GotoBottom() - } -} - -func (m *messagesCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { - m.session = session - messages, err := m.app.Messages.List(context.Background(), session.ID) - if err != nil { - return util.ReportError(err) - } - m.messages = messages - m.currentMsgID = m.messages[len(m.messages)-1].ID - m.needsRerender = true - m.cachedContent = make(map[string]string) - return nil -} - -func (m *messagesCmp) BindingKeys() []key.Binding { - bindings := layout.KeyMapToSlice(m.viewport.KeyMap) - return bindings -} - -func NewMessagesCmp(app *app.App) tea.Model { - focusRenderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(true)), - glamour.WithWordWrap(80), - ) - renderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(false)), - glamour.WithWordWrap(80), - ) - - s := spinner.New() - s.Spinner = spinner.Pulse - return &messagesCmp{ - app: app, - writingMode: true, - cachedContent: make(map[string]string), - viewport: viewport.New(0, 0), - focusRenderer: focusRenderer, - renderer: renderer, - spinner: s, - } -} diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 5a275c0cf..d330e592b 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -51,6 +51,12 @@ func (m *sidebarCmp) Init() tea.Cmd { func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { + case SessionSelectedMsg: + if msg.ID != m.session.ID { + m.session = msg + ctx := context.Background() + m.loadModifiedFiles(ctx) + } case pubsub.Event[session.Session]: if msg.Type == pubsub.UpdatedEvent { if m.session.ID == msg.Payload.ID { @@ -59,10 +65,16 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case pubsub.Event[history.File]: if msg.Payload.SessionID == m.session.ID { - // When a file changes, reload all modified files - // This ensures we have the complete and accurate list + // Process the individual file change instead of reloading all files ctx := context.Background() - m.loadModifiedFiles(ctx) + m.processFileChanges(ctx, msg.Payload) + + // Return a command to continue receiving events + return m, func() tea.Msg { + ctx := context.Background() + filesCh := m.history.Subscribe(ctx) + return <-filesCh + } } } return m, nil @@ -71,6 +83,8 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *sidebarCmp) View() string { return styles.BaseStyle. Width(m.width). + PaddingLeft(4). + PaddingRight(2). Height(m.height - 1). Render( lipgloss.JoinVertical( @@ -79,9 +93,9 @@ func (m *sidebarCmp) View() string { " ", m.sessionSection(), " ", - m.modifiedFiles(), - " ", lspsConfigured(m.width), + " ", + m.modifiedFiles(), ), ) } @@ -170,9 +184,10 @@ func (m *sidebarCmp) modifiedFiles() string { ) } -func (m *sidebarCmp) SetSize(width, height int) { +func (m *sidebarCmp) SetSize(width, height int) tea.Cmd { m.width = width m.height = height + return nil } func (m *sidebarCmp) GetSize() (int, int) { @@ -203,6 +218,12 @@ func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { return } + // Clear the existing map to rebuild it + m.modFiles = make(map[string]struct { + additions int + removals int + }) + // Process each latest file for _, file := range latestFiles { // Skip if this is the initial version (no changes to show) @@ -250,28 +271,23 @@ func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { } func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) { - // Skip if not the latest version + // Skip if this is the initial version (no changes to show) if file.Version == history.InitialVersion { return } - // Get all versions of this file - fileVersions, err := m.history.ListBySession(ctx, m.session.ID) - if err != nil { + // Find the initial version for this file + initialVersion, err := m.findInitialVersion(ctx, file.Path) + if err != nil || initialVersion.ID == "" { return } - // Find the initial version - var initialVersion history.File - for _, v := range fileVersions { - if v.Path == file.Path && v.Version == history.InitialVersion { - initialVersion = v - break - } - } - - // Skip if we can't find the initial version - if initialVersion.ID == "" { + // Skip if content hasn't changed + if initialVersion.Content == file.Content { + // If this file was previously modified but now matches the initial version, + // remove it from the modified files list + displayPath := getDisplayPath(file.Path) + delete(m.modFiles, displayPath) return } @@ -280,12 +296,7 @@ func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) // Only add to modified files if there are changes if additions > 0 || removals > 0 { - // Remove working directory prefix from file path - displayPath := file.Path - workingDir := config.WorkingDirectory() - displayPath = strings.TrimPrefix(displayPath, workingDir) - displayPath = strings.TrimPrefix(displayPath, "/") - + displayPath := getDisplayPath(file.Path) m.modFiles[displayPath] = struct { additions int removals int @@ -293,5 +304,34 @@ func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) additions: additions, removals: removals, } + } else { + // If no changes, remove from modified files + displayPath := getDisplayPath(file.Path) + delete(m.modFiles, displayPath) + } +} + +// Helper function to find the initial version of a file +func (m *sidebarCmp) findInitialVersion(ctx context.Context, path string) (history.File, error) { + // Get all versions of this file for the session + fileVersions, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return history.File{}, err } + + // Find the initial version + for _, v := range fileVersions { + if v.Path == path && v.Version == history.InitialVersion { + return v, nil + } + } + + return history.File{}, fmt.Errorf("initial version not found") +} + +// Helper function to get the display path for a file +func getDisplayPath(path string) string { + workingDir := config.WorkingDirectory() + displayPath := strings.TrimPrefix(path, workingDir) + return strings.TrimPrefix(displayPath, "/") } diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index e76ecde84..01c535869 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -166,19 +166,31 @@ func (m *statusCmp) projectDiagnostics() string { diagnostics := []string{} if len(errorDiagnostics) > 0 { - errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) + errStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Error). + Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) diagnostics = append(diagnostics, errStr) } if len(warnDiagnostics) > 0 { - warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) + warnStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Warning). + Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) diagnostics = append(diagnostics, warnStr) } if len(hintDiagnostics) > 0 { - hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) + hintStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Text). + Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) diagnostics = append(diagnostics, hintStr) } if len(infoDiagnostics) > 0 { - infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) + infoStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Peach). + Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) diagnostics = append(diagnostics, infoStr) } @@ -187,10 +199,12 @@ func (m *statusCmp) projectDiagnostics() string { func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { tokens := "" + tokensWidth := 0 if m.session.ID != "" { tokens = formatTokensAndCost(m.session.PromptTokens+m.session.CompletionTokens, m.session.Cost) + tokensWidth = lipgloss.Width(tokens) + 2 } - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)-lipgloss.Width(tokens)) + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)-tokensWidth) } func (m statusCmp) model() string { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 295884432..f83472e68 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -36,7 +36,7 @@ type PermissionResponseMsg struct { type PermissionDialogCmp interface { tea.Model layout.Bindings - SetPermissions(permission permission.PermissionRequest) + SetPermissions(permission permission.PermissionRequest) tea.Cmd } type permissionsMapping struct { @@ -98,7 +98,8 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: p.windowSize = msg - p.SetSize() + cmd := p.SetSize() + cmds = append(cmds, cmd) p.markdownCache = make(map[string]string) p.diffCache = make(map[string]string) case tea.KeyMsg: @@ -267,7 +268,7 @@ func (p *permissionDialogCmp) renderEditContent() string { } func (p *permissionDialogCmp) renderPatchContent() string { - if pr, ok := p.permission.Params.(tools.PatchPermissionsParams); ok { + if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok { diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) }) @@ -401,9 +402,9 @@ func (p *permissionDialogCmp) BindingKeys() []key.Binding { return layout.KeyMapToSlice(helpKeys) } -func (p *permissionDialogCmp) SetSize() { +func (p *permissionDialogCmp) SetSize() tea.Cmd { if p.permission.ID == "" { - return + return nil } switch p.permission.ToolName { case tools.BashToolName: @@ -422,11 +423,12 @@ func (p *permissionDialogCmp) SetSize() { p.width = int(float64(p.windowSize.Width) * 0.7) p.height = int(float64(p.windowSize.Height) * 0.5) } + return nil } -func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) { +func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) tea.Cmd { p.permission = permission - p.SetSize() + return p.SetSize() } // Helper to get or set cached diff content diff --git a/internal/tui/components/dialog/session.go b/internal/tui/components/dialog/session.go new file mode 100644 index 000000000..d8c859c49 --- /dev/null +++ b/internal/tui/components/dialog/session.go @@ -0,0 +1,224 @@ +package dialog + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +// SessionSelectedMsg is sent when a session is selected +type SessionSelectedMsg struct { + Session session.Session +} + +// CloseSessionDialogMsg is sent when the session dialog is closed +type CloseSessionDialogMsg struct{} + +// SessionDialog interface for the session switching dialog +type SessionDialog interface { + tea.Model + layout.Bindings + SetSessions(sessions []session.Session) + SetSelectedSession(sessionID string) +} + +type sessionDialogCmp struct { + sessions []session.Session + selectedIdx int + width int + height int + selectedSessionID string +} + +type sessionKeyMap struct { + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding +} + +var sessionKeys = sessionKeyMap{ + Up: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("↑", "previous session"), + ), + Down: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("↓", "next session"), + ), + Enter: key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "select session"), + ), + Escape: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), + ), + J: key.NewBinding( + key.WithKeys("j"), + key.WithHelp("j", "next session"), + ), + K: key.NewBinding( + key.WithKeys("k"), + key.WithHelp("k", "previous session"), + ), +} + +func (s *sessionDialogCmp) Init() tea.Cmd { + return nil +} + +func (s *sessionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, sessionKeys.Up) || key.Matches(msg, sessionKeys.K): + if s.selectedIdx > 0 { + s.selectedIdx-- + } + return s, nil + case key.Matches(msg, sessionKeys.Down) || key.Matches(msg, sessionKeys.J): + if s.selectedIdx < len(s.sessions)-1 { + s.selectedIdx++ + } + return s, nil + case key.Matches(msg, sessionKeys.Enter): + if len(s.sessions) > 0 { + return s, util.CmdHandler(SessionSelectedMsg{ + Session: s.sessions[s.selectedIdx], + }) + } + case key.Matches(msg, sessionKeys.Escape): + return s, util.CmdHandler(CloseSessionDialogMsg{}) + } + case tea.WindowSizeMsg: + s.width = msg.Width + s.height = msg.Height + } + return s, nil +} + +func (s *sessionDialogCmp) View() string { + if len(s.sessions) == 0 { + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(40). + Render("No sessions available") + } + + // Calculate max width needed for session titles + maxWidth := 40 // Minimum width + for _, sess := range s.sessions { + if len(sess.Title) > maxWidth-4 { // Account for padding + maxWidth = len(sess.Title) + 4 + } + } + + // Limit height to avoid taking up too much screen space + maxVisibleSessions := min(10, len(s.sessions)) + + // Build the session list + sessionItems := make([]string, 0, maxVisibleSessions) + startIdx := 0 + + // If we have more sessions than can be displayed, adjust the start index + if len(s.sessions) > maxVisibleSessions { + // Center the selected item when possible + halfVisible := maxVisibleSessions / 2 + if s.selectedIdx >= halfVisible && s.selectedIdx < len(s.sessions)-halfVisible { + startIdx = s.selectedIdx - halfVisible + } else if s.selectedIdx >= len(s.sessions)-halfVisible { + startIdx = len(s.sessions) - maxVisibleSessions + } + } + + endIdx := min(startIdx+maxVisibleSessions, len(s.sessions)) + + for i := startIdx; i < endIdx; i++ { + sess := s.sessions[i] + itemStyle := styles.BaseStyle.Width(maxWidth) + + if i == s.selectedIdx { + itemStyle = itemStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background). + Bold(true) + } + + sessionItems = append(sessionItems, itemStyle.Padding(0, 1).Render(sess.Title)) + } + + title := styles.BaseStyle. + Foreground(styles.PrimaryColor). + Bold(true). + Padding(0, 1). + Render("Switch Session") + + content := lipgloss.JoinVertical( + lipgloss.Left, + title, + styles.BaseStyle.Render(""), + lipgloss.JoinVertical(lipgloss.Left, sessionItems...), + styles.BaseStyle.Render(""), + styles.BaseStyle.Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) +} + +func (s *sessionDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(sessionKeys) +} + +func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { + s.sessions = sessions + + // If we have a selected session ID, find its index + if s.selectedSessionID != "" { + for i, sess := range sessions { + if sess.ID == s.selectedSessionID { + s.selectedIdx = i + return + } + } + } + + // Default to first session if selected not found + s.selectedIdx = 0 +} + +func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { + s.selectedSessionID = sessionID + + // Update the selected index if sessions are already loaded + if len(s.sessions) > 0 { + for i, sess := range s.sessions { + if sess.ID == sessionID { + s.selectedIdx = i + return + } + } + } +} + +// NewSessionDialogCmp creates a new session switching dialog +func NewSessionDialogCmp() SessionDialog { + return &sessionDialogCmp{ + sessions: []session.Session{}, + selectedIdx: 0, + selectedSessionID: "", + } +} \ No newline at end of file diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 7c74da104..fa49adbbb 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -119,27 +119,17 @@ func (i *detailCmp) GetSize() (int, int) { return i.width, i.height } -func (i *detailCmp) SetSize(width int, height int) { +func (i *detailCmp) SetSize(width int, height int) tea.Cmd { i.width = width i.height = height i.viewport.Width = i.width i.viewport.Height = i.height i.updateContent() + return nil } func (i *detailCmp) BindingKeys() []key.Binding { - return []key.Binding{ - i.viewport.KeyMap.PageDown, - i.viewport.KeyMap.PageUp, - i.viewport.KeyMap.HalfPageDown, - i.viewport.KeyMap.HalfPageUp, - } -} - -func (i *detailCmp) BorderText() map[layout.BorderPosition]string { - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: "Log Details", - } + return layout.KeyMapToSlice(i.viewport.KeyMap) } func NewLogsDetails() DetailComponent { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 2d0f9c533..245714d0d 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -68,7 +68,7 @@ func (i *tableCmp) GetSize() (int, int) { return i.table.Width(), i.table.Height() } -func (i *tableCmp) SetSize(width int, height int) { +func (i *tableCmp) SetSize(width int, height int) tea.Cmd { i.table.SetWidth(width) i.table.SetHeight(height) cloumns := i.table.Columns() @@ -77,6 +77,7 @@ func (i *tableCmp) SetSize(width int, height int) { cloumns[i] = col } i.table.SetColumns(cloumns) + return nil } func (i *tableCmp) BindingKeys() []key.Binding { diff --git a/internal/tui/layout/bento.go b/internal/tui/layout/bento.go deleted file mode 100644 index c47c4e090..000000000 --- a/internal/tui/layout/bento.go +++ /dev/null @@ -1,392 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type paneID string - -const ( - BentoLeftPane paneID = "left" - BentoRightTopPane paneID = "right-top" - BentoRightBottomPane paneID = "right-bottom" -) - -type BentoPanes map[paneID]tea.Model - -const ( - defaultLeftWidthRatio = 0.2 - defaultRightTopHeightRatio = 0.85 - - minLeftWidth = 10 - minRightBottomHeight = 10 -) - -type BentoLayout interface { - tea.Model - Sizeable - Bindings -} - -type BentoKeyBindings struct { - SwitchPane key.Binding - SwitchPaneBack key.Binding - HideCurrentPane key.Binding - ShowAllPanes key.Binding -} - -var defaultBentoKeyBindings = BentoKeyBindings{ - SwitchPane: key.NewBinding( - key.WithKeys("tab"), - key.WithHelp("tab", "switch pane"), - ), - SwitchPaneBack: key.NewBinding( - key.WithKeys("shift+tab"), - key.WithHelp("shift+tab", "switch pane back"), - ), - HideCurrentPane: key.NewBinding( - key.WithKeys("X"), - key.WithHelp("X", "hide current pane"), - ), - ShowAllPanes: key.NewBinding( - key.WithKeys("R"), - key.WithHelp("R", "show all panes"), - ), -} - -type bentoLayout struct { - width int - height int - - leftWidthRatio float64 - rightTopHeightRatio float64 - - currentPane paneID - panes map[paneID]SinglePaneLayout - hiddenPanes map[paneID]bool -} - -func (b *bentoLayout) GetSize() (int, int) { - return b.width, b.height -} - -func (b *bentoLayout) Init() tea.Cmd { - var cmds []tea.Cmd - for _, pane := range b.panes { - cmd := pane.Init() - if cmd != nil { - cmds = append(cmds, cmd) - } - } - return tea.Batch(cmds...) -} - -func (b *bentoLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - b.SetSize(msg.Width, msg.Height) - return b, nil - case tea.KeyMsg: - switch { - case key.Matches(msg, defaultBentoKeyBindings.SwitchPane): - return b, b.SwitchPane(false) - case key.Matches(msg, defaultBentoKeyBindings.SwitchPaneBack): - return b, b.SwitchPane(true) - case key.Matches(msg, defaultBentoKeyBindings.HideCurrentPane): - return b, b.HidePane(b.currentPane) - case key.Matches(msg, defaultBentoKeyBindings.ShowAllPanes): - for id := range b.hiddenPanes { - delete(b.hiddenPanes, id) - } - b.SetSize(b.width, b.height) - return b, nil - } - } - - var cmds []tea.Cmd - for id, pane := range b.panes { - u, cmd := pane.Update(msg) - b.panes[id] = u.(SinglePaneLayout) - if cmd != nil { - cmds = append(cmds, cmd) - } - } - return b, tea.Batch(cmds...) -} - -func (b *bentoLayout) View() string { - if b.width <= 0 || b.height <= 0 { - return "" - } - - for id, pane := range b.panes { - if b.currentPane == id { - pane.Focus() - } else { - pane.Blur() - } - } - - leftVisible := false - rightTopVisible := false - rightBottomVisible := false - - var leftPane, rightTopPane, rightBottomPane string - - if pane, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - leftPane = pane.View() - leftVisible = true - } - - if pane, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - rightTopPane = pane.View() - rightTopVisible = true - } - - if pane, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - rightBottomPane = pane.View() - rightBottomVisible = true - } - - if leftVisible { - if rightTopVisible || rightBottomVisible { - rightSection := "" - if rightTopVisible && rightBottomVisible { - rightSection = lipgloss.JoinVertical(lipgloss.Top, rightTopPane, rightBottomPane) - } else if rightTopVisible { - rightSection = rightTopPane - } else { - rightSection = rightBottomPane - } - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render( - lipgloss.JoinHorizontal(lipgloss.Left, leftPane, rightSection), - ) - } else { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(leftPane) - } - } else if rightTopVisible || rightBottomVisible { - if rightTopVisible && rightBottomVisible { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render( - lipgloss.JoinVertical(lipgloss.Top, rightTopPane, rightBottomPane), - ) - } else if rightTopVisible { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(rightTopPane) - } else { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(rightBottomPane) - } - } - return "" -} - -func (b *bentoLayout) SetSize(width int, height int) { - if width < 0 || height < 0 { - return - } - b.width = width - b.height = height - - leftExists := false - rightTopExists := false - rightBottomExists := false - - if _, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - leftExists = true - } - if _, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - rightTopExists = true - } - if _, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - rightBottomExists = true - } - - leftWidth := 0 - rightWidth := 0 - rightTopHeight := 0 - rightBottomHeight := 0 - - if leftExists && (rightTopExists || rightBottomExists) { - leftWidth = int(float64(width) * b.leftWidthRatio) - if leftWidth < minLeftWidth && width >= minLeftWidth { - leftWidth = minLeftWidth - } - rightWidth = width - leftWidth - - if rightTopExists && rightBottomExists { - rightTopHeight = int(float64(height) * b.rightTopHeightRatio) - rightBottomHeight = height - rightTopHeight - - if rightBottomHeight < minRightBottomHeight && height >= minRightBottomHeight { - rightBottomHeight = minRightBottomHeight - rightTopHeight = height - rightBottomHeight - } - } else if rightTopExists { - rightTopHeight = height - } else if rightBottomExists { - rightBottomHeight = height - } - } else if leftExists { - leftWidth = width - } else if rightTopExists || rightBottomExists { - rightWidth = width - - if rightTopExists && rightBottomExists { - rightTopHeight = int(float64(height) * b.rightTopHeightRatio) - rightBottomHeight = height - rightTopHeight - - if rightBottomHeight < minRightBottomHeight && height >= minRightBottomHeight { - rightBottomHeight = minRightBottomHeight - rightTopHeight = height - rightBottomHeight - } - } else if rightTopExists { - rightTopHeight = height - } else if rightBottomExists { - rightBottomHeight = height - } - } - - if pane, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - pane.SetSize(leftWidth, height) - } - if pane, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - pane.SetSize(rightWidth, rightTopHeight) - } - if pane, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - pane.SetSize(rightWidth, rightBottomHeight) - } -} - -func (b *bentoLayout) HidePane(pane paneID) tea.Cmd { - if len(b.panes)-len(b.hiddenPanes) == 1 { - return nil - } - if _, ok := b.panes[pane]; ok { - b.hiddenPanes[pane] = true - } - b.SetSize(b.width, b.height) - return b.SwitchPane(false) -} - -func (b *bentoLayout) SwitchPane(back bool) tea.Cmd { - orderForward := []paneID{BentoLeftPane, BentoRightTopPane, BentoRightBottomPane} - orderBackward := []paneID{BentoLeftPane, BentoRightBottomPane, BentoRightTopPane} - - order := orderForward - if back { - order = orderBackward - } - - currentIdx := -1 - for i, id := range order { - if id == b.currentPane { - currentIdx = i - break - } - } - - if currentIdx == -1 { - for _, id := range order { - if _, exists := b.panes[id]; exists { - if _, hidden := b.hiddenPanes[id]; !hidden { - b.currentPane = id - break - } - } - } - } else { - startIdx := currentIdx - for { - currentIdx = (currentIdx + 1) % len(order) - - nextID := order[currentIdx] - if _, exists := b.panes[nextID]; exists { - if _, hidden := b.hiddenPanes[nextID]; !hidden { - b.currentPane = nextID - break - } - } - - if currentIdx == startIdx { - break - } - } - } - - var cmds []tea.Cmd - for id, pane := range b.panes { - if _, ok := b.hiddenPanes[id]; ok { - continue - } - if id == b.currentPane { - cmds = append(cmds, pane.Focus()) - } else { - cmds = append(cmds, pane.Blur()) - } - } - - return tea.Batch(cmds...) -} - -func (s *bentoLayout) BindingKeys() []key.Binding { - bindings := KeyMapToSlice(defaultBentoKeyBindings) - if b, ok := s.panes[s.currentPane].(Bindings); ok { - bindings = append(bindings, b.BindingKeys()...) - } - return bindings -} - -type BentoLayoutOption func(*bentoLayout) - -func NewBentoLayout(panes BentoPanes, opts ...BentoLayoutOption) BentoLayout { - p := make(map[paneID]SinglePaneLayout, len(panes)) - for id, pane := range panes { - if sp, ok := pane.(SinglePaneLayout); !ok { - p[id] = NewSinglePane( - pane, - WithSinglePaneFocusable(true), - WithSinglePaneBordered(true), - ) - } else { - p[id] = sp - } - } - if len(p) == 0 { - panic("no panes provided for BentoLayout") - } - layout := &bentoLayout{ - panes: p, - hiddenPanes: make(map[paneID]bool), - currentPane: BentoLeftPane, - leftWidthRatio: defaultLeftWidthRatio, - rightTopHeightRatio: defaultRightTopHeightRatio, - } - - for _, opt := range opts { - opt(layout) - } - - return layout -} - -func WithBentoLayoutLeftWidthRatio(ratio float64) BentoLayoutOption { - return func(b *bentoLayout) { - if ratio > 0 && ratio < 1 { - b.leftWidthRatio = ratio - } - } -} - -func WithBentoLayoutRightTopHeightRatio(ratio float64) BentoLayoutOption { - return func(b *bentoLayout) { - if ratio > 0 && ratio < 1 { - b.rightTopHeightRatio = ratio - } - } -} - -func WithBentoLayoutCurrentPane(pane paneID) BentoLayoutOption { - return func(b *bentoLayout) { - b.currentPane = pane - } -} diff --git a/internal/tui/layout/border.go b/internal/tui/layout/border.go deleted file mode 100644 index ea9f5e0bc..000000000 --- a/internal/tui/layout/border.go +++ /dev/null @@ -1,121 +0,0 @@ -package layout - -import ( - "fmt" - "strings" - - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/opencode/internal/tui/styles" -) - -type BorderPosition int - -const ( - TopLeftBorder BorderPosition = iota - TopMiddleBorder - TopRightBorder - BottomLeftBorder - BottomMiddleBorder - BottomRightBorder -) - -var ( - ActiveBorder = styles.Blue - InactivePreviewBorder = styles.Grey -) - -type BorderOptions struct { - Active bool - EmbeddedText map[BorderPosition]string - ActiveColor lipgloss.TerminalColor - InactiveColor lipgloss.TerminalColor - ActiveBorder lipgloss.Border - InactiveBorder lipgloss.Border -} - -func Borderize(content string, opts BorderOptions) string { - if opts.EmbeddedText == nil { - opts.EmbeddedText = make(map[BorderPosition]string) - } - if opts.ActiveColor == nil { - opts.ActiveColor = ActiveBorder - } - if opts.InactiveColor == nil { - opts.InactiveColor = InactivePreviewBorder - } - if opts.ActiveBorder == (lipgloss.Border{}) { - opts.ActiveBorder = lipgloss.ThickBorder() - } - if opts.InactiveBorder == (lipgloss.Border{}) { - opts.InactiveBorder = lipgloss.NormalBorder() - } - - var ( - thickness = map[bool]lipgloss.Border{ - true: opts.ActiveBorder, - false: opts.InactiveBorder, - } - color = map[bool]lipgloss.TerminalColor{ - true: opts.ActiveColor, - false: opts.InactiveColor, - } - border = thickness[opts.Active] - style = lipgloss.NewStyle().Foreground(color[opts.Active]) - width = lipgloss.Width(content) - ) - - encloseInSquareBrackets := func(text string) string { - if text != "" { - return fmt.Sprintf("%s%s%s", - style.Render(border.TopRight), - text, - style.Render(border.TopLeft), - ) - } - return text - } - buildHorizontalBorder := func(leftText, middleText, rightText, leftCorner, inbetween, rightCorner string) string { - leftText = encloseInSquareBrackets(leftText) - middleText = encloseInSquareBrackets(middleText) - rightText = encloseInSquareBrackets(rightText) - // Calculate length of border between embedded texts - remaining := max(0, width-lipgloss.Width(leftText)-lipgloss.Width(middleText)-lipgloss.Width(rightText)) - leftBorderLen := max(0, (width/2)-lipgloss.Width(leftText)-(lipgloss.Width(middleText)/2)) - rightBorderLen := max(0, remaining-leftBorderLen) - // Then construct border string - s := leftText + - style.Render(strings.Repeat(inbetween, leftBorderLen)) + - middleText + - style.Render(strings.Repeat(inbetween, rightBorderLen)) + - rightText - // Make it fit in the space available between the two corners. - s = lipgloss.NewStyle(). - Inline(true). - MaxWidth(width). - Render(s) - // Add the corners - return style.Render(leftCorner) + s + style.Render(rightCorner) - } - // Stack top border, content and horizontal borders, and bottom border. - return strings.Join([]string{ - buildHorizontalBorder( - opts.EmbeddedText[TopLeftBorder], - opts.EmbeddedText[TopMiddleBorder], - opts.EmbeddedText[TopRightBorder], - border.TopLeft, - border.Top, - border.TopRight, - ), - lipgloss.NewStyle(). - BorderForeground(color[opts.Active]). - Border(border, false, true, false, true).Render(content), - buildHorizontalBorder( - opts.EmbeddedText[BottomLeftBorder], - opts.EmbeddedText[BottomMiddleBorder], - opts.EmbeddedText[BottomRightBorder], - border.BottomLeft, - border.Bottom, - border.BottomRight, - ), - }, "\n") -} diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index c86d954ea..fdb9ab403 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -86,7 +86,7 @@ func (c *container) View() string { return style.Render(c.content.View()) } -func (c *container) SetSize(width, height int) { +func (c *container) SetSize(width, height int) tea.Cmd { c.width = width c.height = height @@ -113,8 +113,9 @@ func (c *container) SetSize(width, height int) { // Set content size with adjusted dimensions contentWidth := max(0, width-horizontalSpace) contentHeight := max(0, height-verticalSpace) - sizeable.SetSize(contentWidth, contentHeight) + return sizeable.SetSize(contentWidth, contentHeight) } + return nil } func (c *container) GetSize() (int, int) { diff --git a/internal/tui/layout/grid.go b/internal/tui/layout/grid.go deleted file mode 100644 index 6be493e2c..000000000 --- a/internal/tui/layout/grid.go +++ /dev/null @@ -1,254 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type GridLayout interface { - tea.Model - Sizeable - Bindings - Panes() [][]tea.Model -} - -type gridLayout struct { - width int - height int - - rows int - columns int - - panes [][]tea.Model - - gap int - bordered bool - focusable bool - - currentRow int - currentColumn int - - activeColor lipgloss.TerminalColor -} - -type GridOption func(*gridLayout) - -func (g *gridLayout) Init() tea.Cmd { - var cmds []tea.Cmd - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - cmds = append(cmds, g.panes[i][j].Init()) - } - } - } - return tea.Batch(cmds...) -} - -func (g *gridLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.WindowSizeMsg: - g.SetSize(msg.Width, msg.Height) - return g, nil - case tea.KeyMsg: - if key.Matches(msg, g.nextPaneBinding()) { - return g.focusNextPane() - } - } - - // Update all panes - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - var cmd tea.Cmd - g.panes[i][j], cmd = g.panes[i][j].Update(msg) - if cmd != nil { - cmds = append(cmds, cmd) - } - } - } - } - - return g, tea.Batch(cmds...) -} - -func (g *gridLayout) focusNextPane() (tea.Model, tea.Cmd) { - if !g.focusable { - return g, nil - } - - var cmds []tea.Cmd - - // Blur current pane - if g.currentRow < len(g.panes) && g.currentColumn < len(g.panes[g.currentRow]) { - if currentPane, ok := g.panes[g.currentRow][g.currentColumn].(Focusable); ok { - cmds = append(cmds, currentPane.Blur()) - } - } - - // Find next valid pane - g.currentColumn++ - if g.currentColumn >= len(g.panes[g.currentRow]) { - g.currentColumn = 0 - g.currentRow++ - if g.currentRow >= len(g.panes) { - g.currentRow = 0 - } - } - - // Focus next pane - if g.currentRow < len(g.panes) && g.currentColumn < len(g.panes[g.currentRow]) { - if nextPane, ok := g.panes[g.currentRow][g.currentColumn].(Focusable); ok { - cmds = append(cmds, nextPane.Focus()) - } - } - - return g, tea.Batch(cmds...) -} - -func (g *gridLayout) nextPaneBinding() key.Binding { - return key.NewBinding( - key.WithKeys("tab"), - key.WithHelp("tab", "next pane"), - ) -} - -func (g *gridLayout) View() string { - if len(g.panes) == 0 { - return "" - } - - // Calculate dimensions for each cell - cellWidth := (g.width - (g.columns-1)*g.gap) / g.columns - cellHeight := (g.height - (g.rows-1)*g.gap) / g.rows - - // Render each row - rows := make([]string, g.rows) - for i := range g.rows { - // Render each column in this row - cols := make([]string, len(g.panes[i])) - for j := range g.panes[i] { - if g.panes[i][j] == nil { - cols[j] = "" - continue - } - - // Set size for each pane - if sizable, ok := g.panes[i][j].(Sizeable); ok { - effectiveWidth, effectiveHeight := cellWidth, cellHeight - if g.bordered { - effectiveWidth -= 2 - effectiveHeight -= 2 - } - sizable.SetSize(effectiveWidth, effectiveHeight) - } - - // Render the pane - content := g.panes[i][j].View() - - // Apply border if needed - if g.bordered { - isFocused := false - if focusable, ok := g.panes[i][j].(Focusable); ok { - isFocused = focusable.IsFocused() - } - - borderText := map[BorderPosition]string{} - if bordered, ok := g.panes[i][j].(Bordered); ok { - borderText = bordered.BorderText() - } - - content = Borderize(content, BorderOptions{ - Active: isFocused, - EmbeddedText: borderText, - }) - } - - cols[j] = content - } - - // Join columns with gap - rows[i] = lipgloss.JoinHorizontal(lipgloss.Top, cols...) - } - - // Join rows with gap - return lipgloss.JoinVertical(lipgloss.Left, rows...) -} - -func (g *gridLayout) SetSize(width, height int) { - g.width = width - g.height = height -} - -func (g *gridLayout) GetSize() (int, int) { - return g.width, g.height -} - -func (g *gridLayout) BindingKeys() []key.Binding { - var bindings []key.Binding - bindings = append(bindings, g.nextPaneBinding()) - - // Collect bindings from all panes - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - if bindable, ok := g.panes[i][j].(Bindings); ok { - bindings = append(bindings, bindable.BindingKeys()...) - } - } - } - } - - return bindings -} - -func (g *gridLayout) Panes() [][]tea.Model { - return g.panes -} - -// NewGridLayout creates a new grid layout with the given number of rows and columns -func NewGridLayout(rows, cols int, panes [][]tea.Model, opts ...GridOption) GridLayout { - grid := &gridLayout{ - rows: rows, - columns: cols, - panes: panes, - gap: 1, - } - - for _, opt := range opts { - opt(grid) - } - - return grid -} - -// WithGridGap sets the gap between cells -func WithGridGap(gap int) GridOption { - return func(g *gridLayout) { - g.gap = gap - } -} - -// WithGridBordered sets whether cells should have borders -func WithGridBordered(bordered bool) GridOption { - return func(g *gridLayout) { - g.bordered = bordered - } -} - -// WithGridFocusable sets whether the grid supports focus navigation -func WithGridFocusable(focusable bool) GridOption { - return func(g *gridLayout) { - g.focusable = focusable - } -} - -// WithGridActiveColor sets the active border color -func WithGridActiveColor(color lipgloss.TerminalColor) GridOption { - return func(g *gridLayout) { - g.activeColor = color - } -} diff --git a/internal/tui/layout/layout.go b/internal/tui/layout/layout.go index 2f17c4a0e..495a3fbc5 100644 --- a/internal/tui/layout/layout.go +++ b/internal/tui/layout/layout.go @@ -13,12 +13,8 @@ type Focusable interface { IsFocused() bool } -type Bordered interface { - BorderText() map[BorderPosition]string -} - type Sizeable interface { - SetSize(width, height int) + SetSize(width, height int) tea.Cmd GetSize() (int, int) } diff --git a/internal/tui/layout/single.go b/internal/tui/layout/single.go deleted file mode 100644 index c77fa0d78..000000000 --- a/internal/tui/layout/single.go +++ /dev/null @@ -1,189 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type SinglePaneLayout interface { - tea.Model - Focusable - Sizeable - Bindings - Pane() tea.Model -} - -type singlePaneLayout struct { - width int - height int - - focusable bool - focused bool - - bordered bool - borderText map[BorderPosition]string - - content tea.Model - - padding []int - - activeColor lipgloss.TerminalColor -} - -type SinglePaneOption func(*singlePaneLayout) - -func (s *singlePaneLayout) Init() tea.Cmd { - return s.content.Init() -} - -func (s *singlePaneLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - s.SetSize(msg.Width, msg.Height) - return s, nil - } - u, cmd := s.content.Update(msg) - s.content = u - return s, cmd -} - -func (s *singlePaneLayout) View() string { - style := lipgloss.NewStyle().Width(s.width).Height(s.height) - if s.bordered { - style = style.Width(s.width - 2).Height(s.height - 2) - } - if s.padding != nil { - style = style.Padding(s.padding...) - } - content := style.Render(s.content.View()) - if s.bordered { - if s.borderText == nil { - s.borderText = map[BorderPosition]string{} - } - if bordered, ok := s.content.(Bordered); ok { - s.borderText = bordered.BorderText() - } - return Borderize(content, BorderOptions{ - Active: s.focused, - EmbeddedText: s.borderText, - }) - } - return content -} - -func (s *singlePaneLayout) Blur() tea.Cmd { - if s.focusable { - s.focused = false - } - if blurable, ok := s.content.(Focusable); ok { - return blurable.Blur() - } - return nil -} - -func (s *singlePaneLayout) Focus() tea.Cmd { - if s.focusable { - s.focused = true - } - if focusable, ok := s.content.(Focusable); ok { - return focusable.Focus() - } - return nil -} - -func (s *singlePaneLayout) SetSize(width, height int) { - s.width = width - s.height = height - childWidth, childHeight := s.width, s.height - if s.bordered { - childWidth -= 2 - childHeight -= 2 - } - if s.padding != nil { - if len(s.padding) == 1 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[0] * 2 - } else if len(s.padding) == 2 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[1] * 2 - } else if len(s.padding) == 3 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[1] + s.padding[2] - } else if len(s.padding) == 4 { - childWidth -= s.padding[0] + s.padding[2] - childHeight -= s.padding[1] + s.padding[3] - } - } - if s.content != nil { - if c, ok := s.content.(Sizeable); ok { - c.SetSize(childWidth, childHeight) - } - } -} - -func (s *singlePaneLayout) IsFocused() bool { - return s.focused -} - -func (s *singlePaneLayout) GetSize() (int, int) { - return s.width, s.height -} - -func (s *singlePaneLayout) BindingKeys() []key.Binding { - if b, ok := s.content.(Bindings); ok { - return b.BindingKeys() - } - return []key.Binding{} -} - -func (s *singlePaneLayout) Pane() tea.Model { - return s.content -} - -func NewSinglePane(content tea.Model, opts ...SinglePaneOption) SinglePaneLayout { - layout := &singlePaneLayout{ - content: content, - } - for _, opt := range opts { - opt(layout) - } - return layout -} - -func WithSinglePaneSize(width, height int) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.width = width - opts.height = height - } -} - -func WithSinglePaneFocusable(focusable bool) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.focusable = focusable - } -} - -func WithSinglePaneBordered(bordered bool) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.bordered = bordered - } -} - -func WithSinglePaneBorderText(borderText map[BorderPosition]string) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.borderText = borderText - } -} - -func WithSinglePanePadding(padding ...int) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.padding = padding - } -} - -func WithSinglePaneActiveColor(color lipgloss.TerminalColor) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.activeColor = color - } -} diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index bfb616a53..a41df6ab8 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -11,9 +11,9 @@ type SplitPaneLayout interface { tea.Model Sizeable Bindings - SetLeftPanel(panel Container) - SetRightPanel(panel Container) - SetBottomPanel(panel Container) + SetLeftPanel(panel Container) tea.Cmd + SetRightPanel(panel Container) tea.Cmd + SetBottomPanel(panel Container) tea.Cmd } type splitPaneLayout struct { @@ -53,8 +53,7 @@ func (s *splitPaneLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - s.SetSize(msg.Width, msg.Height) - return s, nil + return s, s.SetSize(msg.Width, msg.Height) } if s.rightPanel != nil { @@ -122,7 +121,7 @@ func (s *splitPaneLayout) View() string { return finalView } -func (s *splitPaneLayout) SetSize(width, height int) { +func (s *splitPaneLayout) SetSize(width, height int) tea.Cmd { s.width = width s.height = height @@ -147,42 +146,50 @@ func (s *splitPaneLayout) SetSize(width, height int) { rightWidth = width } + var cmds []tea.Cmd if s.leftPanel != nil { - s.leftPanel.SetSize(leftWidth, topHeight) + cmd := s.leftPanel.SetSize(leftWidth, topHeight) + cmds = append(cmds, cmd) } if s.rightPanel != nil { - s.rightPanel.SetSize(rightWidth, topHeight) + cmd := s.rightPanel.SetSize(rightWidth, topHeight) + cmds = append(cmds, cmd) } if s.bottomPanel != nil { - s.bottomPanel.SetSize(width, bottomHeight) + cmd := s.bottomPanel.SetSize(width, bottomHeight) + cmds = append(cmds, cmd) } + return tea.Batch(cmds...) } func (s *splitPaneLayout) GetSize() (int, int) { return s.width, s.height } -func (s *splitPaneLayout) SetLeftPanel(panel Container) { +func (s *splitPaneLayout) SetLeftPanel(panel Container) tea.Cmd { s.leftPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } -func (s *splitPaneLayout) SetRightPanel(panel Container) { +func (s *splitPaneLayout) SetRightPanel(panel Container) tea.Cmd { s.rightPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } -func (s *splitPaneLayout) SetBottomPanel(panel Container) { +func (s *splitPaneLayout) SetBottomPanel(panel Container) tea.Cmd { s.bottomPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } func (s *splitPaneLayout) BindingKeys() []key.Binding { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 632e10764..b99dc3dfe 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -54,9 +54,11 @@ func (p *chatPage) Init() tea.Cmd { } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - p.layout.SetSize(msg.Width, msg.Height) + cmd := p.layout.SetSize(msg.Width, msg.Height) + cmds = append(cmds, cmd) case chat.SendMsg: cmd := p.sendMessage(msg.Text) if cmd != nil { @@ -68,8 +70,10 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch { case key.Matches(msg, keyMap.NewSession): p.session = session.Session{} - p.clearSidebar() - return p, util.CmdHandler(chat.SessionClearedMsg{}) + return p, tea.Batch( + p.clearSidebar(), + util.CmdHandler(chat.SessionClearedMsg{}), + ) case key.Matches(msg, keyMap.Cancel): if p.session.ID != "" { // Cancel the current session's generation process @@ -80,11 +84,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } u, cmd := p.layout.Update(msg) + cmds = append(cmds, cmd) p.layout = u.(layout.SplitPaneLayout) - if cmd != nil { - return p, cmd - } - return p, nil + return p, tea.Batch(cmds...) } func (p *chatPage) setSidebar() tea.Cmd { @@ -92,16 +94,11 @@ func (p *chatPage) setSidebar() tea.Cmd { chat.NewSidebarCmp(p.session, p.app.History), layout.WithPadding(1, 1, 1, 1), ) - p.layout.SetRightPanel(sidebarContainer) - width, height := p.layout.GetSize() - p.layout.SetSize(width, height) - return sidebarContainer.Init() + return tea.Batch(p.layout.SetRightPanel(sidebarContainer), sidebarContainer.Init()) } -func (p *chatPage) clearSidebar() { - p.layout.SetRightPanel(nil) - width, height := p.layout.GetSize() - p.layout.SetSize(width, height) +func (p *chatPage) clearSidebar() tea.Cmd { + return p.layout.SetRightPanel(nil) } func (p *chatPage) sendMessage(text string) tea.Cmd { @@ -124,8 +121,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { return tea.Batch(cmds...) } -func (p *chatPage) SetSize(width, height int) { - p.layout.SetSize(width, height) +func (p *chatPage) SetSize(width, height int) tea.Cmd { + return p.layout.SetSize(width, height) } func (p *chatPage) GetSize() (int, int) { diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index 0efc69e6e..f0d35fb7b 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -23,15 +23,14 @@ type logsPage struct { } func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: p.width = msg.Width p.height = msg.Height - p.table.SetSize(msg.Width, msg.Height/2) - p.details.SetSize(msg.Width, msg.Height/2) + return p, p.SetSize(msg.Width, msg.Height) } - var cmds []tea.Cmd table, cmd := p.table.Update(msg) cmds = append(cmds, cmd) p.table = table.(layout.Container) @@ -60,11 +59,13 @@ func (p *logsPage) GetSize() (int, int) { } // SetSize implements LogPage. -func (p *logsPage) SetSize(width int, height int) { +func (p *logsPage) SetSize(width int, height int) tea.Cmd { p.width = width p.height = height - p.table.SetSize(width, height/2) - p.details.SetSize(width, height/2) + return tea.Batch( + p.table.SetSize(width, height/2), + p.details.SetSize(width, height/2), + ) } func (p *logsPage) Init() tea.Cmd { diff --git a/internal/tui/styles/background.go b/internal/tui/styles/background.go index bf6cbc105..2fbb34efb 100644 --- a/internal/tui/styles/background.go +++ b/internal/tui/styles/background.go @@ -3,7 +3,6 @@ package styles import ( "fmt" "regexp" - "strconv" "strings" "github.com/charmbracelet/lipgloss" @@ -25,57 +24,100 @@ func getColorRGB(c lipgloss.TerminalColor) (uint8, uint8, uint8) { return uint8(r >> 8), uint8(g >> 8), uint8(b >> 8) } +// ForceReplaceBackgroundWithLipgloss replaces any ANSI background color codes +// in `input` with a single 24‑bit background (48;2;R;G;B). func ForceReplaceBackgroundWithLipgloss(input string, newBgColor lipgloss.TerminalColor) string { + // Precompute our new-bg sequence once r, g, b := getColorRGB(newBgColor) - newBg := fmt.Sprintf("48;2;%d;%d;%d", r, g, b) return ansiEscape.ReplaceAllStringFunc(input, func(seq string) string { - // Extract content between "\x1b[" and "m" - content := seq[2 : len(seq)-1] - tokens := strings.Split(content, ";") - var newTokens []string - - // Skip background color tokens - for i := 0; i < len(tokens); i++ { - if tokens[i] == "" { - continue - } + const ( + escPrefixLen = 2 // "\x1b[" + escSuffixLen = 1 // "m" + ) + + raw := seq + start := escPrefixLen + end := len(raw) - escSuffixLen - val, err := strconv.Atoi(tokens[i]) - if err != nil { - newTokens = append(newTokens, tokens[i]) - continue + var sb strings.Builder + // reserve enough space: original content minus bg codes + our newBg + sb.Grow((end - start) + len(newBg) + 2) + + // scan from start..end, token by token + for i := start; i < end; { + // find the next ';' or end + j := i + for j < end && raw[j] != ';' { + j++ } + token := raw[i:j] - // Skip background color tokens - if val == 48 { - // Skip "48;5;N" or "48;2;R;G;B" sequences - if i+1 < len(tokens) { - if nextVal, err := strconv.Atoi(tokens[i+1]); err == nil { - switch nextVal { - case 5: - i += 2 // Skip "5" and color index - case 2: - i += 4 // Skip "2" and RGB components + // fast‑path: skip "48;5;N" or "48;2;R;G;B" + if len(token) == 2 && token[0] == '4' && token[1] == '8' { + k := j + 1 + if k < end { + // find next token + l := k + for l < end && raw[l] != ';' { + l++ + } + next := raw[k:l] + if next == "5" { + // skip "48;5;N" + m := l + 1 + for m < end && raw[m] != ';' { + m++ + } + i = m + 1 + continue + } else if next == "2" { + // skip "48;2;R;G;B" + m := l + 1 + for count := 0; count < 3 && m < end; count++ { + for m < end && raw[m] != ';' { + m++ + } + m++ } + i = m + continue } } - } else if (val < 40 || val > 47) && (val < 100 || val > 107) && val != 49 { - // Keep non-background tokens - newTokens = append(newTokens, tokens[i]) } - } - // Add new background if provided - if newBg != "" { - newTokens = append(newTokens, strings.Split(newBg, ";")...) + // decide whether to keep this token + // manually parse ASCII digits to int + isNum := true + val := 0 + for p := i; p < j; p++ { + c := raw[p] + if c < '0' || c > '9' { + isNum = false + break + } + val = val*10 + int(c-'0') + } + keep := !isNum || + ((val < 40 || val > 47) && (val < 100 || val > 107) && val != 49) + + if keep { + if sb.Len() > 0 { + sb.WriteByte(';') + } + sb.WriteString(token) + } + // advance past this token (and the semicolon) + i = j + 1 } - if len(newTokens) == 0 { - return "" + // append our new background + if sb.Len() > 0 { + sb.WriteByte(';') } + sb.WriteString(newBg) - return "\x1b[" + strings.Join(newTokens, ";") + "m" + return "\x1b[" + sb.String() + "m" }) } diff --git a/internal/tui/styles/icons.go b/internal/tui/styles/icons.go index aa0df1e31..dd5f4dc51 100644 --- a/internal/tui/styles/icons.go +++ b/internal/tui/styles/icons.go @@ -2,19 +2,11 @@ package styles const ( OpenCodeIcon string = "⌬" - SessionsIcon string = "󰧑" - ChatIcon string = "󰭹" - - BotIcon string = "󰚩" - ToolIcon string = "" - UserIcon string = "" CheckIcon string = "✓" - ErrorIcon string = "" - WarningIcon string = "" + ErrorIcon string = "✖" + WarningIcon string = "⚠" InfoIcon string = "" - HintIcon string = "" + HintIcon string = "i" SpinnerIcon string = "..." - BugIcon string = "" - SleepIcon string = "󰒲" ) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 840ad4905..f3a7298cf 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,6 +1,8 @@ package tui import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -8,6 +10,7 @@ import ( "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/permission" "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" "github.com/kujtimiihoxha/opencode/internal/tui/components/core" "github.com/kujtimiihoxha/opencode/internal/tui/components/dialog" "github.com/kujtimiihoxha/opencode/internal/tui/layout" @@ -16,9 +19,10 @@ import ( ) type keyMap struct { - Logs key.Binding - Quit key.Binding - Help key.Binding + Logs key.Binding + Quit key.Binding + Help key.Binding + SwitchSession key.Binding } var keys = keyMap{ @@ -35,6 +39,10 @@ var keys = keyMap{ key.WithKeys("ctrl+_"), key.WithHelp("ctrl+?", "toggle help"), ), + SwitchSession: key.NewBinding( + key.WithKeys("ctrl+a"), + key.WithHelp("ctrl+a", "switch session"), + ), } var returnKey = key.NewBinding( @@ -64,6 +72,9 @@ type appModel struct { showQuit bool quit dialog.QuitDialog + + showSessionDialog bool + sessionDialog dialog.SessionDialog } func (a appModel) Init() tea.Cmd { @@ -77,6 +88,8 @@ func (a appModel) Init() tea.Cmd { cmds = append(cmds, cmd) cmd = a.help.Init() cmds = append(cmds, cmd) + cmd = a.sessionDialog.Init() + cmds = append(cmds, cmd) return tea.Batch(cmds...) } @@ -100,6 +113,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.help = help.(dialog.HelpCmp) cmds = append(cmds, helpCmd) + session, sessionCmd := a.sessionDialog.Update(msg) + a.sessionDialog = session.(dialog.SessionDialog) + cmds = append(cmds, sessionCmd) + return a, tea.Batch(cmds...) // Status @@ -144,8 +161,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Permission case pubsub.Event[permission.PermissionRequest]: a.showPermissions = true - a.permissions.SetPermissions(msg.Payload) - return a, nil + return a, a.permissions.SetPermissions(msg.Payload) case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: @@ -165,6 +181,19 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showQuit = false return a, nil + case dialog.CloseSessionDialogMsg: + a.showSessionDialog = false + return a, nil + + case chat.SessionSelectedMsg: + a.sessionDialog.SetSelectedSession(msg.ID) + case dialog.SessionSelectedMsg: + a.showSessionDialog = false + if a.currentPage == page.ChatPage { + return a, util.CmdHandler(chat.SessionSelectedMsg(msg.Session)) + } + return a, nil + case tea.KeyMsg: switch { case key.Matches(msg, keys.Quit): @@ -172,6 +201,24 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if a.showHelp { a.showHelp = false } + if a.showSessionDialog { + a.showSessionDialog = false + } + return a, nil + case key.Matches(msg, keys.SwitchSession): + if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions { + // Load sessions and show the dialog + sessions, err := a.app.Sessions.List(context.Background()) + if err != nil { + return a, util.ReportError(err) + } + if len(sessions) == 0 { + return a, util.ReportWarn("No sessions available") + } + a.sessionDialog.SetSessions(sessions) + a.showSessionDialog = true + return a, nil + } return a, nil case key.Matches(msg, logsKeyReturnKey): if a.currentPage == page.LogsPage { @@ -216,6 +263,16 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + if a.showSessionDialog { + d, sessionCmd := a.sessionDialog.Update(msg) + a.sessionDialog = d.(dialog.SessionDialog) + cmds = append(cmds, sessionCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + a.status, _ = a.status.Update(msg) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) @@ -223,18 +280,24 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { - var cmd tea.Cmd + if a.app.CoderAgent.IsBusy() { + // For now we don't move to any page if the agent is busy + return util.ReportWarn("Agent is busy, please wait...") + } + var cmds []tea.Cmd if _, ok := a.loadedPages[pageID]; !ok { - cmd = a.pages[pageID].Init() + cmd := a.pages[pageID].Init() + cmds = append(cmds, cmd) a.loadedPages[pageID] = true } a.previousPage = a.currentPage a.currentPage = pageID if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok { - sizable.SetSize(a.width, a.height) + cmd := sizable.SetSize(a.width, a.height) + cmds = append(cmds, cmd) } - return cmd + return tea.Batch(cmds...) } func (a appModel) View() string { @@ -304,19 +367,35 @@ func (a appModel) View() string { ) } + if a.showSessionDialog { + overlay := a.sessionDialog.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + return appView } func New(app *app.App) tea.Model { startPage := page.ChatPage return &appModel{ - currentPage: startPage, - loadedPages: make(map[page.PageID]bool), - status: core.NewStatusCmp(app.LSPClients), - help: dialog.NewHelpCmp(), - quit: dialog.NewQuitCmp(), - permissions: dialog.NewPermissionDialogCmp(), - app: app, + currentPage: startPage, + loadedPages: make(map[page.PageID]bool), + status: core.NewStatusCmp(app.LSPClients), + help: dialog.NewHelpCmp(), + quit: dialog.NewQuitCmp(), + sessionDialog: dialog.NewSessionDialogCmp(), + permissions: dialog.NewPermissionDialogCmp(), + app: app, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), -- cgit v1.2.3 From 72afeb9f54cee8e248093a52ac0779441c79aea3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 21:24:35 +0200 Subject: small fixes --- internal/diff/patch.go | 33 ++++++++++--------- internal/llm/agent/agent.go | 2 ++ internal/llm/tools/edit.go | 6 ++-- internal/tui/components/chat/editor.go | 9 +++-- internal/tui/components/chat/list.go | 49 ++++++++++++++++++++-------- internal/tui/components/chat/message.go | 1 + internal/tui/components/dialog/permission.go | 22 +++++++------ internal/tui/components/dialog/session.go | 48 ++++++++++++++------------- internal/tui/page/chat.go | 7 ---- internal/tui/tui.go | 4 ++- 10 files changed, 106 insertions(+), 75 deletions(-) (limited to 'internal/diff') diff --git a/internal/diff/patch.go b/internal/diff/patch.go index aab0f956d..49242f7ef 100644 --- a/internal/diff/patch.go +++ b/internal/diff/patch.go @@ -91,11 +91,9 @@ func (p *Parser) isDone(prefixes []string) bool { if p.index >= len(p.lines) { return true } - if prefixes != nil { - for _, prefix := range prefixes { - if strings.HasPrefix(p.lines[p.index], prefix) { - return true - } + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true } } return false @@ -219,7 +217,7 @@ func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { sectionStr = p.lines[p.index] p.index++ } - if !(defStr != "" || sectionStr != "" || index == 0) { + if defStr == "" && sectionStr == "" && index != 0 { return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) } if strings.TrimSpace(defStr) != "" { @@ -433,12 +431,13 @@ func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, delLines = make([]string, 0, 8) insLines = make([]string, 0, 8) } - if mode == "delete" { + switch mode { + case "delete": delLines = append(delLines, line) old = append(old, line) - } else if mode == "add" { + case "add": insLines = append(insLines, line) - } else { + default: old = append(old, line) } } @@ -513,7 +512,7 @@ func IdentifyFilesAdded(text string) []string { func getUpdatedFile(text string, action PatchAction, path string) (string, error) { if action.Type != ActionUpdate { - return "", errors.New("Expected UPDATE action") + return "", errors.New("expected UPDATE action") } origLines := strings.Split(text, "\n") destLines := make([]string, 0, len(origLines)) // Preallocate with capacity @@ -543,18 +542,19 @@ func getUpdatedFile(text string, action PatchAction, path string) (string, error func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} for pathKey, action := range patch.Actions { - if action.Type == ActionDelete { + switch action.Type { + case ActionDelete: oldContent := orig[pathKey] commit.Changes[pathKey] = FileChange{ Type: ActionDelete, OldContent: &oldContent, } - } else if action.Type == ActionAdd { + case ActionAdd: commit.Changes[pathKey] = FileChange{ Type: ActionAdd, NewContent: action.NewFile, } - } else if action.Type == ActionUpdate { + case ActionUpdate: newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) if err != nil { return Commit{}, err @@ -619,18 +619,19 @@ func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string] func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { for p, change := range commit.Changes { - if change.Type == ActionDelete { + switch change.Type { + case ActionDelete: if err := removeFn(p); err != nil { return err } - } else if change.Type == ActionAdd { + case ActionAdd: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) } if err := writeFn(p, *change.NewContent); err != nil { return err } - } else if change.Type == ActionUpdate { + case ActionUpdate: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5e9785991..7542d9adf 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -221,6 +221,8 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) if err != nil { if errors.Is(err, context.Canceled) { + agentMessage.AddFinish(message.FinishReasonCanceled) + a.messages.Update(context.Background(), agentMessage) return a.err(ErrRequestCancelled) } return a.err(fmt.Errorf("failed to process events: %w", err)) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 6a1616010..83cec5dba 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -141,20 +141,20 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if params.OldString == "" { response, err = e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { - return response, nil + return response, err } } if params.NewString == "" { response, err = e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { - return response, nil + return response, err } } response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { - return response, nil + return response, err } if response.IsError { // Return early if there was an error during content replacement diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 537ef392c..963fbbdbf 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -21,6 +21,8 @@ type editorCmp struct { textarea textarea.Model } +type FocusEditorMsg bool + type focusedEditorKeyMaps struct { Send key.Binding OpenEditor key.Binding @@ -112,7 +114,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(EditorFocusMsg(false)), ) } @@ -124,9 +125,13 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg } return m, nil + case FocusEditorMsg: + if msg { + m.textarea.Focus() + return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) + } case tea.KeyMsg: if key.Matches(msg, focusedKeyMaps.OpenEditor) { - m.textarea.Blur() return m, openEditor() } // if the key does not match any binding, return diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index f95b53731..b7703e2cc 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -22,6 +22,10 @@ import ( "github.com/kujtimiihoxha/opencode/internal/tui/util" ) +type cacheItem struct { + width int + content []uiMessage +} type messagesCmp struct { app *app.App width, height int @@ -32,8 +36,9 @@ type messagesCmp struct { uiMessages []uiMessage currentMsgID string mutex sync.Mutex - cachedContent map[string][]uiMessage + cachedContent map[string]cacheItem spinner spinner.Model + lastUpdate time.Time rendering bool } type renderFinishedMsg struct{} @@ -44,6 +49,8 @@ func (m *messagesCmp) Init() tea.Cmd { func (m *messagesCmp) preloadSessions() tea.Cmd { return func() tea.Msg { + m.mutex.Lock() + defer m.mutex.Unlock() sessions, err := m.app.Sessions.List(context.Background()) if err != nil { return util.ReportError(err)() @@ -67,13 +74,13 @@ func (m *messagesCmp) preloadSessions() tea.Cmd { } logging.Debug("preloaded sessions") - return nil + return func() tea.Msg { + return renderFinishedMsg{} + } } } func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { - m.mutex.Lock() - defer m.mutex.Unlock() pos := 0 if m.width == 0 { return @@ -87,7 +94,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int width, pos, ) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: assistantMessages := renderAssistantMessage( @@ -102,7 +112,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int for _, msg := range assistantMessages { pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: assistantMessages, + } } } } @@ -223,8 +236,8 @@ func (m *messagesCmp) renderView() { for inx, msg := range m.messages { switch msg.Role { case message.User: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } userMsg := renderUserMessage( @@ -234,11 +247,14 @@ func (m *messagesCmp) renderView() { pos, ) m.uiMessages = append(m.uiMessages, userMsg) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } assistantMessages := renderAssistantMessage( @@ -254,7 +270,10 @@ func (m *messagesCmp) renderView() { m.uiMessages = append(m.uiMessages, msg) pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: assistantMessages, + } } } @@ -418,6 +437,10 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { m.height = height m.viewport.Width = width m.viewport.Height = height - 2 + for _, msg := range m.messages { + delete(m.cachedContent, msg.ID) + } + m.uiMessages = make([]uiMessage, 0) m.renderView() return m.preloadSessions() } @@ -456,7 +479,7 @@ func NewMessagesCmp(app *app.App) tea.Model { return &messagesCmp{ app: app, writingMode: true, - cachedContent: make(map[string][]uiMessage), + cachedContent: make(map[string]cacheItem), viewport: viewport.New(0, 0), spinner: s, } diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index be6c7ce50..7a840b4ec 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -389,6 +389,7 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " ")) errContent = ansi.Truncate(errContent, width-1, "...") return styles.BaseStyle. + Width(width). Foreground(styles.Error). Render(errContent) } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index f83472e68..1f8df21a0 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -40,7 +40,8 @@ type PermissionDialogCmp interface { } type permissionsMapping struct { - LeftRight key.Binding + Left key.Binding + Right key.Binding EnterSpace key.Binding Allow key.Binding AllowSession key.Binding @@ -49,9 +50,13 @@ type permissionsMapping struct { } var permissionsKeys = permissionsMapping{ - LeftRight: key.NewBinding( - key.WithKeys("left", "right"), - key.WithHelp("←/→", "switch options"), + Left: key.NewBinding( + key.WithKeys("left"), + key.WithHelp("←", "switch options"), + ), + Right: key.NewBinding( + key.WithKeys("right"), + key.WithHelp("→", "switch options"), ), EnterSpace: key.NewBinding( key.WithKeys("enter", " "), @@ -104,21 +109,18 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.diffCache = make(map[string]string) case tea.KeyMsg: switch { - case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): - // Change selected option + case key.Matches(msg, permissionsKeys.Right) || key.Matches(msg, permissionsKeys.Tab): p.selectedOption = (p.selectedOption + 1) % 3 return p, nil + case key.Matches(msg, permissionsKeys.Left): + p.selectedOption = (p.selectedOption + 2) % 3 case key.Matches(msg, permissionsKeys.EnterSpace): - // Select current option return p, p.selectCurrentOption() case key.Matches(msg, permissionsKeys.Allow): - // Select Allow return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) case key.Matches(msg, permissionsKeys.AllowSession): - // Select Allow for session return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) case key.Matches(msg, permissionsKeys.Deny): - // Select Deny return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) default: // Pass other keys to viewport diff --git a/internal/tui/components/dialog/session.go b/internal/tui/components/dialog/session.go index d8c859c49..060875f91 100644 --- a/internal/tui/components/dialog/session.go +++ b/internal/tui/components/dialog/session.go @@ -27,20 +27,20 @@ type SessionDialog interface { } type sessionDialogCmp struct { - sessions []session.Session - selectedIdx int - width int - height int + sessions []session.Session + selectedIdx int + width int + height int selectedSessionID string } type sessionKeyMap struct { - Up key.Binding - Down key.Binding - Enter key.Binding - Escape key.Binding - J key.Binding - K key.Binding + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding } var sessionKeys = sessionKeyMap{ @@ -128,7 +128,7 @@ func (s *sessionDialogCmp) View() string { // Build the session list sessionItems := make([]string, 0, maxVisibleSessions) startIdx := 0 - + // If we have more sessions than can be displayed, adjust the start index if len(s.sessions) > maxVisibleSessions { // Center the selected item when possible @@ -145,30 +145,31 @@ func (s *sessionDialogCmp) View() string { for i := startIdx; i < endIdx; i++ { sess := s.sessions[i] itemStyle := styles.BaseStyle.Width(maxWidth) - + if i == s.selectedIdx { itemStyle = itemStyle. Background(styles.PrimaryColor). Foreground(styles.Background). Bold(true) } - + sessionItems = append(sessionItems, itemStyle.Padding(0, 1).Render(sess.Title)) } title := styles.BaseStyle. Foreground(styles.PrimaryColor). Bold(true). + Width(maxWidth). Padding(0, 1). Render("Switch Session") content := lipgloss.JoinVertical( lipgloss.Left, title, - styles.BaseStyle.Render(""), - lipgloss.JoinVertical(lipgloss.Left, sessionItems...), - styles.BaseStyle.Render(""), - styles.BaseStyle.Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Render(lipgloss.JoinVertical(lipgloss.Left, sessionItems...)), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Padding(0, 1).Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), ) return styles.BaseStyle.Padding(1, 2). @@ -185,7 +186,7 @@ func (s *sessionDialogCmp) BindingKeys() []key.Binding { func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { s.sessions = sessions - + // If we have a selected session ID, find its index if s.selectedSessionID != "" { for i, sess := range sessions { @@ -195,14 +196,14 @@ func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { } } } - + // Default to first session if selected not found s.selectedIdx = 0 } func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { s.selectedSessionID = sessionID - + // Update the selected index if sessions are already loaded if len(s.sessions) > 0 { for i, sess := range s.sessions { @@ -217,8 +218,9 @@ func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { // NewSessionDialogCmp creates a new session switching dialog func NewSessionDialogCmp() SessionDialog { return &sessionDialogCmp{ - sessions: []session.Session{}, - selectedIdx: 0, + sessions: []session.Session{}, + selectedIdx: 0, selectedSessionID: "", } -} \ No newline at end of file +} + diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index b99dc3dfe..ef826e9a3 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -43,13 +43,6 @@ func (p *chatPage) Init() tea.Cmd { cmds := []tea.Cmd{ p.layout.Init(), } - - sessions, _ := p.app.Sessions.List(context.Background()) - if len(sessions) > 0 { - p.session = sessions[0] - cmd := p.setSidebar() - cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) - } return tea.Batch(cmds...) } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index f3a7298cf..2a9ed0d70 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -163,6 +163,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showPermissions = true return a, a.permissions.SetPermissions(msg.Payload) case dialog.PermissionResponseMsg: + var cmd tea.Cmd switch msg.Action { case dialog.PermissionAllow: a.app.Permissions.Grant(msg.Permission) @@ -170,9 +171,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.app.Permissions.GrantPersistant(msg.Permission) case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) + cmd = util.CmdHandler(chat.FocusEditorMsg(true)) } a.showPermissions = false - return a, nil + return a, cmd case page.PageChangeMsg: return a, a.moveToPage(msg.ID) -- cgit v1.2.3