From bd2ec29b65e430f83f430db5fdc424c7d631989d Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 12 Apr 2025 18:45:36 +0200 Subject: add initial git support --- cmd/diff/main.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ cmd/git/main.go | 4 +++ cmd/root.go | 6 ++++ 3 files changed, 112 insertions(+) create mode 100644 cmd/diff/main.go create mode 100644 cmd/git/main.go (limited to 'cmd') diff --git a/cmd/diff/main.go b/cmd/diff/main.go new file mode 100644 index 000000000..da93e4660 --- /dev/null +++ b/cmd/diff/main.go @@ -0,0 +1,102 @@ +package main + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +func main() { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "git-split-diffs") + if err != nil { + fmt.Printf("Error creating temp directory: %v\n", err) + os.Exit(1) + } + defer func() { + fmt.Printf("Cleaning up temporary directory: %s\n", tempDir) + os.RemoveAll(tempDir) + }() + fmt.Printf("Created temporary directory: %s\n", tempDir) + + // Clone the repository with minimum depth + fmt.Println("Cloning git-split-diffs repository with minimum depth...") + cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + fmt.Printf("Error cloning repository: %v\n", err) + os.Exit(1) + } + + // Run npm install + fmt.Println("Running npm install...") + cmdNpmInstall := exec.Command("npm", "install") + cmdNpmInstall.Dir = tempDir + cmdNpmInstall.Stdout = os.Stdout + cmdNpmInstall.Stderr = os.Stderr + if err := cmdNpmInstall.Run(); err != nil { + fmt.Printf("Error running npm install: %v\n", err) + os.Exit(1) + } + + // Run npm run build + fmt.Println("Running npm run build...") + cmdNpmBuild := exec.Command("npm", "run", "build") + cmdNpmBuild.Dir = tempDir + cmdNpmBuild.Stdout = os.Stdout + cmdNpmBuild.Stderr = os.Stderr + if err := cmdNpmBuild.Run(); err != nil { + fmt.Printf("Error running npm run build: %v\n", err) + os.Exit(1) + } + + destDir := filepath.Join(".", "internal", "assets", "diff") + destFile := filepath.Join(destDir, "index.mjs") + + // Make sure the destination directory exists + if err := os.MkdirAll(destDir, 0o755); err != nil { + fmt.Printf("Error creating destination directory: %v\n", err) + os.Exit(1) + } + + // Copy the file + srcFile := filepath.Join(tempDir, "build", "index.mjs") + fmt.Printf("Copying %s to %s\n", srcFile, destFile) + if err := copyFile(srcFile, destFile); err != nil { + fmt.Printf("Error copying file: %v\n", err) + os.Exit(1) + } + + fmt.Println("Successfully completed the process!") +} + +// copyFile copies a file from src to dst +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(destFile, sourceFile) + if err != nil { + return err + } + + // Make sure the file is written to disk + err = destFile.Sync() + if err != nil { + return err + } + + return nil +} diff --git a/cmd/git/main.go b/cmd/git/main.go new file mode 100644 index 000000000..da29a2cad --- /dev/null +++ b/cmd/git/main.go @@ -0,0 +1,4 @@ +package main + +func main() { +} diff --git a/cmd/root.go b/cmd/root.go index bdab53e14..d846a14c2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/assets" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -28,6 +29,9 @@ var rootCmd = &cobra.Command{ } debug, _ := cmd.Flags().GetBool("debug") err := config.Load(debug) + if err != nil { + return err + } cfg := config.Get() defaultLevel := slog.LevelInfo if cfg.Debug { @@ -38,9 +42,11 @@ var rootCmd = &cobra.Command{ })) slog.SetDefault(logger) + err = assets.WriteAssets() if err != nil { return err } + conn, err := db.Connect() if err != nil { return err -- cgit v1.2.3 From 3ad983db0f2c08826d56cb5de274d706c95b3353 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 13:17:17 +0200 Subject: cleanup app, config and root --- .gitignore | 2 +- .opencode.json | 11 ++ .termai.json | 11 -- cmd/git/main.go | 4 - cmd/root.go | 253 +++++++++++++++++++++++-------- internal/app/app.go | 76 ++++++++++ internal/app/lsp.go | 108 +++++++++++++ internal/app/services.go | 64 -------- internal/config/config.go | 20 +-- internal/history/file.go | 73 +++++---- internal/llm/agent/agent-tool.go | 10 +- internal/llm/agent/agent.go | 53 +++---- internal/llm/agent/coder.go | 5 +- internal/llm/agent/task.go | 3 +- internal/message/message.go | 46 +++--- internal/session/session.go | 44 +++--- internal/tui/components/chat/messages.go | 5 +- internal/tui/components/repl/editor.go | 4 +- internal/tui/components/repl/messages.go | 7 +- internal/tui/components/repl/sessions.go | 4 +- internal/tui/page/chat.go | 8 +- internal/tui/tui.go | 6 +- 22 files changed, 525 insertions(+), 292 deletions(-) create mode 100644 .opencode.json delete mode 100644 .termai.json delete mode 100644 cmd/git/main.go create mode 100644 internal/app/app.go create mode 100644 internal/app/lsp.go delete mode 100644 internal/app/services.go (limited to 'cmd') diff --git a/.gitignore b/.gitignore index 388f8b2ca..0ef6e2aef 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,6 @@ debug.log .env .env.local -.termai +.opencode internal/assets/diff/index.mjs diff --git a/.opencode.json b/.opencode.json new file mode 100644 index 000000000..f63a63dba --- /dev/null +++ b/.opencode.json @@ -0,0 +1,11 @@ +{ + "model": { + "coder": "claude-3.7-sonnet", + "coderMaxTokens": 20000 + }, + "lsp": { + "gopls": { + "command": "gopls" + } + } +} diff --git a/.termai.json b/.termai.json deleted file mode 100644 index f63a63dba..000000000 --- a/.termai.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "model": { - "coder": "claude-3.7-sonnet", - "coderMaxTokens": 20000 - }, - "lsp": { - "gopls": { - "command": "gopls" - } - } -} diff --git a/cmd/git/main.go b/cmd/git/main.go deleted file mode 100644 index da29a2cad..000000000 --- a/cmd/git/main.go +++ /dev/null @@ -1,4 +0,0 @@ -package main - -func main() { -} diff --git a/cmd/root.go b/cmd/root.go index d846a14c2..092606de7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,9 +2,10 @@ package cmd import ( "context" - "log/slog" + "fmt" "os" "sync" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" @@ -13,6 +14,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" @@ -23,111 +25,229 @@ var rootCmd = &cobra.Command{ Short: "A terminal ai assistant", Long: `A terminal ai assistant`, RunE: func(cmd *cobra.Command, args []string) error { + // If the help flag is set, show the help message if cmd.Flag("help").Changed { cmd.Help() return nil } + + // Load the config debug, _ := cmd.Flags().GetBool("debug") - err := config.Load(debug) + cwd, _ := cmd.Flags().GetString("cwd") + if cwd != "" { + err := os.Chdir(cwd) + if err != nil { + return fmt.Errorf("failed to change directory: %v", err) + } + } + if cwd == "" { + c, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current working directory: %v", err) + } + cwd = c + } + _, err := config.Load(cwd, debug) if err != nil { return err } - cfg := config.Get() - defaultLevel := slog.LevelInfo - if cfg.Debug { - defaultLevel = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) err = assets.WriteAssets() if err != nil { - return err + logging.Error("Error writing assets: %v", err) } + // Connect DB, this will also run migrations conn, err := db.Connect() if err != nil { return err } - ctx := context.Background() + + // Create main context for the application + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() app := app.New(ctx, conn) - logging.Info("Starting termai...") + + // Set up the TUI zone.NewGlobal() - tui := tea.NewProgram( + program := tea.NewProgram( tui.New(app), tea.WithAltScreen(), tea.WithMouseCellMotion(), ) - logging.Info("Setting up subscriptions...") - ch, unsub := setupSubscriptions(app) - defer unsub() + // Initialize MCP tools in the background + initMCPTools(ctx, app) + + // Setup the subscriptions, this will send services events to the TUI + ch, cancelSubs := setupSubscriptions(app) + + // Create a context for the TUI message handler + tuiCtx, tuiCancel := context.WithCancel(ctx) + var tuiWg sync.WaitGroup + tuiWg.Add(1) + + // Set up message handling for the TUI go func() { - // Set this up once - agent.GetMcpTools(ctx, app.Permissions) - for msg := range ch { - tui.Send(msg) + defer tuiWg.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in TUI message handling: %v", r) + attemptTUIRecovery(program) + } + }() + + for { + select { + case <-tuiCtx.Done(): + logging.Info("TUI message handler shutting down") + return + case msg, ok := <-ch: + if !ok { + logging.Info("TUI message channel closed") + return + } + program.Send(msg) + } } }() - if _, err := tui.Run(); err != nil { - return err + + // Cleanup function for when the program exits + cleanup := func() { + // Shutdown the app + app.Shutdown() + + // Cancel subscriptions first + cancelSubs() + + // Then cancel TUI message handler + tuiCancel() + + // Wait for TUI message handler to finish + tuiWg.Wait() + + logging.Info("All goroutines cleaned up") + } + + // Run the TUI + result, err := program.Run() + cleanup() + + if err != nil { + logging.Error("TUI error: %v", err) + return fmt.Errorf("TUI error: %v", err) } + + logging.Info("TUI exited with result: %v", result) return nil }, } -func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { - ch := make(chan tea.Msg) - wg := sync.WaitGroup{} - ctx, cancel := context.WithCancel(app.Context) - { - sub := logging.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev +// attemptTUIRecovery tries to recover the TUI after a panic +func attemptTUIRecovery(program *tea.Program) { + logging.Info("Attempting to recover TUI after panic") + + // We could try to restart the TUI or gracefully exit + // For now, we'll just quit the program to avoid further issues + program.Quit() +} + +func initMCPTools(ctx context.Context, app *app.App) { + go func() { + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in MCP goroutine: %v", r) } - wg.Done() }() - } - { - sub := app.Sessions.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev + + // Create a context with timeout for the initial MCP tools fetch + ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Set this up once with proper error handling + agent.GetMcpTools(ctxWithTimeout, app.Permissions) + logging.Info("MCP message handling goroutine exiting") + }() +} + +func setupSubscriber[T any]( + ctx context.Context, + wg *sync.WaitGroup, + name string, + subscriber func(context.Context) <-chan pubsub.Event[T], + outputCh chan<- tea.Msg, +) { + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in %s subscription goroutine: %v", name, r) } - wg.Done() }() - } - { - sub := app.Messages.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev + + for { + select { + case event, ok := <-subscriber(ctx): + if !ok { + logging.Info("%s subscription channel closed", name) + return + } + + // Convert generic event to tea.Msg if needed + var msg tea.Msg = event + + // Non-blocking send with timeout to prevent deadlocks + select { + case outputCh <- msg: + case <-time.After(500 * time.Millisecond): + logging.Warn("%s message dropped due to slow consumer", name) + case <-ctx.Done(): + logging.Info("%s subscription cancelled", name) + return + } + case <-ctx.Done(): + logging.Info("%s subscription cancelled", name) + return } - wg.Done() - }() - } - { - sub := app.Permissions.Subscribe(ctx) - wg.Add(1) + } + }() +} + +func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { + ch := make(chan tea.Msg, 100) + // Add a buffer to prevent blocking + wg := sync.WaitGroup{} + ctx, cancel := context.WithCancel(context.Background()) + // Setup each subscription using the helper + setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch) + setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch) + setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch) + setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch) + + // Return channel and a cleanup function + cleanupFunc := func() { + logging.Info("Cancelling all subscriptions") + cancel() // Signal all goroutines to stop + + // Wait with a timeout for all goroutines to complete + waitCh := make(chan struct{}) go func() { - for ev := range sub { - ch <- ev - } - wg.Done() + wg.Wait() + close(waitCh) }() + + select { + case <-waitCh: + logging.Info("All subscription goroutines completed successfully") + case <-time.After(5 * time.Second): + logging.Warn("Timed out waiting for some subscription goroutines to complete") + } + + close(ch) // Safe to close after all writers are done or timed out } - return ch, func() { - cancel() - wg.Wait() - close(ch) - } + return ch, cleanupFunc } func Execute() { @@ -139,5 +259,6 @@ func Execute() { func init() { rootCmd.Flags().BoolP("help", "h", false, "Help") - rootCmd.Flags().BoolP("debug", "d", false, "Help") + rootCmd.Flags().BoolP("debug", "d", false, "Debug") + rootCmd.Flags().StringP("cwd", "c", "", "Current working directory") } diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 000000000..fa4a6ee90 --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,76 @@ +package app + +import ( + "context" + "database/sql" + "maps" + "sync" + "time" + + "github.com/kujtimiihoxha/termai/internal/db" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +type App struct { + Sessions session.Service + Messages message.Service + Files history.Service + Permissions permission.Service + + LSPClients map[string]*lsp.Client + + clientsMutex sync.RWMutex + + watcherCancelFuncs []context.CancelFunc + cancelFuncsMutex sync.Mutex + watcherWG sync.WaitGroup +} + +func New(ctx context.Context, conn *sql.DB) *App { + q := db.New(conn) + sessions := session.NewService(q) + messages := message.NewService(q) + files := history.NewService(q) + + app := &App{ + Sessions: sessions, + Messages: messages, + Files: files, + Permissions: permission.NewPermissionService(), + LSPClients: make(map[string]*lsp.Client), + } + + app.initLSPClients(ctx) + + return app +} + +// Shutdown performs a clean shutdown of the application +func (app *App) Shutdown() { + // Cancel all watcher goroutines + app.cancelFuncsMutex.Lock() + for _, cancel := range app.watcherCancelFuncs { + cancel() + } + app.cancelFuncsMutex.Unlock() + app.watcherWG.Wait() + + // Perform additional cleanup for LSP clients + app.clientsMutex.RLock() + clients := make(map[string]*lsp.Client, len(app.LSPClients)) + maps.Copy(clients, app.LSPClients) + app.clientsMutex.RUnlock() + + for name, client := range clients { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := client.Shutdown(shutdownCtx); err != nil { + logging.Error("Failed to shutdown LSP client", "name", name, "error", err) + } + cancel() + } +} diff --git a/internal/app/lsp.go b/internal/app/lsp.go new file mode 100644 index 000000000..4e0568f07 --- /dev/null +++ b/internal/app/lsp.go @@ -0,0 +1,108 @@ +package app + +import ( + "context" + "time" + + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/watcher" +) + +func (app *App) initLSPClients(ctx context.Context) { + cfg := config.Get() + + // Initialize LSP clients + for name, clientConfig := range cfg.LSP { + app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) + } +} + +// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher +func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { + // Create a specific context for initialization with a timeout + initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer initCancel() + + // Create the LSP client + lspClient, err := lsp.NewClient(initCtx, command, args...) + if err != nil { + logging.Error("Failed to create LSP client for", name, err) + return + } + + // Initialize with the initialization context + _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) + if err != nil { + logging.Error("Initialize failed", "name", name, "error", err) + // Clean up the client to prevent resource leaks + lspClient.Close() + return + } + + // Create a child context that can be canceled when the app is shutting down + watchCtx, cancelFunc := context.WithCancel(ctx) + workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) + + // Store the cancel function to be called during cleanup + app.cancelFuncsMutex.Lock() + app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc) + app.cancelFuncsMutex.Unlock() + + // Add the watcher to a WaitGroup to track active goroutines + app.watcherWG.Add(1) + + // Add to map with mutex protection before starting goroutine + app.clientsMutex.Lock() + app.LSPClients[name] = lspClient + app.clientsMutex.Unlock() + + go app.runWorkspaceWatcher(watchCtx, name, workspaceWatcher) +} + +// runWorkspaceWatcher executes the workspace watcher for an LSP client +func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) { + defer app.watcherWG.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("LSP client crashed", "client", name, "panic", r) + + // Try to restart the client + app.restartLSPClient(ctx, name) + } + }() + + workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) + logging.Info("Workspace watcher stopped", "client", name) +} + +// restartLSPClient attempts to restart a crashed or failed LSP client +func (app *App) restartLSPClient(ctx context.Context, name string) { + // Get the original configuration + cfg := config.Get() + clientConfig, exists := cfg.LSP[name] + if !exists { + logging.Error("Cannot restart client, configuration not found", "client", name) + return + } + + // Clean up the old client if it exists + app.clientsMutex.Lock() + oldClient, exists := app.LSPClients[name] + if exists { + delete(app.LSPClients, name) // Remove from map before potentially slow shutdown + } + app.clientsMutex.Unlock() + + if exists && oldClient != nil { + // Try to shut it down gracefully, but don't block on errors + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = oldClient.Shutdown(shutdownCtx) + cancel() + } + + // Create a new client using the shared function + app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) + logging.Info("Successfully restarted LSP client", "client", name) +} diff --git a/internal/app/services.go b/internal/app/services.go deleted file mode 100644 index 6ecdef03c..000000000 --- a/internal/app/services.go +++ /dev/null @@ -1,64 +0,0 @@ -package app - -import ( - "context" - "database/sql" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/watcher" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type App struct { - Context context.Context - - Sessions session.Service - Messages message.Service - Files history.Service - Permissions permission.Service - - LSPClients map[string]*lsp.Client -} - -func New(ctx context.Context, conn *sql.DB) *App { - cfg := config.Get() - logging.Info("Debug mode enabled") - - q := db.New(conn) - sessions := session.NewService(ctx, q) - messages := message.NewService(ctx, q) - files := history.NewService(ctx, q) - - app := &App{ - Context: ctx, - Sessions: sessions, - Messages: messages, - Files: files, - Permissions: permission.NewPermissionService(), - LSPClients: make(map[string]*lsp.Client), - } - - for name, client := range cfg.LSP { - lspClient, err := lsp.NewClient(ctx, client.Command, client.Args...) - workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) - if err != nil { - logging.Error("Failed to create LSP client for", name, err) - continue - } - - _, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory()) - if err != nil { - logging.Error("Initialize failed", "error", err) - continue - } - go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) - app.LSPClients[name] = lspClient - } - return app -} diff --git a/internal/config/config.go b/internal/config/config.go index 6f757b3f4..1f3091ff3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,9 +83,9 @@ var cfg *Config // Load initializes the configuration from environment variables and config files. // If debug is true, debug mode is enabled and log level is set to debug. // It returns an error if configuration loading fails. -func Load(workingDir string, debug bool) error { +func Load(workingDir string, debug bool) (*Config, error) { if cfg != nil { - return nil + return cfg, nil } cfg = &Config{ @@ -101,7 +101,7 @@ func Load(workingDir string, debug bool) error { // Read global config if err := readConfig(viper.ReadInConfig()); err != nil { - return err + return cfg, err } // Load and merge local config @@ -109,7 +109,7 @@ func Load(workingDir string, debug bool) error { // Apply configuration to the struct if err := viper.Unmarshal(cfg); err != nil { - return err + return cfg, fmt.Errorf("failed to unmarshal config: %w", err) } applyDefaultValues() @@ -123,7 +123,7 @@ func Load(workingDir string, debug bool) error { Level: defaultLevel, })) slog.SetDefault(logger) - return nil + return cfg, nil } // configureViper sets up viper's configuration paths and environment variables. @@ -237,7 +237,7 @@ func readConfig(err error) error { return nil } - return err + return fmt.Errorf("failed to read config: %w", err) } // mergeLocalConfig loads and merges configuration from the local directory. @@ -264,14 +264,6 @@ func applyDefaultValues() { } } -// setWorkingDirectory stores the current working directory in the configuration. -func setWorkingDirectory() { - workdir, err := os.Getwd() - if err == nil { - viper.Set("wd", workdir) - } -} - // Get returns the current configuration. // It's safe to call this function multiple times. func Get() *Config { diff --git a/internal/history/file.go b/internal/history/file.go index 25953b273..82017d4cf 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -27,45 +27,43 @@ type File struct { type Service interface { pubsub.Suscriber[File] - Create(sessionID, path, content string) (File, error) - CreateVersion(sessionID, path, content string) (File, error) - Get(id string) (File, error) - GetByPathAndSession(path, sessionID string) (File, error) - ListBySession(sessionID string) ([]File, error) - ListLatestSessionFiles(sessionID string) ([]File, error) - Update(file File) (File, error) - Delete(id string) error - DeleteSessionFiles(sessionID string) error + Create(ctx context.Context, sessionID, path, content string) (File, error) + CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) + Get(ctx context.Context, id string) (File, error) + GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) + ListBySession(ctx context.Context, sessionID string) ([]File, error) + ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) + Update(ctx context.Context, file File) (File, error) + Delete(ctx context.Context, id string) error + DeleteSessionFiles(ctx context.Context, sessionID string) error } type service struct { *pubsub.Broker[File] - q db.Querier - ctx context.Context + q db.Querier } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { return &service{ Broker: pubsub.NewBroker[File](), q: q, - ctx: ctx, } } -func (s *service) Create(sessionID, path, content string) (File, error) { - return s.createWithVersion(sessionID, path, content, InitialVersion) +func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) { + return s.createWithVersion(ctx, sessionID, path, content, InitialVersion) } -func (s *service) CreateVersion(sessionID, path, content string) (File, error) { +func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) { // Get the latest version for this path - files, err := s.q.ListFilesByPath(s.ctx, path) + files, err := s.q.ListFilesByPath(ctx, path) if err != nil { return File{}, err } if len(files) == 0 { // No previous versions, create initial - return s.Create(sessionID, path, content) + return s.Create(ctx, sessionID, path, content) } // Get the latest version @@ -89,11 +87,11 @@ func (s *service) CreateVersion(sessionID, path, content string) (File, error) { nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) } - return s.createWithVersion(sessionID, path, content, nextVersion) + return s.createWithVersion(ctx, sessionID, path, content, nextVersion) } -func (s *service) createWithVersion(sessionID, path, content, version string) (File, error) { - dbFile, err := s.q.CreateFile(s.ctx, db.CreateFileParams{ +func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) { + dbFile, err := s.q.CreateFile(ctx, db.CreateFileParams{ ID: uuid.New().String(), SessionID: sessionID, Path: path, @@ -108,16 +106,16 @@ func (s *service) createWithVersion(sessionID, path, content, version string) (F return file, nil } -func (s *service) Get(id string) (File, error) { - dbFile, err := s.q.GetFile(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (File, error) { + dbFile, err := s.q.GetFile(ctx, id) if err != nil { return File{}, err } return s.fromDBItem(dbFile), nil } -func (s *service) GetByPathAndSession(path, sessionID string) (File, error) { - dbFile, err := s.q.GetFileByPathAndSession(s.ctx, db.GetFileByPathAndSessionParams{ +func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) { + dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{ Path: path, SessionID: sessionID, }) @@ -127,8 +125,8 @@ func (s *service) GetByPathAndSession(path, sessionID string) (File, error) { return s.fromDBItem(dbFile), nil } -func (s *service) ListBySession(sessionID string) ([]File, error) { - dbFiles, err := s.q.ListFilesBySession(s.ctx, sessionID) +func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) { + dbFiles, err := s.q.ListFilesBySession(ctx, sessionID) if err != nil { return nil, err } @@ -139,8 +137,8 @@ func (s *service) ListBySession(sessionID string) ([]File, error) { return files, nil } -func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) { - dbFiles, err := s.q.ListLatestSessionFiles(s.ctx, sessionID) +func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) { + dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID) if err != nil { return nil, err } @@ -151,8 +149,8 @@ func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) { return files, nil } -func (s *service) Update(file File) (File, error) { - dbFile, err := s.q.UpdateFile(s.ctx, db.UpdateFileParams{ +func (s *service) Update(ctx context.Context, file File) (File, error) { + dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{ ID: file.ID, Content: file.Content, Version: file.Version, @@ -165,12 +163,12 @@ func (s *service) Update(file File) (File, error) { return updatedFile, nil } -func (s *service) Delete(id string) error { - file, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + file, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteFile(s.ctx, id) + err = s.q.DeleteFile(ctx, id) if err != nil { return err } @@ -178,13 +176,13 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) DeleteSessionFiles(sessionID string) error { - files, err := s.ListBySession(sessionID) +func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := s.ListBySession(ctx, sessionID) if err != nil { return err } for _, file := range files { - err = s.Delete(file.ID) + err = s.Delete(ctx, file.ID) if err != nil { return err } @@ -203,4 +201,3 @@ func (s *service) fromDBItem(item db.File) File { UpdatedAt: item.UpdatedAt, } } - diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index deb6aed60..91c46da8b 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -51,7 +51,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil } - session, err := b.app.Sessions.CreateTaskSession(call.ID, b.parentSessionID, "New Agent Session") + session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session") if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } @@ -61,7 +61,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } - messages, err := b.app.Messages.List(session.ID) + messages, err := b.app.Messages.List(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil } @@ -74,11 +74,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("no assistant message found"), nil } - updatedSession, err := b.app.Sessions.Get(session.ID) + updatedSession, err := b.app.Sessions.Get(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - parentSession, err := b.app.Sessions.Get(b.parentSessionID) + parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } @@ -87,7 +87,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes parentSession.PromptTokens += updatedSession.PromptTokens parentSession.CompletionTokens += updatedSession.CompletionTokens - _, err = b.app.Sessions.Save(parentSession) + _, err = b.app.Sessions.Save(ctx, parentSession) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 89de627f7..b7c736e6c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -48,7 +48,7 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st return } - session, err := c.Sessions.Get(sessionID) + session, err := c.Sessions.Get(ctx, sessionID) if err != nil { return } @@ -56,12 +56,12 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st session.Title = response.Content session.Title = strings.TrimSpace(session.Title) session.Title = strings.ReplaceAll(session.Title, "\n", " ") - c.Sessions.Save(session) + c.Sessions.Save(ctx, session) } } -func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := c.Sessions.Get(sessionID) +func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + session, err := c.Sessions.Get(ctx, sessionID) if err != nil { return err } @@ -75,11 +75,12 @@ func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider. session.CompletionTokens += usage.OutputTokens session.PromptTokens += usage.InputTokens - _, err = c.Sessions.Save(session) + _, err = c.Sessions.Save(ctx, session) return err } func (c *agent) processEvent( + ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent, @@ -87,10 +88,10 @@ func (c *agent) processEvent( switch event.Type { case provider.EventThinkingDelta: assistantMsg.AppendReasoningContent(event.Content) - return c.Messages.Update(*assistantMsg) + return c.Messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) - return c.Messages.Update(*assistantMsg) + return c.Messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { return nil @@ -105,11 +106,11 @@ func (c *agent) processEvent( case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) - err := c.Messages.Update(*assistantMsg) + err := c.Messages.Update(ctx, *assistantMsg) if err != nil { return err } - return c.TrackUsage(sessionID, c.model, event.Response.Usage) + return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage) } return nil @@ -237,7 +238,7 @@ func (c *agent) handleToolExecution( for _, toolResult := range toolResults { parts = append(parts, toolResult) } - msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{ + msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) @@ -247,7 +248,7 @@ func (c *agent) handleToolExecution( func (c *agent) generate(ctx context.Context, sessionID string, content string) error { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - messages, err := c.Messages.List(sessionID) + messages, err := c.Messages.List(ctx, sessionID) if err != nil { return err } @@ -256,7 +257,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) go c.handleTitleGeneration(ctx, sessionID, content) } - userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.User, Parts: []message.ContentPart{ message.TextContent{ @@ -272,7 +273,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) for { select { case <-ctx.Done(): - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, }) @@ -280,7 +281,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) return err } assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -289,7 +290,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, }) @@ -297,13 +298,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) return err } assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } return err } - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, Model: c.model.ID, @@ -314,22 +315,22 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) for event := range eventChan { - err = c.processEvent(sessionID, &assistantMsg, event) + err = c.processEvent(ctx, sessionID, &assistantMsg, event) if err != nil { if errors.Is(err, context.Canceled) { assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } assistantMsg.AddFinish("error:" + err.Error()) - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return err } select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: } @@ -339,7 +340,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -349,13 +350,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) if err != nil { if errors.Is(err, context.Canceled) { assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } return err } - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) if len(assistantMsg.ToolCalls()) == 0 { break @@ -370,7 +371,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -383,7 +384,7 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid maxTokens := config.Get().Model.CoderMaxTokens providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || !providerConfig.Enabled { + if !ok || providerConfig.Disabled { return nil, nil, errors.New("provider is not enabled") } var agentProvider provider.Provider diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index 5deff05a8..f8e1c40a0 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -40,12 +40,13 @@ func NewCoderAgent(app *app.App) (Agent, error) { return nil, errors.New("model not supported") } - agentProvider, titleGenerator, err := getAgentProviders(app.Context, model) + ctx := context.Background() + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) if err != nil { return nil, err } - otherTools := GetMcpTools(app.Context, app.Permissions) + otherTools := GetMcpTools(ctx, app.Permissions) if len(app.LSPClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index 034e93460..c196cb107 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -24,7 +24,8 @@ func NewTaskAgent(app *app.App) (Agent, error) { return nil, errors.New("model not supported") } - agentProvider, titleGenerator, err := getAgentProviders(app.Context, model) + ctx := context.Background() + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) if err != nil { return nil, err } diff --git a/internal/message/message.go b/internal/message/message.go index 06dae13a5..2871780a7 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -20,34 +20,32 @@ type CreateMessageParams struct { type Service interface { pubsub.Suscriber[Message] - Create(sessionID string, params CreateMessageParams) (Message, error) - Update(message Message) error - Get(id string) (Message, error) - List(sessionID string) ([]Message, error) - Delete(id string) error - DeleteSessionMessages(sessionID string) error + Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) + Update(ctx context.Context, message Message) error + Get(ctx context.Context, id string) (Message, error) + List(ctx context.Context, sessionID string) ([]Message, error) + Delete(ctx context.Context, id string) error + DeleteSessionMessages(ctx context.Context, sessionID string) error } type service struct { *pubsub.Broker[Message] - q db.Querier - ctx context.Context + q db.Querier } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { return &service{ Broker: pubsub.NewBroker[Message](), q: q, - ctx: ctx, } } -func (s *service) Delete(id string) error { - message, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + message, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteMessage(s.ctx, message.ID) + err = s.q.DeleteMessage(ctx, message.ID) if err != nil { return err } @@ -55,7 +53,7 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) { +func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) { if params.Role != Assistant { params.Parts = append(params.Parts, Finish{ Reason: "stop", @@ -66,7 +64,7 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message, return Message{}, err } - dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{ + dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{ ID: uuid.New().String(), SessionID: sessionID, Role: string(params.Role), @@ -84,14 +82,14 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message, return message, nil } -func (s *service) DeleteSessionMessages(sessionID string) error { - messages, err := s.List(sessionID) +func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error { + messages, err := s.List(ctx, sessionID) if err != nil { return err } for _, message := range messages { if message.SessionID == sessionID { - err = s.Delete(message.ID) + err = s.Delete(ctx, message.ID) if err != nil { return err } @@ -100,7 +98,7 @@ func (s *service) DeleteSessionMessages(sessionID string) error { return nil } -func (s *service) Update(message Message) error { +func (s *service) Update(ctx context.Context, message Message) error { parts, err := marshallParts(message.Parts) if err != nil { return err @@ -110,7 +108,7 @@ func (s *service) Update(message Message) error { finishedAt.Int64 = f.Time finishedAt.Valid = true } - err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{ + err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ ID: message.ID, Parts: string(parts), FinishedAt: finishedAt, @@ -122,16 +120,16 @@ func (s *service) Update(message Message) error { return nil } -func (s *service) Get(id string) (Message, error) { - dbMessage, err := s.q.GetMessage(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (Message, error) { + dbMessage, err := s.q.GetMessage(ctx, id) if err != nil { return Message{}, err } return s.fromDBItem(dbMessage) } -func (s *service) List(sessionID string) ([]Message, error) { - dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID) +func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) { + dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID) if err != nil { return nil, err } diff --git a/internal/session/session.go b/internal/session/session.go index 13f420b7c..9a16224c3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -23,22 +23,21 @@ type Session struct { type Service interface { pubsub.Suscriber[Session] - Create(title string) (Session, error) - CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) - Get(id string) (Session, error) - List() ([]Session, error) - Save(session Session) (Session, error) - Delete(id string) error + Create(ctx context.Context, title string) (Session, error) + CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) + Get(ctx context.Context, id string) (Session, error) + List(ctx context.Context) ([]Session, error) + Save(ctx context.Context, session Session) (Session, error) + Delete(ctx context.Context, id string) error } type service struct { *pubsub.Broker[Session] - q db.Querier - ctx context.Context + q db.Querier } -func (s *service) Create(title string) (Session, error) { - dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{ +func (s *service) Create(ctx context.Context, title string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ ID: uuid.New().String(), Title: title, }) @@ -50,8 +49,8 @@ func (s *service) Create(title string) (Session, error) { return session, nil } -func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) { - dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{ +func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ ID: toolCallID, ParentSessionID: sql.NullString{String: parentSessionID, Valid: true}, Title: title, @@ -64,12 +63,12 @@ func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) ( return session, nil } -func (s *service) Delete(id string) error { - session, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + session, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteSession(s.ctx, session.ID) + err = s.q.DeleteSession(ctx, session.ID) if err != nil { return err } @@ -77,16 +76,16 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) Get(id string) (Session, error) { - dbSession, err := s.q.GetSessionByID(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (Session, error) { + dbSession, err := s.q.GetSessionByID(ctx, id) if err != nil { return Session{}, err } return s.fromDBItem(dbSession), nil } -func (s *service) Save(session Session) (Session, error) { - dbSession, err := s.q.UpdateSession(s.ctx, db.UpdateSessionParams{ +func (s *service) Save(ctx context.Context, session Session) (Session, error) { + dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{ ID: session.ID, Title: session.Title, PromptTokens: session.PromptTokens, @@ -101,8 +100,8 @@ func (s *service) Save(session Session) (Session, error) { return session, nil } -func (s *service) List() ([]Session, error) { - dbSessions, err := s.q.ListSessions(s.ctx) +func (s *service) List(ctx context.Context) ([]Session, error) { + dbSessions, err := s.q.ListSessions(ctx) if err != nil { return nil, err } @@ -127,11 +126,10 @@ func (s service) fromDBItem(item db.Session) Session { } } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { broker := pubsub.NewBroker[Session]() return &service{ broker, q, - ctx, } } diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index b5a361392..dc21fca29 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -1,6 +1,7 @@ package chat import ( + "context" "encoding/json" "fmt" "math" @@ -324,7 +325,7 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s innerToolCalls := make([]string, 0) if toolCall.Name == agent.AgentToolName { - messages, _ := m.app.Messages.List(toolCall.ID) + messages, _ := m.app.Messages.List(context.Background(), toolCall.ID) toolCalls := make([]message.ToolCall, 0) for _, v := range messages { toolCalls = append(toolCalls, v.ToolCalls()...) @@ -554,7 +555,7 @@ func (m *messagesCmp) GetSize() (int, int) { func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { m.session = session - messages, err := m.app.Messages.List(session.ID) + messages, err := m.app.Messages.List(context.Background(), session.ID) if err != nil { return util.ReportError(err) } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index e9493129d..b1e39e655 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -160,7 +160,7 @@ func (m *editorCmp) Send() tea.Cmd { return util.ReportWarn("Assistant is still working on the previous message") } - messages, err := m.app.Messages.List(m.sessionID) + messages, err := m.app.Messages.List(context.Background(), m.sessionID) if err != nil { return util.ReportError(err) } @@ -177,7 +177,7 @@ func (m *editorCmp) Send() tea.Cmd { if len(content) == 0 { return util.ReportWarn("Message is empty") } - ctx, cancel := context.WithCancel(m.app.Context) + ctx, cancel := context.WithCancel(context.Background()) m.cancelMessage = cancel go func() { defer cancel() diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go index 57a55c579..260be220e 100644 --- a/internal/tui/components/repl/messages.go +++ b/internal/tui/components/repl/messages.go @@ -1,6 +1,7 @@ package repl import ( + "context" "encoding/json" "fmt" "sort" @@ -77,8 +78,8 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.Payload } case SelectedSessionMsg: - m.session, _ = m.app.Sessions.Get(msg.SessionID) - m.messages, _ = m.app.Messages.List(m.session.ID) + m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID) + m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID) m.renderView() m.viewport.GotoBottom() } @@ -259,7 +260,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message. runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) allParts = append(allParts, leftPadding.Render(runningIndicator)) - taskSessionMessages, _ := m.app.Messages.List(toolCall.ID) + taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID) for _, msg := range taskSessionMessages { if msg.Role == message.Assistant { for _, toolCall := range msg.ToolCalls() { diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go index 093337b18..c83c40367 100644 --- a/internal/tui/components/repl/sessions.go +++ b/internal/tui/components/repl/sessions.go @@ -1,6 +1,7 @@ package repl import ( + "context" "fmt" "strings" @@ -57,12 +58,13 @@ var sessionKeyMapValue = sessionsKeyMap{ } func (i *sessionsCmp) Init() tea.Cmd { - existing, err := i.app.Sessions.List() + existing, err := i.app.Sessions.List(context.Background()) if err != nil { return util.ReportError(err) } if len(existing) == 0 || existing[0].MessageCount > 0 { newSession, err := i.app.Sessions.Create( + context.Background(), "New Session", ) if err != nil { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index a7a51bb84..9b9924909 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -1,6 +1,8 @@ package page import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" @@ -36,7 +38,7 @@ func (p *chatPage) Init() tea.Cmd { p.layout.Init(), } - sessions, _ := p.app.Sessions.List() + sessions, _ := p.app.Sessions.List(context.Background()) if len(sessions) > 0 { p.session = sessions[0] cmd := p.setSidebar() @@ -92,7 +94,7 @@ func (p *chatPage) clearSidebar() { func (p *chatPage) sendMessage(text string) tea.Cmd { var cmds []tea.Cmd if p.session.ID == "" { - session, err := p.app.Sessions.Create("New Session") + session, err := p.app.Sessions.Create(context.Background(), "New Session") if err != nil { return util.ReportError(err) } @@ -110,7 +112,7 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { return util.ReportError(err) } go func() { - a.Generate(p.app.Context, p.session.ID, text) + a.Generate(context.Background(), p.session.ID, text) }() return tea.Batch(cmds...) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index db9ac9ff6..1b1a1ed50 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,6 +1,8 @@ package tui import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -184,7 +186,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case key.Matches(msg, replKeyMap): if a.currentPage == page.ReplPage { - sessions, err := a.app.Sessions.List() + sessions, err := a.app.Sessions.List(context.Background()) if err != nil { return a, util.CmdHandler(util.ReportError(err)) } @@ -192,7 +194,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if lastSession.MessageCount == 0 { return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID}) } - s, err := a.app.Sessions.Create("New Session") + s, err := a.app.Sessions.Create(context.Background(), "New Session") if err != nil { return a, util.CmdHandler(util.ReportError(err)) } -- cgit v1.2.3 From cdc5f209dccdc980714f2ca1aeb52133d6e93cce Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 14:37:05 +0200 Subject: cleanup diff, cleanup agent --- README.md | 14 +- cmd/diff/main.go | 102 ------- cmd/root.go | 12 +- go.mod | 2 +- internal/app/app.go | 20 +- internal/assets/diff/themes/dark.json | 73 ----- internal/assets/embed.go | 6 - internal/assets/write.go | 60 ---- internal/git/diff.go | 35 ++- internal/llm/agent/agent-tool.go | 34 ++- internal/llm/agent/agent.go | 522 +++++++++++++++++++++------------ internal/llm/agent/coder.go | 83 +++--- internal/llm/agent/task.go | 7 +- internal/llm/provider/provider.go | 4 +- internal/llm/tools/edit.go | 7 +- internal/llm/tools/tools.go | 2 +- internal/llm/tools/write.go | 2 +- internal/tui/components/repl/editor.go | 8 +- internal/tui/page/chat.go | 15 +- 19 files changed, 456 insertions(+), 552 deletions(-) delete mode 100644 cmd/diff/main.go delete mode 100644 internal/assets/diff/themes/dark.json delete mode 100644 internal/assets/embed.go delete mode 100644 internal/assets/write.go (limited to 'cmd') diff --git a/README.md b/README.md index ebef72cad..23a1906a1 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ termai -d ### Keyboard Shortcuts #### Global Shortcuts + - `?`: Toggle help panel - `Ctrl+C` or `q`: Quit application - `L`: View logs @@ -60,10 +61,12 @@ termai -d - `Esc`: Close current view/dialog or return to normal mode #### Session Management + - `N`: Create new session - `Enter` or `Space`: Select session (in sessions list) #### Editor Shortcuts (Vim-like) + - `i`: Enter insert mode - `Esc`: Enter normal mode - `v`: Enter visual mode @@ -72,6 +75,7 @@ termai -d - `Ctrl+S`: Send message (in insert mode) #### Navigation + - Arrow keys: Navigate through lists and content - Page Up/Down: Scroll through content @@ -112,16 +116,6 @@ go build -o termai ./termai ``` -### Important: Building the Diff Script - -Before building or running the application, you must first build the diff script by running: - -```bash -go run cmd/diff/main.go -``` - -This command generates the necessary JavaScript file (`index.mjs`) used by the diff functionality in the application. - ## Acknowledgments TermAI builds upon the work of several open source projects and developers: diff --git a/cmd/diff/main.go b/cmd/diff/main.go deleted file mode 100644 index da93e4660..000000000 --- a/cmd/diff/main.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - "os/exec" - "path/filepath" -) - -func main() { - // Create a temporary directory - tempDir, err := os.MkdirTemp("", "git-split-diffs") - if err != nil { - fmt.Printf("Error creating temp directory: %v\n", err) - os.Exit(1) - } - defer func() { - fmt.Printf("Cleaning up temporary directory: %s\n", tempDir) - os.RemoveAll(tempDir) - }() - fmt.Printf("Created temporary directory: %s\n", tempDir) - - // Clone the repository with minimum depth - fmt.Println("Cloning git-split-diffs repository with minimum depth...") - cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - fmt.Printf("Error cloning repository: %v\n", err) - os.Exit(1) - } - - // Run npm install - fmt.Println("Running npm install...") - cmdNpmInstall := exec.Command("npm", "install") - cmdNpmInstall.Dir = tempDir - cmdNpmInstall.Stdout = os.Stdout - cmdNpmInstall.Stderr = os.Stderr - if err := cmdNpmInstall.Run(); err != nil { - fmt.Printf("Error running npm install: %v\n", err) - os.Exit(1) - } - - // Run npm run build - fmt.Println("Running npm run build...") - cmdNpmBuild := exec.Command("npm", "run", "build") - cmdNpmBuild.Dir = tempDir - cmdNpmBuild.Stdout = os.Stdout - cmdNpmBuild.Stderr = os.Stderr - if err := cmdNpmBuild.Run(); err != nil { - fmt.Printf("Error running npm run build: %v\n", err) - os.Exit(1) - } - - destDir := filepath.Join(".", "internal", "assets", "diff") - destFile := filepath.Join(destDir, "index.mjs") - - // Make sure the destination directory exists - if err := os.MkdirAll(destDir, 0o755); err != nil { - fmt.Printf("Error creating destination directory: %v\n", err) - os.Exit(1) - } - - // Copy the file - srcFile := filepath.Join(tempDir, "build", "index.mjs") - fmt.Printf("Copying %s to %s\n", srcFile, destFile) - if err := copyFile(srcFile, destFile); err != nil { - fmt.Printf("Error copying file: %v\n", err) - os.Exit(1) - } - - fmt.Println("Successfully completed the process!") -} - -// copyFile copies a file from src to dst -func copyFile(src, dst string) error { - sourceFile, err := os.Open(src) - if err != nil { - return err - } - defer sourceFile.Close() - - destFile, err := os.Create(dst) - if err != nil { - return err - } - defer destFile.Close() - - _, err = io.Copy(destFile, sourceFile) - if err != nil { - return err - } - - // Make sure the file is written to disk - err = destFile.Sync() - if err != nil { - return err - } - - return nil -} diff --git a/cmd/root.go b/cmd/root.go index 092606de7..a2e63006f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,7 +9,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/assets" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -52,11 +51,6 @@ var rootCmd = &cobra.Command{ return err } - err = assets.WriteAssets() - if err != nil { - logging.Error("Error writing assets: %v", err) - } - // Connect DB, this will also run migrations conn, err := db.Connect() if err != nil { @@ -67,7 +61,11 @@ var rootCmd = &cobra.Command{ ctx, cancel := context.WithCancel(context.Background()) defer cancel() - app := app.New(ctx, conn) + app, err := app.New(ctx, conn) + if err != nil { + logging.Error("Failed to create app: %v", err) + return err + } // Set up the TUI zone.NewGlobal() diff --git a/go.mod b/go.mod index 617dad3a1..e3dc2bd96 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,6 @@ require ( github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.16.0 github.com/openai/openai-go v0.1.0-beta.2 - github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 @@ -107,6 +106,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect diff --git a/internal/app/app.go b/internal/app/app.go index fa4a6ee90..9f575cac3 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -22,6 +23,8 @@ type App struct { Files history.Service Permissions permission.Service + CoderAgent agent.Service + LSPClients map[string]*lsp.Client clientsMutex sync.RWMutex @@ -31,7 +34,7 @@ type App struct { watcherWG sync.WaitGroup } -func New(ctx context.Context, conn *sql.DB) *App { +func New(ctx context.Context, conn *sql.DB) (*App, error) { q := db.New(conn) sessions := session.NewService(q) messages := message.NewService(q) @@ -45,9 +48,22 @@ func New(ctx context.Context, conn *sql.DB) *App { LSPClients: make(map[string]*lsp.Client), } + var err error + app.CoderAgent, err = agent.NewCoderAgent( + + app.Permissions, + app.Sessions, + app.Messages, + app.LSPClients, + ) + if err != nil { + logging.Error("Failed to create coder agent", err) + return nil, err + } + app.initLSPClients(ctx) - return app + return app, nil } // Shutdown performs a clean shutdown of the application diff --git a/internal/assets/diff/themes/dark.json b/internal/assets/diff/themes/dark.json deleted file mode 100644 index 05c18e08c..000000000 --- a/internal/assets/diff/themes/dark.json +++ /dev/null @@ -1,73 +0,0 @@ -{ - "SYNTAX_HIGHLIGHTING_THEME": "dark-plus", - "DEFAULT_COLOR": { - "color": "#ffffff", - "backgroundColor": "#212121" - }, - "COMMIT_HEADER_COLOR": { - "color": "#cccccc" - }, - "COMMIT_HEADER_LABEL_COLOR": { - "color": "#00000022" - }, - "COMMIT_SHA_COLOR": { - "color": "#00eeaa" - }, - "COMMIT_AUTHOR_COLOR": { - "color": "#00aaee" - }, - "COMMIT_DATE_COLOR": { - "color": "#cccccc" - }, - "COMMIT_MESSAGE_COLOR": { - "color": "#cccccc" - }, - "COMMIT_TITLE_COLOR": { - "modifiers": [ - "bold" - ] - }, - "FILE_NAME_COLOR": { - "color": "#ffdd99" - }, - "BORDER_COLOR": { - "color": "#ffdd9966", - "modifiers": [ - "dim" - ] - }, - "HUNK_HEADER_COLOR": { - "modifiers": [ - "dim" - ] - }, - "DELETED_WORD_COLOR": { - "color": "#ffcccc", - "backgroundColor": "#ff000033" - }, - "INSERTED_WORD_COLOR": { - "color": "#ccffcc", - "backgroundColor": "#00ff0033" - }, - "DELETED_LINE_NO_COLOR": { - "color": "#00000022", - "backgroundColor": "#00000022" - }, - "INSERTED_LINE_NO_COLOR": { - "color": "#00000022", - "backgroundColor": "#00000022" - }, - "UNMODIFIED_LINE_NO_COLOR": { - "color": "#666666" - }, - "DELETED_LINE_COLOR": { - "color": "#cc6666", - "backgroundColor": "#3a3030" - }, - "INSERTED_LINE_COLOR": { - "color": "#66cc66", - "backgroundColor": "#303a30" - }, - "UNMODIFIED_LINE_COLOR": {}, - "MISSING_LINE_COLOR": {} -} diff --git a/internal/assets/embed.go b/internal/assets/embed.go deleted file mode 100644 index 9e1316d08..000000000 --- a/internal/assets/embed.go +++ /dev/null @@ -1,6 +0,0 @@ -package assets - -import "embed" - -//go:embed diff -var FS embed.FS diff --git a/internal/assets/write.go b/internal/assets/write.go deleted file mode 100644 index 602b589ce..000000000 --- a/internal/assets/write.go +++ /dev/null @@ -1,60 +0,0 @@ -package assets - -import ( - "os" - "path/filepath" - - "github.com/kujtimiihoxha/termai/internal/config" -) - -func WriteAssets() error { - appCfg := config.Get() - appWd := config.WorkingDirectory() - scriptDir := filepath.Join( - appWd, - appCfg.Data.Directory, - "diff", - ) - scriptPath := filepath.Join(scriptDir, "index.mjs") - // Before, run the script in cmd/diff/main.go to build this file - if _, err := os.Stat(scriptPath); err != nil { - scriptData, err := FS.ReadFile("diff/index.mjs") - if err != nil { - return err - } - - err = os.MkdirAll(scriptDir, 0o755) - if err != nil { - return err - } - err = os.WriteFile(scriptPath, scriptData, 0o755) - if err != nil { - return err - } - } - - themeDir := filepath.Join( - appWd, - appCfg.Data.Directory, - "themes", - ) - - themePath := filepath.Join(themeDir, "dark.json") - - if _, err := os.Stat(themePath); err != nil { - themeData, err := FS.ReadFile("diff/themes/dark.json") - if err != nil { - return err - } - - err = os.MkdirAll(themeDir, 0o755) - if err != nil { - return err - } - err = os.WriteFile(themePath, themeData, 0o755) - if err != nil { - return err - } - } - return nil -} diff --git a/internal/git/diff.go b/internal/git/diff.go index d87956f01..2ab139642 100644 --- a/internal/git/diff.go +++ b/internal/git/diff.go @@ -11,7 +11,6 @@ import ( "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/kujtimiihoxha/termai/internal/config" ) type DiffStats struct { @@ -197,32 +196,32 @@ func isSplitDiffsAvailable() bool { } func formatWithSplitDiffs(diffText string, width int) (string, error) { - var cmd *exec.Cmd + args := []string{ + "--color", + } - appCfg := config.Get() - appWd := config.WorkingDirectory() - script := filepath.Join( - appWd, - appCfg.Data.Directory, - "diff", - "index.mjs", - ) + var diffCmd *exec.Cmd - cmd = exec.Command("node", script, "--color") + if _, err := exec.LookPath("git-split-diffs-opencode"); err == nil { + fullArgs := append([]string{"git-split-diffs-opencode"}, args...) + diffCmd = exec.Command(fullArgs[0], fullArgs[1:]...) + } else { + npxArgs := append([]string{"git-split-diffs-opencode"}, args...) + diffCmd = exec.Command("npx", npxArgs...) + } - cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width)) + diffCmd.Env = append(os.Environ(), fmt.Sprintf("DIFF_COLUMNS=%d", width)) - cmd.Stdin = strings.NewReader(diffText) + diffCmd.Stdin = strings.NewReader(diffText) var out bytes.Buffer - cmd.Stdout = &out + diffCmd.Stdout = &out var stderr bytes.Buffer - cmd.Stderr = &stderr + diffCmd.Stderr = &stderr - err := cmd.Run() - if err != nil { - return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String()) + if err := diffCmd.Run(); err != nil { + return "", fmt.Errorf("git-split-diffs-opencode error: %w, stderr: %s", err, stderr.String()) } return out.String(), nil diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 91c46da8b..a9c6f93a7 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,14 +5,16 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) type agentTool struct { - parentSessionID string - app *app.App + sessions session.Service + messages message.Service + lspClients map[string]*lsp.Client } const ( @@ -46,12 +48,17 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("prompt is required"), nil } - agent, err := NewTaskAgent(b.app) + sessionID, messageID := tools.GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return tools.NewTextErrorResponse("session ID and message ID are required"), nil + } + + agent, err := NewTaskAgent(b.lspClients) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil } - session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session") + session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } @@ -61,7 +68,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } - messages, err := b.app.Messages.List(ctx, session.ID) + messages, err := b.messages.List(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil } @@ -74,11 +81,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("no assistant message found"), nil } - updatedSession, err := b.app.Sessions.Get(ctx, session.ID) + updatedSession, err := b.sessions.Get(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID) + parentSession, err := b.sessions.Get(ctx, sessionID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } @@ -87,16 +94,19 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes parentSession.PromptTokens += updatedSession.PromptTokens parentSession.CompletionTokens += updatedSession.CompletionTokens - _, err = b.app.Sessions.Save(ctx, parentSession) + _, err = b.sessions.Save(ctx, parentSession) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } return tools.NewTextResponse(response.Content().String()), nil } -func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool { +func NewAgentTool( + Sessions session.Service, + Messages message.Service, +) tools.BaseTool { return &agentTool{ - parentSessionID: parentSessionID, - app: app, + sessions: Sessions, + messages: Messages, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b7c736e6c..997004e12 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/prompt" @@ -15,22 +14,118 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) -type Agent interface { +// Common errors +var ( + ErrProviderNotEnabled = errors.New("provider is not enabled") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") +) + +// Service defines the interface for generating responses +type Service interface { Generate(ctx context.Context, sessionID string, content string) error + Cancel(sessionID string) error } type agent struct { - *app.App + sessions session.Service + messages message.Service model models.Model tools []tools.BaseTool agent provider.Provider titleGenerator provider.Provider + activeRequests sync.Map // map[sessionID]context.CancelFunc +} + +// NewAgent creates a new agent instance with the given model and tools +func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to initialize providers: %w", err) + } + + return &agent{ + model: model, + tools: tools, + sessions: sessions, + messages: messages, + agent: agentProvider, + titleGenerator: titleGenerator, + activeRequests: sync.Map{}, + }, nil +} + +// Cancel cancels an active request by session ID +func (a *agent) Cancel(sessionID string) error { + if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { + if cancel, ok := cancelFunc.(context.CancelFunc); ok { + logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) + cancel() + return nil + } + } + return errors.New("no active request found for this session") } -func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := c.titleGenerator.SendMessages( +// Generate starts the generation process +func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { + // Check if this session already has an active request + if _, busy := a.activeRequests.Load(sessionID); busy { + return ErrSessionBusy + } + + // Create a cancellable context + genCtx, cancel := context.WithCancel(ctx) + + // Store cancel function to allow user cancellation + a.activeRequests.Store(sessionID, cancel) + + // Launch the generation in a goroutine + go func() { + defer func() { + if r := recover(); r != nil { + logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) + } + }() + defer a.activeRequests.Delete(sessionID) + defer cancel() + + if err := a.generate(genCtx, sessionID, content); err != nil { + if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { + // Log the error (avoid logging cancellations as they're expected) + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) + + // You may want to create an error message in the chat + bgCtx := context.Background() + errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) + _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ + Role: message.System, + Parts: []message.ContentPart{ + message.TextContent{ + Text: errorMsg, + }, + }, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) + } + } + } + }() + + return nil +} + +// IsSessionBusy checks if a session currently has an active request +func (a *agent) IsSessionBusy(sessionID string) bool { + _, busy := a.activeRequests.Load(sessionID) + return busy +} // handleTitleGeneration asynchronously generates a title for new sessions +func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { + response, err := a.titleGenerator.SendMessages( ctx, []message.Message{ { @@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st nil, ) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) return } - session, err := c.Sessions.Get(ctx, sessionID) + session, err := a.sessions.Get(ctx, sessionID) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) return } + if response.Content != "" { - session.Title = response.Content - session.Title = strings.TrimSpace(session.Title) + session.Title = strings.TrimSpace(response.Content) session.Title = strings.ReplaceAll(session.Title, "\n", " ") - c.Sessions.Save(ctx, session) + if _, err := a.sessions.Save(ctx, session); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) + } } } -func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := c.Sessions.Get(ctx, sessionID) +// TrackUsage updates token usage statistics for the session +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + session, err := a.sessions.Get(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to get session: %w", err) } cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + @@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M session.CompletionTokens += usage.OutputTokens session.PromptTokens += usage.InputTokens - _, err = c.Sessions.Save(ctx, session) - return err + _, err = a.sessions.Save(ctx, session) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) + } + return nil } -func (c *agent) processEvent( +// processEvent handles different types of events during generation +func (a *agent) processEvent( ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent, ) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing + } + switch event.Type { case provider.EventThinkingDelta: assistantMsg.AppendReasoningContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { - return nil + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled } logging.ErrorPersist(event.Error.Error()) return event.Error case provider.EventWarning: logging.WarnPersist(event.Info) - return nil case provider.EventInfo: logging.InfoPersist(event.Info) case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) - err := c.Messages.Update(ctx, *assistantMsg) - if err != nil { - return err + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) } - return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage) + return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } return nil } -func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - var wg sync.WaitGroup +// ExecuteTools runs all tool calls sequentially and returns the results +func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { toolResults := make([]message.ToolResult, len(toolCalls)) - mutex := &sync.Mutex{} - errChan := make(chan error, 1) // Create a child context that can be canceled ctx, cancel := context.WithCancel(ctx) defer cancel() - for i, tc := range toolCalls { - wg.Add(1) - go func(index int, toolCall message.ToolCall) { - defer wg.Done() + // Check if already canceled before starting any execution + if ctx.Err() != nil { + // Mark all tools as canceled + for i, toolCall := range toolCalls { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + return toolResults, ctx.Err() + } - // Check if context is already canceled - select { - case <-ctx.Done(): - mutex.Lock() - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled", + for i, toolCall := range toolCalls { + // Check for cancellation before executing each tool + select { + case <-ctx.Done(): + // Mark this and all remaining tools as canceled + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", IsError: true, } - mutex.Unlock() - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - return - default: } + return toolResults, ctx.Err() + default: + // Continue processing + } - response := "" - isError := false - found := false - - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled" - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - } else { - response = fmt.Sprintf("error running tool: %s", toolErr) - } - isError = true + response := "" + isError := false + found := false + + // Find and execute the appropriate tool + for _, tool := range tls { + if tool.Info().Name == toolCall.Name { + found = true + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, + }) + + if toolErr != nil { + if errors.Is(toolErr, context.Canceled) { + response = "Tool execution canceled by user" } else { - response = toolResult.Content - isError = toolResult.IsError + response = fmt.Sprintf("Error running tool: %s", toolErr) } - break + isError = true + } else { + response = toolResult.Content + isError = toolResult.IsError } + break } + } - if !found { - response = fmt.Sprintf("tool not found: %s", toolCall.Name) - isError = true - } - - mutex.Lock() - defer mutex.Unlock() - - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - }(i, tc) - } - - // Wait for all goroutines to finish or context to be canceled - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() + if !found { + response = fmt.Sprintf("Tool not found: %s", toolCall.Name) + isError = true + } - select { - case <-done: - // All tools completed successfully - case err := <-errChan: - // One of the tools encountered a cancellation - return toolResults, err - case <-ctx.Done(): - // Context was canceled externally - return toolResults, ctx.Err() + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: response, + IsError: isError, + } } return toolResults, nil } -func (c *agent) handleToolExecution( +// handleToolExecution processes tool calls and creates tool result messages +func (a *agent) handleToolExecution( ctx context.Context, assistantMsg message.Message, ) (*message.Message, error) { + select { + case <-ctx.Done(): + // If cancelled, create tool results that indicate cancellation + if len(assistantMsg.ToolCalls()) > 0 { + toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) + for _, tc := range assistantMsg.ToolCalls() { + toolResults = append(toolResults, message.ToolResult{ + ToolCallID: tc.ID, + Content: "Tool execution canceled by user", + IsError: true, + }) + } + + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if err != nil { + return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) + } + return &msg, ctx.Err() + } + return nil, ctx.Err() + default: + // Continue processing + } + if len(assistantMsg.ToolCalls()) == 0 { return nil, nil } - toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools) + toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) if err != nil { + // If error is from cancellation, still return the partial results we have + if errors.Is(err, context.Canceled) { + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + + msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) + return nil, err + } + return &msg, err + } return nil, err } - parts := make([]message.ContentPart, 0) + + parts := make([]message.ContentPart, 0, len(toolResults)) for _, toolResult := range toolResults { parts = append(parts, toolResult) } - msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + + msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) + if err != nil { + return nil, fmt.Errorf("failed to create tool message: %w", err) + } - return &msg, err + return &msg, nil } -func (c *agent) generate(ctx context.Context, sessionID string, content string) error { +// generate handles the main generation workflow +func (a *agent) generate(ctx context.Context, sessionID string, content string) error { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - messages, err := c.Messages.List(ctx, sessionID) + + // Handle context cancellation at any point + if err := ctx.Err(); err != nil { + return ErrRequestCancelled + } + + messages, err := a.messages.List(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to list messages: %w", err) } if len(messages) == 0 { - go c.handleTitleGeneration(ctx, sessionID, content) + titleCtx := context.Background() + go a.handleTitleGeneration(titleCtx, sessionID, content) } - userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.User, Parts: []message.ContentPart{ message.TextContent{ @@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) }, }) if err != nil { - return err + return fmt.Errorf("failed to create user message: %w", err) } messages = append(messages, userMsg) + for { + // Check for cancellation before each iteration select { case <-ctx.Done(): - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: // Continue processing } - eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) + eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled } - return err + return fmt.Errorf("failed to stream response: %w", err) } - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: c.model.ID, + Model: a.model.ID, }) if err != nil { - return err + return fmt.Errorf("failed to create assistant message: %w", err) } ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + + // Process events from the LLM provider for event := range eventChan { - err = c.processEvent(ctx, sessionID, &assistantMsg, event) - if err != nil { + if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { if errors.Is(err, context.Canceled) { + // Mark as canceled but don't create separate message assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } assistantMsg.AddFinish("error:" + err.Error()) - c.Messages.Update(ctx, assistantMsg) - return err + _ = a.messages.Update(ctx, assistantMsg) + return fmt.Errorf("event processing error: %w", err) } + // Check for cancellation during event processing select { case <-ctx.Done(): + // Mark as canceled assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: } } - // Check for context cancellation before tool execution + // Check for cancellation before tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: - // Continue processing } - msg, err := c.handleToolExecution(ctx, assistantMsg) + // Execute any tool calls + toolMsg, err := a.handleToolExecution(ctx, assistantMsg) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } - return err + return fmt.Errorf("tool execution error: %w", err) } - c.Messages.Update(ctx, assistantMsg) + if err := a.messages.Update(ctx, assistantMsg); err != nil { + return fmt.Errorf("failed to update assistant message: %w", err) + } + // If no tool calls, we're done if len(assistantMsg.ToolCalls()) == 0 { break } + // Add messages for next iteration messages = append(messages, assistantMsg) - if msg != nil { - messages = append(messages, *msg) + if toolMsg != nil { + messages = append(messages, *toolMsg) } - // Check for context cancellation after tool execution + // Check for cancellation after tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: - // Continue processing } } + return nil } +// getAgentProviders initializes the LLM providers based on the chosen model func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { maxTokens := config.Get().Model.CoderMaxTokens providerConfig, ok := config.Get().Providers[model.Provider] if !ok || providerConfig.Disabled { - return nil, nil, errors.New("provider is not enabled") + return nil, nil, ErrProviderNotEnabled } + var agentProvider provider.Provider var titleGenerator provider.Provider + var err error switch model.Provider { case models.ProviderOpenAI: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderOpenAISystemPrompt(), @@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) } + case models.ProviderAnthropic: - var err error agentProvider, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) } + titleGenerator, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.TitlePrompt(), @@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) } case models.ProviderGemini: - var err error agentProvider, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) } + titleGenerator, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) } + case models.ProviderGROQ: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) } case models.ProviderBedrock: - var err error agentProvider, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) } + titleGenerator, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.TitlePrompt(), ), - provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockMaxTokens(80), provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) } - + default: + return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) } return agentProvider, titleGenerator, nil diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index f8e1c40a0..8eea57041 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -4,71 +4,60 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" ) type coderAgent struct { - *agent + Service } -func (c *coderAgent) setAgentTool(sessionID string) { - inx := -1 - for i, tool := range c.tools { - if tool.Info().Name == AgentToolName { - inx = i - break - } - } - if inx == -1 { - c.tools = append(c.tools, NewAgentTool(sessionID, c.App)) - } else { - c.tools[inx] = NewAgentTool(sessionID, c.App) - } -} - -func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error { - c.setAgentTool(sessionID) - return c.generate(ctx, sessionID, content) -} - -func NewCoderAgent(app *app.App) (Agent, error) { +func NewCoderAgent( + permissions permission.Service, + sessions session.Service, + messages message.Service, + lspClients map[string]*lsp.Client, +) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") } ctx := context.Background() - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + agent, err := NewAgent( + ctx, + sessions, + messages, + model, + append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions), + NewAgentTool(sessions, messages), + }, otherTools..., + ), + ) if err != nil { return nil, err } - otherTools := GetMcpTools(ctx, app.Permissions) - if len(app.LSPClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) - } return &coderAgent{ - agent: &agent{ - App: app, - tools: append( - []tools.BaseTool{ - tools.NewBashTool(app.Permissions), - tools.NewEditTool(app.LSPClients, app.Permissions), - tools.NewFetchTool(app.Permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), - tools.NewWriteTool(app.LSPClients, app.Permissions), - }, otherTools..., - ), - model: model, - agent: agentProvider, - titleGenerator: titleGenerator, - }, + agent, }, nil } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index c196cb107..0a072044c 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -4,10 +4,10 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" ) type taskAgent struct { @@ -18,7 +18,7 @@ func (c *taskAgent) Generate(ctx context.Context, sessionID string, content stri return c.generate(ctx, sessionID, content) } -func NewTaskAgent(app *app.App) (Agent, error) { +func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") @@ -31,13 +31,12 @@ func NewTaskAgent(app *app.App) (Agent, error) { } return &taskAgent{ agent: &agent{ - App: app, tools: []tools.BaseTool{ tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), + tools.NewViewTool(lspClients), }, model: model, agent: agentProvider, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 938a8c0ad..34d91f2b7 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -57,7 +57,9 @@ func cleanupMessages(messages []message.Message) []message.Message { // First pass: filter out canceled messages var cleanedMessages []message.Message for _, msg := range messages { - if msg.FinishReason() != "canceled" { + if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { + // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been + // cancelled cleanedMessages = append(cleanedMessages, msg) } } diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index c9a0be079..647b8d35f 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -190,7 +190,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return er, fmt.Errorf("failed to create parent directories: %w", err) } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") } @@ -277,7 +277,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string newContent := oldContent[:index] + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -365,7 +365,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -409,4 +409,3 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return er, nil } - diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 473b787bb..07afe1363 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -66,7 +66,7 @@ type BaseTool interface { Run(ctx context.Context, params ToolCall) (ToolResponse, error) } -func getContextValues(ctx context.Context) (string, string) { +func GetContextValues(ctx context.Context) (string, string) { sessionID := ctx.Value(SessionIDContextKey) messageID := ctx.Value(MessageIDContextKey) if sessionID == nil { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 27c98bb9d..1b087c193 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -144,7 +144,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return NewTextErrorResponse("session ID or message ID is missing"), nil } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index b1e39e655..b659775e0 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -7,7 +7,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" @@ -168,11 +167,6 @@ func (m *editorCmp) Send() tea.Cmd { return util.ReportWarn("Assistant is still working on the previous message") } - a, err := agent.NewCoderAgent(m.app) - if err != nil { - return util.ReportError(err) - } - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") if len(content) == 0 { return util.ReportWarn("Message is empty") @@ -181,7 +175,7 @@ func (m *editorCmp) Send() tea.Cmd { m.cancelMessage = cancel go func() { defer cancel() - a.Generate(ctx, m.sessionID, content) + m.app.CoderAgent.Generate(ctx, m.sessionID, content) m.cancelMessage = nil }() diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 9b9924909..439c89e1f 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -6,7 +6,6 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/components/chat" "github.com/kujtimiihoxha/termai/internal/tui/layout" @@ -23,6 +22,7 @@ type chatPage struct { type ChatKeyMap struct { NewSession key.Binding + Cancel key.Binding } var keyMap = ChatKeyMap{ @@ -30,6 +30,10 @@ var keyMap = ChatKeyMap{ key.WithKeys("ctrl+n"), key.WithHelp("ctrl+n", "new session"), ), + Cancel: key.NewBinding( + key.WithKeys("ctrl+x"), + key.WithHelp("ctrl+x", "cancel"), + ), } func (p *chatPage) Init() tea.Cmd { @@ -106,15 +110,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { } cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - // TODO: move this to a service - a, err := agent.NewCoderAgent(p.app) - if err != nil { - return util.ReportError(err) - } - go func() { - a.Generate(context.Background(), p.session.ID, text) - }() + p.app.CoderAgent.Generate(context.Background(), p.session.ID, text) return tea.Batch(cmds...) } -- cgit v1.2.3 From bbfa60c787f2ec459f1689b9a650ddbec9693ed9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 20:06:23 +0200 Subject: reimplement agent,provider and add file history --- .opencode.json | 4 - README.md | 34 +- cmd/root.go | 24 +- go.mod | 8 +- go.sum | 14 - internal/app/app.go | 17 +- internal/app/lsp.go | 19 +- internal/config/config.go | 108 +++- internal/db/files.sql.go | 4 +- internal/db/sql/files.sql | 4 +- internal/diff/diff.go | 99 ++- internal/llm/agent/agent-tool.go | 18 +- internal/llm/agent/agent.go | 861 ++++++++++----------------- internal/llm/agent/coder.go | 63 -- internal/llm/agent/mcp-tools.go | 4 +- internal/llm/agent/task.go | 47 -- internal/llm/agent/tools.go | 50 ++ internal/llm/models/anthropic.go | 71 +++ internal/llm/models/models.go | 190 +++--- internal/llm/prompt/coder.go | 28 +- internal/llm/prompt/prompt.go | 19 + internal/llm/prompt/task.go | 5 +- internal/llm/prompt/title.go | 4 +- internal/llm/provider/anthropic.go | 531 +++++++++-------- internal/llm/provider/bedrock.go | 101 ++-- internal/llm/provider/gemini.go | 533 +++++++++++------ internal/llm/provider/openai.go | 401 ++++++++----- internal/llm/provider/provider.go | 169 ++++-- internal/llm/tools/bash.go | 7 +- internal/llm/tools/bash_test.go | 31 - internal/llm/tools/edit.go | 75 ++- internal/llm/tools/edit_test.go | 30 +- internal/llm/tools/file.go | 10 - internal/llm/tools/glob.go | 4 +- internal/llm/tools/grep.go | 4 +- internal/llm/tools/ls.go | 4 +- internal/llm/tools/mocks_test.go | 246 ++++++++ internal/llm/tools/shell/shell.go | 12 +- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/tools.go | 9 +- internal/llm/tools/write.go | 27 +- internal/llm/tools/write_test.go | 22 +- internal/logging/logger.go | 41 +- internal/lsp/client.go | 13 +- internal/lsp/handlers.go | 2 +- internal/lsp/transport.go | 28 +- internal/lsp/watcher/watcher.go | 18 +- internal/message/content.go | 30 +- internal/pubsub/broker.go | 2 +- internal/session/session.go | 15 + internal/tui/components/chat/chat.go | 2 - internal/tui/components/chat/editor.go | 22 +- internal/tui/components/chat/messages.go | 205 ++++++- internal/tui/components/chat/sidebar.go | 176 +++++- internal/tui/components/core/dialog.go | 117 ---- internal/tui/components/core/help.go | 119 ---- internal/tui/components/core/status.go | 90 ++- internal/tui/components/dialog/help.go | 182 ++++++ internal/tui/components/dialog/permission.go | 682 +++++++++++---------- internal/tui/components/dialog/quit.go | 156 +++-- internal/tui/components/logs/details.go | 2 - internal/tui/components/logs/table.go | 22 - internal/tui/components/repl/editor.go | 201 ------- internal/tui/components/repl/messages.go | 513 ---------------- internal/tui/components/repl/sessions.go | 249 -------- internal/tui/layout/overlay.go | 11 +- internal/tui/layout/split.go | 1 + internal/tui/page/chat.go | 32 +- internal/tui/page/init.go | 308 ---------- internal/tui/page/logs.go | 17 + internal/tui/page/repl.go | 21 - internal/tui/tui.go | 277 ++++----- main.go | 7 + 73 files changed, 3595 insertions(+), 3879 deletions(-) delete mode 100644 internal/llm/agent/coder.go delete mode 100644 internal/llm/agent/task.go create mode 100644 internal/llm/agent/tools.go create mode 100644 internal/llm/models/anthropic.go create mode 100644 internal/llm/prompt/prompt.go create mode 100644 internal/llm/tools/mocks_test.go delete mode 100644 internal/tui/components/core/dialog.go delete mode 100644 internal/tui/components/core/help.go create mode 100644 internal/tui/components/dialog/help.go delete mode 100644 internal/tui/components/repl/editor.go delete mode 100644 internal/tui/components/repl/messages.go delete mode 100644 internal/tui/components/repl/sessions.go delete mode 100644 internal/tui/page/init.go delete mode 100644 internal/tui/page/repl.go (limited to 'cmd') diff --git a/.opencode.json b/.opencode.json index f63a63dba..b7fc19b52 100644 --- a/.opencode.json +++ b/.opencode.json @@ -1,8 +1,4 @@ { - "model": { - "coder": "claude-3.7-sonnet", - "coderMaxTokens": 20000 - }, "lsp": { "gopls": { "command": "gopls" diff --git a/README.md b/README.md index 23a1906a1..564284c7f 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# TermAI +# OpenCode > **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk. A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal. -[![TermAI Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) +[![OpenCode Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) ## Overview -TermAI is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. +OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. ## Features @@ -23,16 +23,16 @@ TermAI is a Go-based CLI application that brings AI assistance to your terminal. ```bash # Coming soon -go install github.com/kujtimiihoxha/termai@latest +go install github.com/kujtimiihoxha/opencode@latest ``` ## Configuration -TermAI looks for configuration in the following locations: +OpenCode looks for configuration in the following locations: -- `$HOME/.termai.json` -- `$XDG_CONFIG_HOME/termai/.termai.json` -- `./.termai.json` (local directory) +- `$HOME/.opencode.json` +- `$XDG_CONFIG_HOME/opencode/.opencode.json` +- `./.opencode.json` (local directory) You can also use environment variables: @@ -43,11 +43,11 @@ You can also use environment variables: ## Usage ```bash -# Start TermAI -termai +# Start OpenCode +opencode # Start with debug logging -termai -d +opencode -d ``` ### Keyboard Shortcuts @@ -81,7 +81,7 @@ termai -d ## Architecture -TermAI is built with a modular architecture: +OpenCode is built with a modular architecture: - **cmd**: Command-line interface using Cobra - **internal/app**: Core application services @@ -103,22 +103,22 @@ TermAI is built with a modular architecture: ```bash # Clone the repository -git clone https://github.com/kujtimiihoxha/termai.git -cd termai +git clone https://github.com/kujtimiihoxha/opencode.git +cd opencode # Build the diff script first go run cmd/diff/main.go # Build -go build -o termai +go build -o opencode # Run -./termai +./opencode ``` ## Acknowledgments -TermAI builds upon the work of several open source projects and developers: +OpenCode builds upon the work of several open source projects and developers: - [@isaacphi](https://github.com/isaacphi) - LSP client implementation diff --git a/cmd/root.go b/cmd/root.go index a2e63006f..ff71747d5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,7 +20,7 @@ import ( ) var rootCmd = &cobra.Command{ - Use: "termai", + Use: "OpenCode", Short: "A terminal ai assistant", Long: `A terminal ai assistant`, RunE: func(cmd *cobra.Command, args []string) error { @@ -89,12 +89,9 @@ var rootCmd = &cobra.Command{ // Set up message handling for the TUI go func() { defer tuiWg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in TUI message handling: %v", r) - attemptTUIRecovery(program) - } - }() + defer logging.RecoverPanic("TUI-message-handler", func() { + attemptTUIRecovery(program) + }) for { select { @@ -153,11 +150,7 @@ func attemptTUIRecovery(program *tea.Program) { func initMCPTools(ctx context.Context, app *app.App) { go func() { - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in MCP goroutine: %v", r) - } - }() + defer logging.RecoverPanic("MCP-goroutine", nil) // Create a context with timeout for the initial MCP tools fetch ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -179,11 +172,7 @@ func setupSubscriber[T any]( wg.Add(1) go func() { defer wg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in %s subscription goroutine: %v", name, r) - } - }() + defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil) for { select { @@ -232,6 +221,7 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { // Wait with a timeout for all goroutines to complete waitCh := make(chan struct{}) go func() { + defer logging.RecoverPanic("subscription-cleanup", nil) wg.Wait() close(waitCh) }() diff --git a/go.mod b/go.mod index 925a71097..16c88d3a6 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 github.com/google/uuid v1.6.0 - github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 github.com/mark3labs/mcp-go v0.17.0 github.com/mattn/go-runewidth v0.0.16 @@ -36,7 +35,6 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 - golang.org/x/net v0.39.0 google.golang.org/api v0.215.0 ) @@ -106,7 +104,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect - github.com/sahilm/fuzzy v0.1.1 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect @@ -129,11 +126,8 @@ require ( go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect - golang.org/x/image v0.14.0 // indirect - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect diff --git a/go.sum b/go.sum index 9c2c2df8f..4832271f2 100644 --- a/go.sum +++ b/go.sum @@ -180,10 +180,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 h1:xYfCLI8KUwmXDFp1pOpNX+XsQczQw9VbEuju1pQF5/A= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9/go.mod h1:Ye+kIkTmPO5xuqCQ+PPHDTGIViRRoSpSIlcYgma8YlA= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 h1:9rjt7AfnrXKNSZhp36A3/4QAZAwGGCGD/p8Bse26zms= @@ -235,8 +231,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= -github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= -github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -302,8 +296,6 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= -golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -314,12 +306,6 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= -golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= -golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/internal/app/app.go b/internal/app/app.go index ca23b3c40..1c16ccc11 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -20,7 +21,7 @@ import ( type App struct { Sessions session.Service Messages message.Service - Files history.Service + History history.Service Permissions permission.Service CoderAgent agent.Service @@ -43,7 +44,7 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app := &App{ Sessions: sessions, Messages: messages, - Files: files, + History: files, Permissions: permission.NewPermissionService(), LSPClients: make(map[string]*lsp.Client), } @@ -51,11 +52,17 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app.initLSPClients(ctx) var err error - app.CoderAgent, err = agent.NewCoderAgent( - app.Permissions, + app.CoderAgent, err = agent.NewAgent( + config.AgentCoder, app.Sessions, app.Messages, - app.LSPClients, + agent.CoderAgentTools( + app.Permissions, + app.Sessions, + app.Messages, + app.History, + app.LSPClients, + ), ) if err != nil { logging.Error("Failed to create coder agent", err) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4e0568f07..4a762f1a1 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -22,16 +22,17 @@ func (app *App) initLSPClients(ctx context.Context) { // createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { // Create a specific context for initialization with a timeout - initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer initCancel() // Create the LSP client - lspClient, err := lsp.NewClient(initCtx, command, args...) + lspClient, err := lsp.NewClient(ctx, command, args...) if err != nil { logging.Error("Failed to create LSP client for", name, err) return + } + initCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() // Initialize with the initialization context _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) if err != nil { @@ -64,14 +65,10 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman // runWorkspaceWatcher executes the workspace watcher for an LSP client func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) { defer app.watcherWG.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("LSP client crashed", "client", name, "panic", r) - - // Try to restart the client - app.restartLSPClient(ctx, name) - } - }() + defer logging.RecoverPanic("LSP-"+name, func() { + // Try to restart the client + app.restartLSPClient(ctx, name) + }) workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) logging.Info("Workspace watcher stopped", "client", name) diff --git a/internal/config/config.go b/internal/config/config.go index f0afbdd3c..147d6c83a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -31,12 +31,18 @@ type MCPServer struct { Headers map[string]string `json:"headers"` } -// Model defines configuration for different LLM models and their token limits. -type Model struct { - Coder models.ModelID `json:"coder"` - CoderMaxTokens int64 `json:"coderMaxTokens"` - Task models.ModelID `json:"task"` - TaskMaxTokens int64 `json:"taskMaxTokens"` +type AgentName string + +const ( + AgentCoder AgentName = "coder" + AgentTask AgentName = "task" + AgentTitle AgentName = "title" +) + +// Agent defines configuration for different LLM models and their token limits. +type Agent struct { + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` } // Provider defines configuration for an LLM provider. @@ -65,8 +71,9 @@ type Config struct { MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` LSP map[string]LSPConfig `json:"lsp,omitempty"` - Model Model `json:"model"` + Agents map[AgentName]Agent `json:"agents"` Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debugLSP,omitempty"` } // Application constants @@ -118,11 +125,42 @@ func Load(workingDir string, debug bool) (*Config, error) { if cfg.Debug { defaultLevel = slog.LevelDebug } - // Configure logger - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) + // if we are in debug mode make the writer a file + if cfg.Debug { + loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + + // if file does not exist create it + if _, err := os.Stat(loggingFile); os.IsNotExist(err) { + if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil { + return cfg, fmt.Errorf("failed to create directory: %w", err) + } + if _, err := os.Create(loggingFile); err != nil { + return cfg, fmt.Errorf("failed to create log file: %w", err) + } + } + + sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) + if err != nil { + return cfg, fmt.Errorf("failed to open log file: %w", err) + } + // Configure logger + logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } else { + // Configure logger + logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } + + // Override the max tokens for title agent + cfg.Agents[AgentTitle] = Agent{ + Model: cfg.Agents[AgentTitle].Model, + MaxTokens: 80, + } return cfg, nil } @@ -159,44 +197,50 @@ func setProviderDefaults() { // Groq configuration if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { viper.SetDefault("providers.groq.apiKey", apiKey) - viper.SetDefault("model.coder", models.QWENQwq) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.QWENQwq) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.QWENQwq) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.QWENQwq) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.QWENQwq) } // Google Gemini configuration if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) - viper.SetDefault("model.coder", models.GRMINI20Flash) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GRMINI20Flash) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GRMINI20Flash) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GRMINI20Flash) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GRMINI20Flash) } // OpenAI configuration if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) - viper.SetDefault("model.coder", models.GPT4o) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GPT4o) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GPT4o) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GPT4o) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GPT4o) + } // Anthropic configuration if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) - viper.SetDefault("model.coder", models.Claude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.Claude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.Claude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.Claude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.Claude37Sonnet) } if hasAWSCredentials() { - viper.SetDefault("model.coder", models.BedrockClaude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.BedrockClaude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) } } diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index b45731098..39def271f 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -97,7 +97,9 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) { const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one SELECT id, session_id, path, content, version, created_at, updated_at FROM files -WHERE path = ? AND session_id = ? LIMIT 1 +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1 ` type GetFileByPathAndSessionParams struct { diff --git a/internal/db/sql/files.sql b/internal/db/sql/files.sql index c2e799076..aba2a6111 100644 --- a/internal/db/sql/files.sql +++ b/internal/db/sql/files.sql @@ -6,7 +6,9 @@ WHERE id = ? LIMIT 1; -- name: GetFileByPathAndSession :one SELECT * FROM files -WHERE path = ? AND session_id = ? LIMIT 1; +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1; -- name: ListFilesBySession :many SELECT * diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 02d4d7140..829554c7e 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,6 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) @@ -77,6 +79,8 @@ type linePair struct { // StyleConfig defines styling for diff rendering type StyleConfig struct { + ShowHeader bool + FileNameFg lipgloss.Color // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color @@ -106,11 +110,13 @@ type StyleOption func(*StyleConfig) func NewStyleConfig(opts ...StyleOption) StyleConfig { // Default color scheme config := StyleConfig{ + ShowHeader: true, + FileNameFg: lipgloss.Color("#fab283"), RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), - HunkLineBg: lipgloss.Color("#23252D"), - HunkLineFg: lipgloss.Color("#8CA3B4"), + HunkLineBg: lipgloss.Color("#212121"), + HunkLineFg: lipgloss.Color("#a0a0a0"), RemovedFg: lipgloss.Color("#7C4444"), AddedFg: lipgloss.Color("#478247"), LineNumberFg: lipgloss.Color("#888888"), @@ -132,6 +138,10 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { } // Style option functions +func WithFileNameFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { s.FileNameFg = color } +} + func WithRemovedLineBg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.RemovedLineBg = color } } @@ -190,6 +200,10 @@ func WithHunkLineFg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.HunkLineFg = color } } +func WithShowHeader(show bool) StyleOption { + return func(s *StyleConfig) { s.ShowHeader = show } +} + // ------------------------------------------------------------------------- // Parse Configuration // ------------------------------------------------------------------------- @@ -841,10 +855,12 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str // Calculate column width colWidth := config.TotalWidth / 2 + leftWidth := colWidth + rightWidth := config.TotalWidth - colWidth var sb strings.Builder for _, p := range pairs { - leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style) - rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style) + leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style) + rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style) sb.WriteString(leftStr + rightStr + "\n") } @@ -861,17 +877,50 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { var sb strings.Builder config := NewSideBySideConfig(opts...) - for i, h := range diffResult.Hunks { - if i > 0 { - // Render hunk header - sb.WriteString( - lipgloss.NewStyle(). - Background(config.Style.HunkLineBg). - Foreground(config.Style.HunkLineFg). - Width(config.TotalWidth). - Render(h.Header) + "\n", - ) - } + if config.Style.ShowHeader { + removeIcon := lipgloss.NewStyle(). + Background(config.Style.RemovedLineBg). + Foreground(config.Style.RemovedFg). + Render("⏹") + addIcon := lipgloss.NewStyle(). + Background(config.Style.AddedLineBg). + Foreground(config.Style.AddedFg). + Render("⏹") + + fileName := lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Foreground(config.Style.FileNameFg). + Render(" " + diffResult.OldFile) + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Padding(0, 1, 0, 1). + Foreground(config.Style.FileNameFg). + BorderStyle(lipgloss.NormalBorder()). + BorderTop(true). + BorderBottom(true). + BorderForeground(config.Style.FileNameFg). + BorderBackground(config.Style.ContextLineBg). + Width(config.TotalWidth). + Render( + lipgloss.JoinHorizontal(lipgloss.Top, + removeIcon, + addIcon, + fileName, + ), + ) + "\n", + ) + } + + for _, h := range diffResult.Hunks { + // Render hunk header + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } @@ -880,9 +929,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { // GenerateDiff creates a unified diff from two file contents func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + // remove the cwd prefix and ensure consistent path format + // this prevents issues with absolute paths in different environments + cwd := config.WorkingDirectory() + fileName = strings.TrimPrefix(fileName, cwd) + fileName = strings.TrimPrefix(fileName, "/") // Create temporary directory for git operations - tempDir, err := os.MkdirTemp("", "git-diff-temp") + tempDir, err := os.MkdirTemp("", fmt.Sprintf("git-diff-%d", time.Now().UnixNano())) if err != nil { + logging.Error("Failed to create temp directory for git diff", "error", err) return "", 0, 0 } defer os.RemoveAll(tempDir) @@ -890,25 +945,30 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in // Initialize git repo repo, err := git.PlainInit(tempDir, false) if err != nil { + logging.Error("Failed to initialize git repository", "error", err) return "", 0, 0 } wt, err := repo.Worktree() if err != nil { + logging.Error("Failed to get git worktree", "error", err) return "", 0, 0 } // Write the "before" content and commit it fullPath := filepath.Join(tempDir, fileName) if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + logging.Error("Failed to create directory for file", "error", err) return "", 0, 0 } if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil { + logging.Error("Failed to write before content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -920,16 +980,19 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit before content", "error", err) return "", 0, 0 } // Write the "after" content and commit it if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { + logging.Error("Failed to write after content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -941,22 +1004,26 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit after content", "error", err) return "", 0, 0 } // Get the diff between the two commits beforeCommitObj, err := repo.CommitObject(beforeCommit) if err != nil { + logging.Error("Failed to get before commit object", "error", err) return "", 0, 0 } afterCommitObj, err := repo.CommitObject(afterCommit) if err != nil { + logging.Error("Failed to get after commit object", "error", err) return "", 0, 0 } patch, err := beforeCommitObj.Patch(afterCommitObj) if err != nil { + logging.Error("Failed to create git diff patch", "error", err) return "", 0, 0 } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 83160bb64..308412bde 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) + agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - err = agent.Generate(ctx, session.ID, params.Prompt) + done, err := agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } - - messages, err := b.messages.List(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) - } - - if len(messages) == 0 { - return tools.NewTextErrorResponse("no response"), nil + result := <-done + if result.Err() != nil { + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err()) } - response := messages[len(messages)-1] + response := result.Response() if response.Role != message.Assistant { return tools.NewTextErrorResponse("no response"), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 1958111a1..ab2742ec1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "runtime/debug" "strings" "sync" @@ -16,133 +14,101 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/session" ) // Common errors var ( - ErrProviderNotEnabled = errors.New("provider is not enabled") - ErrRequestCancelled = errors.New("request cancelled by user") - ErrSessionBusy = errors.New("session is currently processing another request") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") ) -// Service defines the interface for generating responses +type AgentEvent struct { + message message.Message + err error +} + +func (e *AgentEvent) Err() error { + return e.err +} + +func (e *AgentEvent) Response() message.Message { + return e.message +} + type Service interface { - Generate(ctx context.Context, sessionID string, content string) error - Cancel(sessionID string) error + Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) + Cancel(sessionID string) + IsSessionBusy(sessionID string) bool } type agent struct { - sessions session.Service - messages message.Service - model models.Model - tools []tools.BaseTool - agent provider.Provider - titleGenerator provider.Provider - activeRequests sync.Map // map[sessionID]context.CancelFunc + sessions session.Service + messages message.Service + + tools []tools.BaseTool + provider provider.Provider + + titleProvider provider.Provider + + activeRequests sync.Map } -// NewAgent creates a new agent instance with the given model and tools -func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) +func NewAgent( + agentName config.AgentName, + sessions session.Service, + messages message.Service, + agentTools []tools.BaseTool, +) (Service, error) { + agentProvider, err := createAgentProvider(agentName) if err != nil { - return nil, fmt.Errorf("failed to initialize providers: %w", err) + return nil, err + } + var titleProvider provider.Provider + // Only generate titles for the coder agent + if agentName == config.AgentCoder { + titleProvider, err = createAgentProvider(config.AgentTitle) + if err != nil { + return nil, err + } } - return &agent{ - model: model, - tools: tools, - sessions: sessions, + agent := &agent{ + provider: agentProvider, messages: messages, - agent: agentProvider, - titleGenerator: titleGenerator, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, activeRequests: sync.Map{}, - }, nil + } + + return agent, nil } -// Cancel cancels an active request by session ID -func (a *agent) Cancel(sessionID string) error { +func (a *agent) Cancel(sessionID string) { if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) cancel() - return nil } } - return errors.New("no active request found for this session") } -// Generate starts the generation process -func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { - // Check if this session already has an active request - if _, busy := a.activeRequests.Load(sessionID); busy { - return ErrSessionBusy - } - - // Create a cancellable context - genCtx, cancel := context.WithCancel(ctx) - - // Store cancel function to allow user cancellation - a.activeRequests.Store(sessionID, cancel) - - // Launch the generation in a goroutine - go func() { - defer func() { - if r := recover(); r != nil { - logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) - - // dump stack trace into a file - file, err := os.Create("panic.log") - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) - return - } - - defer file.Close() - - stackTrace := debug.Stack() - if _, err := file.Write(stackTrace); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) - } - - } - }() - defer a.activeRequests.Delete(sessionID) - defer cancel() - - if err := a.generate(genCtx, sessionID, content); err != nil { - if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { - // Log the error (avoid logging cancellations as they're expected) - logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) - - // You may want to create an error message in the chat - bgCtx := context.Background() - errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) - _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ - Role: message.System, - Parts: []message.ContentPart{ - message.TextContent{ - Text: errorMsg, - }, - }, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) - } - } - } - }() - - return nil -} - -// IsSessionBusy checks if a session currently has an active request func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy -} // handleTitleGeneration asynchronously generates a title for new sessions -func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := a.titleGenerator.SendMessages( +} + +func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { + if a.titleProvider == nil { + return nil + } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return err + } + response, err := a.titleProvider.SendMessages( ctx, []message.Message{ { @@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st }, }, }, - nil, + make([]tools.BaseTool, 0), ) if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) - return + return err } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) - return + title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " ")) + if title == "" { + return nil } - if response.Content != "" { - session.Title = strings.TrimSpace(response.Content) - session.Title = strings.ReplaceAll(session.Title, "\n", " ") - if _, err := a.sessions.Save(ctx, session); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) - } + session.Title = title + _, err = a.sessions.Save(ctx, session) + return err +} + +func (a *agent) err(err error) AgentEvent { + return AgentEvent{ + err: err, } } -// TrackUsage updates token usage statistics for the session -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) +func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) { + events := make(chan AgentEvent) + if a.IsSessionBusy(sessionID) { + return nil, ErrSessionBusy } - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) + genCtx, cancel := context.WithCancel(ctx) + + a.activeRequests.Store(sessionID, cancel) + go func() { + logging.Debug("Request started", "sessionID", sessionID) + defer logging.RecoverPanic("agent.Run", func() { + events <- a.err(fmt.Errorf("panic while running the agent")) + }) - session.Cost += cost - session.CompletionTokens += usage.OutputTokens - session.PromptTokens += usage.InputTokens + result := a.processGeneration(genCtx, sessionID, content) + if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) { + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result)) + } + logging.Debug("Request completed", "sessionID", sessionID) + a.activeRequests.Delete(sessionID) + cancel() + events <- result + close(events) + }() + return events, nil +} - _, err = a.sessions.Save(ctx, session) +func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent { + // List existing messages; if none, start title generation asynchronously. + msgs, err := a.messages.List(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to save session: %w", err) + return a.err(fmt.Errorf("failed to list messages: %w", err)) + } + if len(msgs) == 0 { + go func() { + defer logging.RecoverPanic("agent.Run", func() { + logging.ErrorPersist("panic while generating title") + }) + titleErr := a.generateTitle(context.Background(), sessionID, content) + if titleErr != nil { + logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr)) + } + }() } - return nil -} -// processEvent handles different types of events during generation -func (a *agent) processEvent( - ctx context.Context, - sessionID string, - assistantMsg *message.Message, - event provider.ProviderEvent, -) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing + userMsg, err := a.createUserMessage(ctx, sessionID, content) + if err != nil { + return a.err(fmt.Errorf("failed to create user message: %w", err)) } - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - if errors.Is(event.Error, context.Canceled) { - logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) - return context.Canceled + // Append the new user message to the conversation history. + msgHistory := append(msgs, userMsg) + for { + // Check for cancellation before each iteration + select { + case <-ctx.Done(): + return a.err(ctx.Err()) + default: + // Continue processing } - logging.ErrorPersist(event.Error.Error()) - return event.Error - case provider.EventWarning: - logging.WarnPersist(event.Info) - case provider.EventInfo: - logging.InfoPersist(event.Info) - case provider.EventComplete: - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason) - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) + if err != nil { + if errors.Is(err, context.Canceled) { + return a.err(ErrRequestCancelled) + } + return a.err(fmt.Errorf("failed to process events: %w", err)) + } + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { + // We are not done, we need to respond with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } + return AgentEvent{ + message: agentMessage, } - return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } +} - return nil +func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) { + return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: content}, + }, + }) } -// ExecuteTools runs all tool calls sequentially and returns the results -func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - toolResults := make([]message.ToolResult, len(toolCalls)) +func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + }) + if err != nil { + return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) + } - // Create a child context that can be canceled - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Add the session and message ID into the context if needed by tools. + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - // Check if already canceled before starting any execution - if ctx.Err() != nil { - // Mark all tools as canceled - for i, toolCall := range toolCalls { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled by user", - IsError: true, - } + // Process each event in the stream. + for event := range eventChan { + if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { + a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, processErr + } + if ctx.Err() != nil { + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, ctx.Err() } - return toolResults, ctx.Err() } + toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) + toolCalls := assistantMsg.ToolCalls() for i, toolCall := range toolCalls { - // Check for cancellation before executing each tool select { case <-ctx.Done(): - // Mark this and all remaining tools as canceled + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + // Make all future tool calls cancelled for j := i; j < len(toolCalls); j++ { toolResults[j] = message.ToolResult{ ToolCallID: toolCalls[j].ID, @@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, IsError: true, } } - return toolResults, ctx.Err() + goto out default: // Continue processing - } - - response := "" - isError := false - found := false - - // Find and execute the appropriate tool - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled by user" - } else { - response = fmt.Sprintf("Error running tool: %s", toolErr) - } - isError = true - } else { - response = toolResult.Content - isError = toolResult.IsError + var tool tools.BaseTool + for _, availableTools := range a.tools { + if availableTools.Info().Name == toolCall.Name { + tool = availableTools } - break } - } - - if !found { - response = fmt.Sprintf("Tool not found: %s", toolCall.Name) - isError = true - } - - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - } - return toolResults, nil -} - -// handleToolExecution processes tool calls and creates tool result messages -func (a *agent) handleToolExecution( - ctx context.Context, - assistantMsg message.Message, -) (*message.Message, error) { - select { - case <-ctx.Done(): - // If cancelled, create tool results that indicate cancellation - if len(assistantMsg.ToolCalls()) > 0 { - toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) - for _, tc := range assistantMsg.ToolCalls() { - toolResults = append(toolResults, message.ToolResult{ - ToolCallID: tc.ID, - Content: "Tool execution canceled by user", + // Tool not found + if tool == nil { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), IsError: true, - }) + } + continue } - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) - } - msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, }) - if err != nil { - return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - return &msg, ctx.Err() - } - return nil, ctx.Err() - default: - // Continue processing - } - - if len(assistantMsg.ToolCalls()) == 0 { - return nil, nil - } - - toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) - if err != nil { - // If error is from cancellation, still return the partial results we have - if errors.Is(err, context.Canceled) { - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) + if toolErr != nil { + if errors.Is(toolErr, permission.ErrorPermissionDenied) { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Permission denied", + IsError: true, + } + for j := i + 1; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) + } else { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolErr.Error(), + IsError: true, + } + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Previous tool failed", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + } + // If permission is denied or an error happens we cancel all the following tools + break } - - msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) - return nil, err + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolResult.Content, + Metadata: toolResult.Metadata, + IsError: toolResult.IsError, } - return &msg, err } - return nil, err } - - parts := make([]message.ContentPart, 0, len(toolResults)) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) +out: + if len(toolResults) == 0 { + return assistantMsg, nil, nil } - - msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + parts := make([]message.ContentPart, 0) + for _, tr := range toolResults { + parts = append(parts, tr) + } + msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) if err != nil { - return nil, fmt.Errorf("failed to create tool message: %w", err) + return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) } - return &msg, nil + return assistantMsg, &msg, err } -// generate handles the main generation workflow -func (a *agent) generate(ctx context.Context, sessionID string, content string) error { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) +func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) { + msg.AddFinish(finishReson) + _ = a.messages.Update(ctx, *msg) +} - // Handle context cancellation at any point - if err := ctx.Err(); err != nil { - return ErrRequestCancelled +func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing. } - messages, err := a.messages.List(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to list messages: %w", err) + switch event.Type { + case provider.EventThinkingDelta: + assistantMsg.AppendReasoningContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventContentDelta: + assistantMsg.AppendContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventError: + if errors.Is(event.Error, context.Canceled) { + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled + } + logging.ErrorPersist(event.Error.Error()) + return event.Error + case provider.EventComplete: + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) } - if len(messages) == 0 { - titleCtx := context.Background() - go a.handleTitleGeneration(titleCtx, sessionID, content) - } + return nil +} - userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: content, - }, - }, - }) +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + sess, err := a.sessions.Get(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to create user message: %w", err) + return fmt.Errorf("failed to get session: %w", err) } - messages = append(messages, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - // Continue processing - } - - eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) - if err != nil { - if errors.Is(err, context.Canceled) { - return ErrRequestCancelled - } - return fmt.Errorf("failed to stream response: %w", err) - } - - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.model.ID, - }) - if err != nil { - return fmt.Errorf("failed to create assistant message: %w", err) - } - - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - - // Process events from the LLM provider - for event := range eventChan { - if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { - if errors.Is(err, context.Canceled) { - // Mark as canceled but don't create separate message - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - assistantMsg.AddFinish("error:" + err.Error()) - _ = a.messages.Update(ctx, assistantMsg) - return fmt.Errorf("event processing error: %w", err) - } - - // Check for cancellation during event processing - select { - case <-ctx.Done(): - // Mark as canceled - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - } - - // Check for cancellation before tool execution - select { - case <-ctx.Done(): - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - - // Execute any tool calls - toolMsg, err := a.handleToolExecution(ctx, assistantMsg) - if err != nil { - if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - return fmt.Errorf("tool execution error: %w", err) - } - - if err := a.messages.Update(ctx, assistantMsg); err != nil { - return fmt.Errorf("failed to update assistant message: %w", err) - } - - // If no tool calls, we're done - if len(assistantMsg.ToolCalls()) == 0 { - break - } + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) - // Add messages for next iteration - messages = append(messages, assistantMsg) - if toolMsg != nil { - messages = append(messages, *toolMsg) - } + sess.Cost += cost + sess.CompletionTokens += usage.OutputTokens + sess.PromptTokens += usage.InputTokens - // Check for cancellation after tool execution - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - } + _, err = a.sessions.Save(ctx, sess) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) } - return nil } -// getAgentProviders initializes the LLM providers based on the chosen model -func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { - maxTokens := config.Get().Model.CoderMaxTokens - - providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || providerConfig.Disabled { - return nil, nil, ErrProviderNotEnabled +func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { + cfg := config.Get() + agentConfig, ok := cfg.Agents[agentName] + if !ok { + return nil, fmt.Errorf("agent %s not found", agentName) + } + model, ok := models.SupportedModels[agentConfig.Model] + if !ok { + return nil, fmt.Errorf("model %s not supported", agentConfig.Model) } - var agentProvider provider.Provider - var titleGenerator provider.Provider - var err error - - switch model.Provider { - case models.ProviderOpenAI: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) - } - - case models.ProviderAnthropic: - agentProvider, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithAnthropicMaxTokens(maxTokens), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) - } - - titleGenerator, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithAnthropicMaxTokens(80), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) - } - - case models.ProviderGemini: - agentProvider, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithGeminiMaxTokens(int32(maxTokens)), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) - } - - titleGenerator, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithGeminiMaxTokens(80), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) - } - - case models.ProviderGROQ: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) - } - - case models.ProviderBedrock: - agentProvider, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithBedrockMaxTokens(maxTokens), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) - } - - titleGenerator, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithBedrockMaxTokens(80), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) - } - default: - return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) + providerCfg, ok := cfg.Providers[model.Provider] + if !ok { + return nil, fmt.Errorf("provider %s not supported", model.Provider) + } + if providerCfg.Disabled { + return nil, fmt.Errorf("provider %s is not enabled", model.Provider) + } + agentProvider, err := provider.NewProvider( + model.Provider, + provider.WithAPIKey(providerCfg.APIKey), + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), + provider.WithMaxTokens(agentConfig.MaxTokens), + ) + if err != nil { + return nil, fmt.Errorf("could not create provider: %v", err) } - return agentProvider, titleGenerator, nil + return agentProvider, nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go deleted file mode 100644 index a3db6b55c..000000000 --- a/internal/llm/agent/coder.go +++ /dev/null @@ -1,63 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type coderAgent struct { - Service -} - -func NewCoderAgent( - permissions permission.Service, - sessions session.Service, - messages message.Service, - lspClients map[string]*lsp.Client, -) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ), - ) - if err != nil { - return nil, err - } - - return &coderAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index b1c97b512..c7ea4916c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } @@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go deleted file mode 100644 index fca1f223f..000000000 --- a/internal/llm/agent/task.go +++ /dev/null @@ -1,47 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type taskAgent struct { - Service -} - -func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - ) - if err != nil { - return nil, err - } - - return &taskAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 000000000..a37f1d65d --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,50 @@ +package agent + +import ( + "context" + + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +func CoderAgentTools( + permissions permission.Service, + sessions session.Service, + messages message.Service, + history history.Service, + lspClients map[string]*lsp.Client, +) []tools.BaseTool { + ctx := context.Background() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + return append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + }, otherTools..., + ) +} + +func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { + return []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + } +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go new file mode 100644 index 000000000..48307e6d3 --- /dev/null +++ b/internal/llm/models/anthropic.go @@ -0,0 +1,71 @@ +package models + +const ( + ProviderAnthropic ModelProvider = "anthropic" + + // Models + Claude35Sonnet ModelID = "claude-3.5-sonnet" + Claude3Haiku ModelID = "claude-3-haiku" + Claude37Sonnet ModelID = "claude-3.7-sonnet" + Claude35Haiku ModelID = "claude-3.5-haiku" + Claude3Opus ModelID = "claude-3-opus" +) + +var AnthropicModels = map[ModelID]Model{ + // Anthropic + Claude35Sonnet: { + ID: Claude35Sonnet, + Name: "Claude 3.5 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude3Haiku: { + ID: Claude3Haiku, + Name: "Claude 3 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-haiku-latest", + CostPer1MIn: 0.25, + CostPer1MInCached: 0.30, + CostPer1MOutCached: 0.03, + CostPer1MOut: 1.25, + ContextWindow: 200000, + }, + Claude37Sonnet: { + ID: Claude37Sonnet, + Name: "Claude 3.7 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-7-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude35Haiku: { + ID: Claude35Haiku, + Name: "Claude 3.5 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-haiku-latest", + CostPer1MIn: 0.80, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + CostPer1MOut: 4.0, + ContextWindow: 200000, + }, + Claude3Opus: { + ID: Claude3Opus, + Name: "Claude 3 Opus", + Provider: ProviderAnthropic, + APIModel: "claude-3-opus-latest", + CostPer1MIn: 15.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.50, + CostPer1MOut: 75.0, + ContextWindow: 200000, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 140693237..4d4589bfd 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -1,5 +1,7 @@ package models +import "maps" + type ( ModelID string ModelProvider string @@ -14,15 +16,13 @@ type Model struct { CostPer1MOut float64 `json:"cost_per_1m_out"` CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` } // Model IDs const ( - // Anthropic - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" // OpenAI + GPT4o ModelID = "gpt-4o" GPT41 ModelID = "gpt-4.1" // GEMINI @@ -37,47 +37,59 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" - ProviderAnthropic ModelProvider = "anthropic" - ProviderBedrock ModelProvider = "bedrock" - ProviderGemini ModelProvider = "gemini" - ProviderGROQ ModelProvider = "groq" + ProviderOpenAI ModelProvider = "openai" + ProviderBedrock ModelProvider = "bedrock" + ProviderGemini ModelProvider = "gemini" + ProviderGROQ ModelProvider = "groq" + + // ForTests + ProviderMock ModelProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ - // Anthropic - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, + // // Anthropic + // Claude35Sonnet: { + // ID: Claude35Sonnet, + // Name: "Claude 3.5 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-5-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // Claude3Haiku: { + // ID: Claude3Haiku, + // Name: "Claude 3 Haiku", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-haiku-latest", + // CostPer1MIn: 0.80, + // CostPer1MInCached: 1, + // CostPer1MOutCached: 0.08, + // CostPer1MOut: 4, + // }, + // Claude37Sonnet: { + // ID: Claude37Sonnet, + // Name: "Claude 3.7 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-7-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // + // // OpenAI + GPT4o: { + ID: GPT4o, + Name: "GPT-4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0, + CostPer1MOut: 8.00, }, - - // OpenAI GPT41: { ID: GPT41, Name: "GPT-4.1", @@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 8.00, }, + // + // // GEMINI + // GEMINI25: { + // ID: GEMINI25, + // Name: "Gemini 2.5 Pro", + // Provider: ProviderGemini, + // APIModel: "gemini-2.5-pro-exp-03-25", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // GRMINI20Flash: { + // ID: GRMINI20Flash, + // Name: "Gemini 2.0 Flash", + // Provider: ProviderGemini, + // APIModel: "gemini-2.0-flash", + // CostPer1MIn: 0.1, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0.025, + // CostPer1MOut: 0.4, + // }, + // + // // GROQ + // QWENQwq: { + // ID: QWENQwq, + // Name: "Qwen Qwq", + // Provider: ProviderGROQ, + // APIModel: "qwen-qwq-32b", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // // Bedrock + // BedrockClaude37Sonnet: { + // ID: BedrockClaude37Sonnet, + // Name: "Bedrock: Claude 3.7 Sonnet", + // Provider: ProviderBedrock, + // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, +} - // GEMINI - GEMINI25: { - ID: GEMINI25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-exp-03-25", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - GRMINI20Flash: { - ID: GRMINI20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.1, - CostPer1MInCached: 0, - CostPer1MOutCached: 0.025, - CostPer1MOut: 0.4, - }, - - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, +func init() { + maps.Copy(SupportedModels, AnthropicModels) } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 47941f976..7439fd570 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,11 +9,22 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" ) -func CoderOpenAISystemPrompt() string { - basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. +func CoderPrompt(provider models.ModelProvider) string { + basePrompt := baseAnthropicCoderPrompt + switch provider { + case models.ProviderOpenAI: + basePrompt = baseOpenAICoderPrompt + } + envInfo := getEnvironmentInfo() + + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) +} + +const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. # Your mindset Act like a competent, efficient software engineer who is familiar with large codebases. You should: @@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines] Never commit changes unless the user explicitly asks you to.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - -func CoderAnthropicSystemPrompt() string { - basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. @@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - func getEnvironmentInfo() string { cwd := config.WorkingDirectory() isGit := isGitRepo(cwd) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go new file mode 100644 index 000000000..63fc2df7b --- /dev/null +++ b/internal/llm/prompt/prompt.go @@ -0,0 +1,19 @@ +package prompt + +import ( + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" +) + +func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { + switch agentName { + case config.AgentCoder: + return CoderPrompt(provider) + case config.AgentTitle: + return TitlePrompt(provider) + case config.AgentTask: + return TaskPrompt(provider) + default: + return "You are a helpful assistant" + } +} diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index ee3c707fa..8bf604ad9 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,11 +2,12 @@ package prompt import ( "fmt" + + "github.com/kujtimiihoxha/termai/internal/llm/models" ) -func TaskAgentSystemPrompt() string { +func TaskPrompt(_ models.ModelProvider) string { agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. - Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 5c47f4d64..3023a8550 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,8 @@ package prompt -func TitlePrompt() string { +import "github.com/kujtimiihoxha/termai/internal/llm/models" + +func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 93c4308ad..c3a4efc49 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,187 +12,257 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" ) -type anthropicProvider struct { - client anthropic.Client - model models.Model - maxTokens int64 - apiKey string - systemMessage string - useBedrock bool - disableCache bool +type anthropicOptions struct { + useBedrock bool + disableCache bool + shouldThink func(userMessage string) bool } -type AnthropicOption func(*anthropicProvider) +type AnthropicOption func(*anthropicOptions) -func WithAnthropicSystemMessage(message string) AnthropicOption { - return func(a *anthropicProvider) { - a.systemMessage = message - } +type anthropicClient struct { + providerOptions providerClientOptions + options anthropicOptions + client anthropic.Client } -func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption { - return func(a *anthropicProvider) { - a.maxTokens = maxTokens - } -} +type AnthropicClient ProviderClient -func WithAnthropicModel(model models.Model) AnthropicOption { - return func(a *anthropicProvider) { - a.model = model +func newAnthropicClient(opts providerClientOptions) AnthropicClient { + anthropicOpts := anthropicOptions{} + for _, o := range opts.anthropicOptions { + o(&anthropicOpts) } -} -func WithAnthropicKey(apiKey string) AnthropicOption { - return func(a *anthropicProvider) { - a.apiKey = apiKey + anthropicClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithAnthropicBedrock() AnthropicOption { - return func(a *anthropicProvider) { - a.useBedrock = true + if anthropicOpts.useBedrock { + anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } -} -func WithAnthropicDisableCache() AnthropicOption { - return func(a *anthropicProvider) { - a.disableCache = true + client := anthropic.NewClient(anthropicClientOptions...) + return &anthropicClient{ + providerOptions: opts, + options: anthropicOpts, + client: client, } } -func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { - provider := &anthropicProvider{ - maxTokens: 1024, - } +func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { + cachedBlocks := 0 + for _, msg := range messages { + switch msg.Role { + case message.User: + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - for _, opt := range opts { - opt(provider) - } + case message.Assistant: + blocks := []anthropic.ContentBlockParamUnion{} + if msg.Content().String() != "" { + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + blocks = append(blocks, content) + } - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } + for _, toolCall := range msg.ToolCalls() { + var inputMap map[string]any + err := json.Unmarshal([]byte(toolCall.Input), &inputMap) + if err != nil { + continue + } + blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) + } - anthropicOptions := []option.RequestOption{} + if len(blocks) == 0 { + logging.Warn("There is a message without content, investigate") + // This should never happend but we log this because we might have a bug in our cleanup method + continue + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - if provider.apiKey != "" { - anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) - } - if provider.useBedrock { - anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + case message.Tool: + results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) + for i, toolResult := range msg.ToolResults() { + results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) + } } - - provider.client = anthropic.NewClient(anthropicOptions...) - return provider, nil + return } -func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) - - response, err := a.client.Messages.New( - ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: anthropic.Float(0), - Messages: anthropicMessages, - Tools: anthropicTools, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, +func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { + anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) + + for i, tool := range tools { + info := tool.Info() + toolParam := anthropic.ToolParam{ + Name: info.Name, + Description: anthropic.String(info.Description), + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: info.Parameters, + // TODO: figure out how we can tell claude the required fields? }, - }, - ) - if err != nil { - return nil, err - } + } - content := "" - for _, block := range response.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text + if i == len(tools)-1 && !a.options.disableCache { + toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } } - } - toolCalls := a.extractToolCalls(response.Content) - tokenUsage := a.extractTokenUsage(response.Usage) + anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return anthropicTools } -func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) +func (a *anthropicClient) finishReason(reason string) message.FinishReason { + switch reason { + case "end_turn": + return message.FinishReasonEndTurn + case "max_tokens": + return message.FinishReasonMaxTokens + case "tool_use": + return message.FinishReasonToolUse + case "stop_sequence": + return message.FinishReasonEndTurn + default: + return message.FinishReasonUnknown + } +} +func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion lastMessage := messages[len(messages)-1] + isUser := lastMessage.Role == anthropic.MessageParamRoleUser + messageContent := "" temperature := anthropic.Float(0) - if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") { - thinkingParam = anthropic.ThinkingConfigParamUnion{ - OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: int64(float64(a.maxTokens) * 0.8), - Type: "enabled", - }, + if isUser { + for _, m := range lastMessage.Content { + if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" { + messageContent = m.OfRequestTextBlock.Text + } + } + if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { + thinkingParam = anthropic.ThinkingConfigParamUnion{ + OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8), + Type: "enabled", + }, + } + temperature = anthropic.Float(1) } - temperature = anthropic.Float(1) } - eventChan := make(chan ProviderEvent) + return anthropic.MessageNewParams{ + Model: anthropic.Model(a.providerOptions.model.APIModel), + MaxTokens: a.providerOptions.maxTokens, + Temperature: temperature, + Messages: messages, + Tools: tools, + Thinking: thinkingParam, + System: []anthropic.TextBlockParam{ + { + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }, + }, + } +} - go func() { - defer close(eventChan) +func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + for { + attempts++ + anthropicResponse, err := a.client.Messages.New( + ctx, + preparedMessages, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr + } - const maxRetries = 8 - attempts := 0 + content := "" + for _, block := range anthropicResponse.Content { + if text, ok := block.AsAny().(anthropic.TextBlock); ok { + content += text.Text + } + } - for { + return &ProviderResponse{ + Content: content, + ToolCalls: a.toolCalls(*anthropicResponse), + Usage: a.usage(*anthropicResponse), + }, nil + } +} +func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + eventChan := make(chan ProviderEvent) + go func() { + for { attempts++ - - stream := a.client.Messages.NewStreaming( + anthropicStream := a.client.Messages.NewStreaming( ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: temperature, - Messages: anthropicMessages, - Tools: anthropicTools, - Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, - }, + preparedMessages, ) - accumulatedMessage := anthropic.Message{} - for stream.Next() { - event := stream.Current() + for anthropicStream.Next() { + event := anthropicStream.Current() err := accumulatedMessage.Accumulate(event) if err != nil { eventChan <- ProviderEvent{Type: EventError, Error: err} - return // Don't retry on accumulation errors + continue } switch event := event.AsAny().(type) { @@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa Content: event.Delta.Text, } } + // TODO: check if we can somehow stream tool calls case anthropic.ContentBlockStopEvent: eventChan <- ProviderEvent{Type: EventContentStop} @@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa } } - toolCalls := a.extractToolCalls(accumulatedMessage.Content) - tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage) - eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(accumulatedMessage.StopReason), + ToolCalls: a.toolCalls(accumulatedMessage), + Usage: a.usage(accumulatedMessage), + FinishReason: a.finishReason(string(accumulatedMessage.StopReason)), }, } } } - err := stream.Err() + err := anthropicStream.Err() if err == nil || errors.Is(err, io.EOF) { + close(eventChan) return } - - var apierr *anthropic.Error - if !errors.As(err, &apierr) { - eventChan <- ProviderEvent{Type: EventError, Error: err} - return - } - - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { - eventChan <- ProviderEvent{Type: EventError, Error: err} + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } - - if attempts > maxRetries { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("maximum retry attempts reached for rate limit (429)"), - } - return - } - - retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") - if len(retryAfterValues) > 0 { - var retryAfterSec int - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil { - retryMs = retryAfterSec * 1000 - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec), + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue } - } else { - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries), - } - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs } - select { - case <-ctx.Done(): + if ctx.Err() != nil { eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - return - case <-time.After(time.Duration(retryMs) * time.Millisecond): - continue } + close(eventChan) + return } }() + return eventChan +} - return eventChan, nil +func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *anthropic.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil } -func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall { +func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { var toolCalls []message.ToolCall - for _, block := range content { + for _, block := range msg.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: toolCall := message.ToolCall{ @@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni return toolCalls } -func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage { +func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { return TokenUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - CacheCreationTokens: usage.CacheCreationInputTokens, - CacheReadTokens: usage.CacheReadInputTokens, + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + CacheCreationTokens: msg.Usage.CacheCreationInputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, } } -func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { - anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - toolParam := anthropic.ToolParam{ - Name: info.Name, - Description: anthropic.String(info.Description), - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: info.Parameters, - }, - } - - if i == len(tools)-1 && !a.disableCache { - toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - - anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} +func WithAnthropicBedrock(useBedrock bool) AnthropicOption { + return func(options *anthropicOptions) { + options.useBedrock = useBedrock } - - return anthropicTools } -func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam { - anthropicMessages := make([]anthropic.MessageParam, 0, len(messages)) - cachedBlocks := 0 - - for _, msg := range messages { - switch msg.Role { - case message.User: - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - - case message.Assistant: - blocks := []anthropic.ContentBlockParamUnion{} - if msg.Content().String() != "" { - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - blocks = append(blocks, content) - } - - for _, toolCall := range msg.ToolCalls() { - var inputMap map[string]any - err := json.Unmarshal([]byte(toolCall.Input), &inputMap) - if err != nil { - continue - } - blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) - } +func WithAnthropicDisableCache() AnthropicOption { + return func(options *anthropicOptions) { + options.disableCache = true + } +} - if len(blocks) > 0 { - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } +func DefaultShouldThinkFn(s string) bool { + return strings.Contains(strings.ToLower(s), "think") +} - case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) - for i, toolResult := range msg.ToolResults() { - results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) - } +func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { + return func(options *anthropicOptions) { + options.shouldThink = fn } - - return anthropicMessages } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 677f4676b..d76925ad1 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,33 +7,29 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -type bedrockProvider struct { - childProvider Provider - model models.Model - maxTokens int64 - systemMessage string +type bedrockOptions struct { + // Bedrock specific options can be added here } -func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - return b.childProvider.SendMessages(ctx, messages, tools) -} +type BedrockOption func(*bedrockOptions) -func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - return b.childProvider.StreamResponse(ctx, messages, tools) +type bedrockClient struct { + providerOptions providerClientOptions + options bedrockOptions + childProvider ProviderClient } -func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { - provider := &bedrockProvider{} - for _, opt := range opts { - opt(provider) - } +type BedrockClient ProviderClient + +func newBedrockClient(opts providerClientOptions) BedrockClient { + bedrockOpts := bedrockOptions{} + // Apply bedrock specific options if they are added in the future - // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + // Get AWS region from environment region := os.Getenv("AWS_REGION") if region == "" { region = os.Getenv("AWS_DEFAULT_REGION") @@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { region = "us-east-1" // default region } if len(region) < 2 { - return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, // Will cause an error when used + } } + + // Prefix the model name with region regionPrefix := region[:2] - provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + modelName := opts.model.APIModel + opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) - if strings.Contains(string(provider.model.APIModel), "anthropic") { - anthropic, err := NewAnthropicProvider( - WithAnthropicModel(provider.model), - WithAnthropicMaxTokens(provider.maxTokens), - WithAnthropicSystemMessage(provider.systemMessage), - WithAnthropicBedrock(), + // Determine which provider to use based on the model + if strings.Contains(string(opts.model.APIModel), "anthropic") { + // Create Anthropic client with Bedrock configuration + anthropicOpts := opts + anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, + WithAnthropicBedrock(true), WithAnthropicDisableCache(), ) - provider.childProvider = anthropic - if err != nil { - return nil, err + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: newAnthropicClient(anthropicOpts), } - } else { - return nil, errors.New("unsupported model for bedrock provider") } - return provider, nil -} - -type BedrockOption func(*bedrockProvider) -func WithBedrockSystemMessage(message string) BedrockOption { - return func(a *bedrockProvider) { - a.systemMessage = message + // Return client with nil childProvider if model is not supported + // This will cause an error when used + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, } } -func WithBedrockMaxTokens(maxTokens int64) BedrockOption { - return func(a *bedrockProvider) { - a.maxTokens = maxTokens +func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if b.childProvider == nil { + return nil, errors.New("unsupported model for bedrock provider") } + return b.childProvider.send(ctx, messages, tools) } -func WithBedrockModel(model models.Model) BedrockOption { - return func(a *bedrockProvider) { - a.model = model +func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + eventChan := make(chan ProviderEvent) + + if b.childProvider == nil { + go func() { + eventChan <- ProviderEvent{ + Type: EventError, + Error: errors.New("unsupported model for bedrock provider"), + } + close(eventChan) + }() + return eventChan } -} + + return b.childProvider.stream(ctx, messages, tools) +} \ No newline at end of file diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 2d1db2b64..804baea28 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -4,80 +4,68 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" + "strings" + "time" "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) -type geminiProvider struct { - client *genai.Client - model models.Model - maxTokens int32 - apiKey string - systemMessage string +type geminiOptions struct { + disableCache bool } -type GeminiOption func(*geminiProvider) +type GeminiOption func(*geminiOptions) -func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) { - provider := &geminiProvider{ - maxTokens: 5000, - } +type geminiClient struct { + providerOptions providerClientOptions + options geminiOptions + client *genai.Client +} - for _, opt := range opts { - opt(provider) - } +type GeminiClient ProviderClient - if provider.systemMessage == "" { - return nil, errors.New("system message is required") +func newGeminiClient(opts providerClientOptions) GeminiClient { + geminiOpts := geminiOptions{} + for _, o := range opts.geminiOptions { + o(&geminiOpts) } - client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey)) + client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey)) if err != nil { - return nil, err - } - provider.client = client - - return provider, nil -} - -func WithGeminiSystemMessage(message string) GeminiOption { - return func(p *geminiProvider) { - p.systemMessage = message + logging.Error("Failed to create Gemini client", "error", err) + return nil } -} -func WithGeminiMaxTokens(maxTokens int32) GeminiOption { - return func(p *geminiProvider) { - p.maxTokens = maxTokens + return &geminiClient{ + providerOptions: opts, + options: geminiOpts, + client: client, } } -func WithGeminiModel(model models.Model) GeminiOption { - return func(p *geminiProvider) { - p.model = model - } -} - -func WithGeminiKey(apiKey string) GeminiOption { - return func(p *geminiProvider) { - p.apiKey = apiKey - } -} +func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { + var history []*genai.Content -func (p *geminiProvider) Close() { - if p.client != nil { - p.client.Close() - } -} + // Add system message first + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)}, + Role: "user", + }) -func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content { - var history []*genai.Content + // Add a system response to acknowledge the system message + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text("I'll help you with that.")}, + Role: "model", + }) for _, msg := range messages { switch msg.Role { @@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g Parts: []genai.Part{genai.Text(msg.Content().String())}, Role: "user", }) + case message.Assistant: content := &genai.Content{ Role: "model", @@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g } history = append(history, content) + case message.Tool: for _, result := range msg.ToolResults() { response := map[string]interface{}{"result": result.Content} @@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g if err == nil { response = parsed } + var toolCall message.ToolCall - for _, msg := range messages { - if msg.Role == message.Assistant { - for _, call := range msg.ToolCalls() { + for _, m := range messages { + if m.Role == message.Assistant { + for _, call := range m.ToolCalls() { if call.ID == result.ToolCallID { toolCall = call break @@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g return history } -func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage { - if resp == nil || resp.UsageMetadata == nil { - return TokenUsage{} - } +func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { + geminiTools := make([]*genai.Tool, 0, len(tools)) - return TokenUsage{ - InputTokens: int64(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), - CacheCreationTokens: 0, // Not directly provided by Gemini - CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + for _, tool := range tools { + info := tool.Info() + declaration := &genai.FunctionDeclaration{ + Name: info.Name, + Description: info.Description, + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: convertSchemaProperties(info.Parameters), + Required: info.Required, + }, + } + + geminiTools = append(geminiTools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{declaration}, + }) } + + return geminiTools } -func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) +func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { + reasonStr := reason.String() + switch { + case reasonStr == "STOP": + return message.FinishReasonEndTurn + case reasonStr == "MAX_TOKENS": + return message.FinishReasonMaxTokens + case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"): + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String())) - if err != nil { - return nil, err + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - var content string - var toolCalls []message.ToolCall + attempts := 0 + for { + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break + } + } - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - content = string(p) - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - }) + resp, err := chat.SendMessage(ctx, genai.Text(lastText)) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(resp) + content := "" + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + content = string(p) + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + }) + } + } + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: g.usage(resp), + FinishReason: g.finishReason(resp.Candidates[0].FinishReason), + }, nil + } } -func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) - - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - - iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String())) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) go func() { defer close(eventChan) - var finalResp *genai.GenerateContentResponse - currentContent := "" - toolCalls := []message.ToolCall{} - for { - resp, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break } - return } - finalResp = resp + iter := chat.SendMessageStream(ctx, genai.Text(lastText)) - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - newText := string(p) - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: newText, - } - currentContent += newText - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - newCall := message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - } + currentContent := "" + toolCalls := []message.ToolCall{} + var finalResp *genai.GenerateContentResponse - isNew := true - for _, existing := range toolCalls { - if existing.Name == newCall.Name && existing.Input == newCall.Input { - isNew = false - break + eventChan <- ProviderEvent{Type: EventContentStart} + + for { + resp, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + + return + case <-time.After(time.Duration(after) * time.Millisecond): + break } + } else { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + } + + finalResp = resp + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + newText := string(p) + delta := newText[len(currentContent):] + if delta != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: delta, + } + currentContent = newText + } + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + newCall := message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + } - if isNew { - toolCalls = append(toolCalls, newCall) + isNew := true + for _, existing := range toolCalls { + if existing.Name == newCall.Name && existing.Input == newCall.Input { + isNew = false + break + } + } + + if isNew { + toolCalls = append(toolCalls, newCall) + } } } } } - } - tokenUsage := p.extractTokenUsage(finalResp) + eventChan <- ProviderEvent{Type: EventContentStop} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(finalResp.Candidates[0].FinishReason.String()), - }, + if finalResp != nil { + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: g.usage(finalResp), + FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason), + }, + } + return + } + + // If we get here, we need to retry + if attempts > maxRetries { + eventChan <- ProviderEvent{ + Type: EventError, + Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries), + } + return + } + + // Wait before retrying + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + return + case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond): + continue + } } }() - return eventChan, nil + return eventChan } -func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { - declarations := make([]*genai.FunctionDeclaration, len(tools)) +func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + // Check if error is a rate limit error + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } - for i, tool := range tools { - info := tool.Info() - declarations[i] = &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: convertSchemaProperties(info.Parameters), - Required: info.Required, - }, + // Gemini doesn't have a standard error type we can check against + // So we'll check the error message for rate limit indicators + if errors.Is(err, io.EOF) { + return false, 0, err + } + + errMsg := err.Error() + isRateLimit := false + + // Check for common rate limit error messages + if contains(errMsg, "rate limit", "quota exceeded", "too many requests") { + isRateLimit = true + } + + if !isRateLimit { + return false, 0, err + } + + // Calculate backoff with jitter + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs := backoffMs + jitterMs + + return true, int64(retryMs), nil +} + +func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + if funcCall, ok := part.(genai.FunctionCall); ok { + id := "call_" + uuid.New().String() + args, _ := json.Marshal(funcCall.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: funcCall.Name, + Input: string(args), + Type: "function", + }) + } } } - return declarations + return toolCalls +} + +func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { + if resp == nil || resp.UsageMetadata == nil { + return TokenUsage{} + } + + return TokenUsage{ + InputTokens: int64(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), + CacheCreationTokens: 0, // Not directly provided by Gemini + CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + } +} + +func WithGeminiDisableCache() GeminiOption { + return func(options *geminiOptions) { + options.disableCache = true + } +} + +// Helper functions +func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { + var result map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + return result, err } func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema { @@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type { } } -func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { - var result map[string]interface{} - err := json.Unmarshal([]byte(jsonStr), &result) - return result, err +func contains(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) { + return true + } + } + return false } + diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index dbfde3fa8..9c2ad2012 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -2,89 +2,65 @@ package provider import ( "context" + "encoding/json" "errors" + "fmt" + "io" + "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) -type openaiProvider struct { - client openai.Client - model models.Model - maxTokens int64 - baseURL string - apiKey string - systemMessage string +type openaiOptions struct { + baseURL string + disableCache bool } -type OpenAIOption func(*openaiProvider) +type OpenAIOption func(*openaiOptions) -func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) { - provider := &openaiProvider{ - maxTokens: 5000, - } - - for _, opt := range opts { - opt(provider) - } - - clientOpts := []option.RequestOption{ - option.WithAPIKey(provider.apiKey), - } - if provider.baseURL != "" { - clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL)) - } - - provider.client = openai.NewClient(clientOpts...) - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } - - return provider, nil +type openaiClient struct { + providerOptions providerClientOptions + options openaiOptions + client openai.Client } -func WithOpenAISystemMessage(message string) OpenAIOption { - return func(p *openaiProvider) { - p.systemMessage = message - } -} +type OpenAIClient ProviderClient -func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption { - return func(p *openaiProvider) { - p.maxTokens = maxTokens +func newOpenAIClient(opts providerClientOptions) OpenAIClient { + openaiOpts := openaiOptions{} + for _, o := range opts.openaiOptions { + o(&openaiOpts) } -} -func WithOpenAIModel(model models.Model) OpenAIOption { - return func(p *openaiProvider) { - p.model = model + openaiClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(p *openaiProvider) { - p.baseURL = baseURL + if openaiOpts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) } -} -func WithOpenAIKey(apiKey string) OpenAIOption { - return func(p *openaiProvider) { - p.apiKey = apiKey + client := openai.NewClient(openaiClientOptions...) + return &openaiClient{ + providerOptions: opts, + options: openaiOpts, + client: client, } } -func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion { - var chatMessages []openai.ChatCompletionMessageParamUnion - - chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage)) +func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) for _, msg := range messages { switch msg.Role { case message.User: - chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String())) + openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String())) case message.Assistant: assistantMsg := openai.ChatCompletionAssistantMessageParam{ @@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o } } - chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{ + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ OfAssistant: &assistantMsg, }) case message.Tool: for _, result := range msg.ToolResults() { - chatMessages = append(chatMessages, + openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID), ) } } } - return chatMessages + return } -func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { +func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) for i, tool := range tools { @@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C return openaiTools } -func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage { - cachedTokens := int64(0) - - cachedTokens = usage.PromptTokensDetails.CachedTokens - inputTokens := usage.PromptTokens - cachedTokens - - return TokenUsage{ - InputTokens: inputTokens, - OutputTokens: usage.CompletionTokens, - CacheCreationTokens: 0, // OpenAI doesn't provide this directly - CacheReadTokens: cachedTokens, +func (o *openaiClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown } } -func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - } - - response, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, err +func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + MaxTokens: openai.Int(o.providerOptions.maxTokens), + Tools: tools, } +} - content := "" - if response.Choices[0].Message.Content != "" { - content = response.Choices[0].Message.Content +func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - - var toolCalls []message.ToolCall - if len(response.Choices[0].Message.ToolCalls) > 0 { - toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls)) - for i, call := range response.Choices[0].Message.ToolCalls { - toolCalls[i] = message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", + attempts := 0 + for { + attempts++ + openaiResponse, err := o.client.Chat.Completions.New( + ctx, + params, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(response.Usage) + content := "" + if openaiResponse.Choices[0].Message.Content != "" { + content = openaiResponse.Choices[0].Message.Content + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: o.toolCalls(*openaiResponse), + Usage: o.usage(*openaiResponse), + FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)), + }, nil + } } -func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - StreamOptions: openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - }, +func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) - toolCalls := make([]message.ToolCall, 0) go func() { - defer close(eventChan) - - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if tool, ok := acc.JustFinishedToolCall(); ok { - toolCalls = append(toolCalls, message.ToolCall{ - ID: tool.Id, - Name: tool.Name, - Input: tool.Arguments, - Type: "function", - }) - } + for { + attempts++ + openaiStream := o.client.Chat.Completions.NewStreaming( + ctx, + params, + ) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + + for openaiStream.Next() { + chunk := openaiStream.Current() + acc.AddChunk(chunk) + + if tool, ok := acc.JustFinishedToolCall(); ok { + toolCalls = append(toolCalls, message.ToolCall{ + ID: tool.Id, + Name: tool.Name, + Input: tool.Arguments, + Type: "function", + }) + } - for _, choice := range chunk.Choices { - if choice.Delta.Content != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: choice.Delta.Content, + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } + currentContent += choice.Delta.Content } - currentContent += choice.Delta.Content } } - } - if err := stream.Err(); err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + err := openaiStream.Err() + if err == nil || errors.Is(err, io.EOF) { + // Stream completed successfully + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: o.usage(acc.ChatCompletion), + FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)), + }, + } + close(eventChan) + return } + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() == nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } + }() - tokenUsage := p.extractTokenUsage(acc.Usage) + return eventChan +} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, +func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *openai.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 } - }() + } + return true, int64(retryMs), nil +} - return eventChan, nil +func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + } + toolCalls = append(toolCalls, toolCall) + } + } + + return toolCalls } + +func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // OpenAI doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func WithOpenAIBaseURL(baseURL string) OpenAIOption { + return func(options *openaiOptions) { + options.baseURL = baseURL + } +} + +func WithOpenAIDisableCache() OpenAIOption { + return func(options *openaiOptions) { + options.disableCache = true + } +} + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 34d91f2b7..1a5b3dc8a 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -2,14 +2,17 @@ package provider import ( "context" + "fmt" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -// EventType represents the type of streaming event type EventType string +const maxRetries = 8 + const ( EventContentStart EventType = "content_start" EventContentDelta EventType = "content_delta" @@ -18,7 +21,6 @@ const ( EventComplete EventType = "complete" EventError EventType = "error" EventWarning EventType = "warning" - EventInfo EventType = "info" ) type TokenUsage struct { @@ -32,61 +34,152 @@ type ProviderResponse struct { Content string ToolCalls []message.ToolCall Usage TokenUsage - FinishReason string + FinishReason message.FinishReason } type ProviderEvent struct { - Type EventType + Type EventType + Content string Thinking string - ToolCall *message.ToolCall - Error error Response *ProviderResponse - // Used for giving users info on e.x retry - Info string + Error error } - type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) + StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() models.Model +} + +type providerClientOptions struct { + apiKey string + model models.Model + maxTokens int64 + systemMessage string + + anthropicOptions []AnthropicOption + openaiOptions []OpenAIOption + geminiOptions []GeminiOption + bedrockOptions []BedrockOption +} + +type ProviderClientOption func(*providerClientOptions) + +type ProviderClient interface { + send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent +} + +type baseProvider[C ProviderClient] struct { + options providerClientOptions + client C +} + +func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{} + for _, o := range opts { + o(&clientOptions) + } + switch providerName { + case models.ProviderAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions), + }, nil + case models.ProviderOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case models.ProviderGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case models.ProviderBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case models.ProviderMock: + // TODO: implement mock client for test + panic("not implemented") + } + return nil, fmt.Errorf("provider not supported: %s", providerName) } -func cleanupMessages(messages []message.Message) []message.Message { - // First pass: filter out canceled messages - var cleanedMessages []message.Message +func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { - if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { - // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been - // cancelled - cleanedMessages = append(cleanedMessages, msg) + // The message has no content + if len(msg.Parts) == 0 { + continue } + cleaned = append(cleaned, msg) } + return +} - // Second pass: filter out tool messages without a corresponding tool call - var result []message.Message - toolMessageIDs := make(map[string]bool) +func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = p.cleanMessages(messages) + return p.client.send(ctx, messages, tools) +} - for _, msg := range cleanedMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolMessageIDs[toolCall.ID] = true // Mark as referenced - } - } +func (p *baseProvider[C]) Model() models.Model { + return p.options.model +} + +func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = p.cleanMessages(messages) + return p.client.stream(ctx, messages, tools) +} + +func WithAPIKey(apiKey string) ProviderClientOption { + return func(options *providerClientOptions) { + options.apiKey = apiKey } +} - // Keep only messages that aren't unreferenced tool messages - for _, msg := range cleanedMessages { - if msg.Role == message.Tool { - for _, toolCall := range msg.ToolResults() { - if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced { - result = append(result, msg) - } - } - } else { - result = append(result, msg) - } +func WithModel(model models.Model) ProviderClientOption { + return func(options *providerClientOptions) { + options.model = model + } +} + +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func WithSystemMessage(systemMessage string) ProviderClientOption { + return func(options *providerClientOptions) { + options.systemMessage = systemMessage + } +} + +func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.anthropicOptions = anthropicOptions + } +} + +func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = openaiOptions + } +} + +func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.geminiOptions = geminiOptions + } +} + +func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.bedrockOptions = bedrockOptions } - return result } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0cea20878..c7c970e5a 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -23,7 +23,8 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - Took int64 `json:"took"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` } type bashTool struct { permissions permission.Service @@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } - took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - Took: took, + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), } if stdout == "" { return WithResponseMetadata(NewTextResponse("no output"), metadata), nil diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 97be3683a..dafb0ccc5 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) { }) } } - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 08d6d446c..148e7aba7 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -35,6 +36,7 @@ type EditResponseMetadata struct { type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } const ( @@ -88,10 +90,11 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return response, nil } + if response.IsError { + // Return early if there was an error during content replacement + // This prevents unnecessary LSP diagnostics processing + return response, nil + } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) text := fmt.Sprintf("\n%s\n\n", response.Content) @@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // File can't be in the history so we create a new file history + _, err = e.files.Create(ctx, sessionID, filePath, "") + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + // Add the new content to the file history + _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) + if err != nil { + // Log error but don't fail the operation + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string if err != nil { return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + if oldContent == newContent { + return NewTextErrorResponse("new content is the same as old content. No changes made."), nil + } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, - - Diff: diff, + Diff: diff, }, }, ) @@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 48a34ed75..0971775dd 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -14,7 +14,7 @@ import ( ) func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, EditToolName, info.Name) @@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "replace_content.txt") @@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "delete_content.txt") @@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: EditToolName, @@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := EditParams{ FilePath: "", @@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "non_existent_file.txt") params := EditParams{ @@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "content_not_found.txt") @@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file with duplicate content filePath := filepath.Join(tempDir, "duplicate_content.txt") @@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "not_read_file.txt") @@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "permission_denied.txt") diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 9c9707c9c..7f34fdc1f 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,8 +3,6 @@ package tools import ( "sync" "time" - - "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -19,14 +17,6 @@ var ( fileRecordMutex sync.RWMutex ) -func removeWorkingDirectoryPrefix(path string) string { - wd := config.WorkingDirectory() - if len(path) > len(wd) && path[:len(wd)] == wd { - return path[len(wd)+1:] - } - return path -} - func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index bdfc23b4a..7b4fb1187 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,7 +63,7 @@ type GlobParams struct { Path string `json:"path"` } -type GlobMetadata struct { +type GlobResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GlobMetadata{ + GlobResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7e52821d0..19333f50b 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,7 +27,7 @@ type grepMatch struct { modTime time.Time } -type GrepMetadata struct { +type GrepResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } @@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GrepMetadata{ + GrepResponseMetadata{ NumberOfMatches: len(matches), Truncated: truncated, }, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a679f261b..a63bf0eeb 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,7 +23,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -type LSMetadata struct { +type LSResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { return WithResponseMetadata( NewTextResponse(output), - LSMetadata{ + LSResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go new file mode 100644 index 000000000..321f09ac1 --- /dev/null +++ b/internal/llm/tools/mocks_test.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +// Mock permission service for testing +type mockPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + allow bool +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return m.allow +} + +func newMockPermissionService(allow bool) permission.Service { + return &mockPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } +} + +type mockFileHistoryService struct { + *pubsub.Broker[history.File] + files map[string]history.File // ID -> File + timeNow func() int64 +} + +// Create implements history.Service. +func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) +} + +// CreateVersion implements history.Service. +func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + var files []history.File + for _, file := range m.files { + if file.Path == path { + files = append(files, file) + } + } + + if len(files) == 0 { + // No previous versions, create initial + return m.Create(ctx, sessionID, path, content) + } + + // Sort files by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + // Get the latest version + latestFile := files[0] + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == history.InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return m.createWithVersion(ctx, sessionID, path, content, nextVersion) +} + +func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { + now := m.timeNow() + file := history.File{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + CreatedAt: now, + UpdatedAt: now, + } + + m.files[file.ID] = file + m.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +// Delete implements history.Service. +func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { + file, ok := m.files[id] + if !ok { + return fmt.Errorf("file not found: %s", id) + } + + delete(m.files, id) + m.Publish(pubsub.DeletedEvent, file) + return nil +} + +// DeleteSessionFiles implements history.Service. +func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := m.ListBySession(ctx, sessionID) + if err != nil { + return err + } + + for _, file := range files { + err = m.Delete(ctx, file.ID) + if err != nil { + return err + } + } + + return nil +} + +// Get implements history.Service. +func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { + file, ok := m.files[id] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", id) + } + return file, nil +} + +// GetByPathAndSession implements history.Service. +func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { + var latestFile history.File + var found bool + var latestTime int64 + + for _, file := range m.files { + if file.Path == path && file.SessionID == sessionID { + if !found || file.CreatedAt > latestTime { + latestFile = file + latestTime = file.CreatedAt + found = true + } + } + } + + if !found { + return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) + } + return latestFile, nil +} + +// ListBySession implements history.Service. +func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + var files []history.File + for _, file := range m.files { + if file.SessionID == sessionID { + files = append(files, file) + } + } + + // Sort by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + return files, nil +} + +// ListLatestSessionFiles implements history.Service. +func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + // Map to track the latest file for each path + latestFiles := make(map[string]history.File) + + for _, file := range m.files { + if file.SessionID == sessionID { + existing, ok := latestFiles[file.Path] + if !ok || file.CreatedAt > existing.CreatedAt { + latestFiles[file.Path] = file + } + } + } + + // Convert map to slice + var result []history.File + for _, file := range latestFiles { + result = append(result, file) + } + + // Sort by CreatedAt in descending order + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result, nil +} + +// Subscribe implements history.Service. +func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + return m.Broker.Subscribe(ctx) +} + +// Update implements history.Service. +func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + _, ok := m.files[file.ID] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", file.ID) + } + + file.UpdatedAt = m.timeNow() + m.files[file.ID] = file + m.Publish(pubsub.UpdatedEvent, file) + return file, nil +} + +func newMockFileHistoryService() history.Service { + return &mockFileHistoryService{ + Broker: pubsub.NewBroker[history.File](), + files: make(map[string]history.File), + timeNow: func() int64 { return time.Now().Unix() }, + } +} diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 64592f67d..4a776478a 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell { commandQueue: make(chan *commandExecution, 10), } - go shell.processCommands() + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() go func() { err := cmd.Wait() if err != nil { + // Log the error if needed } shell.isAlive = false close(shell.commandQueue) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 17bc610ea..a6f2c8afb 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,7 +18,7 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } -type SourcegraphMetadata struct { +type SourcegraphResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 07afe1363..bf0f8df0b 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -14,12 +14,17 @@ type ToolInfo struct { type toolResponseType string +type ( + sessionIDContextKey string + messageIDContextKey string +) + const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" - SessionIDContextKey = "session_id" - MessageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" ) type ToolResponse struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 889561d2a..bb49381fd 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -27,6 +28,7 @@ type WritePermissionsParams struct { type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } type WriteResponseMetadata struct { @@ -67,10 +69,11 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } + // Check if file exists in history + file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = w.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) waitForLspDiagnostics(ctx, filePath, w.lspClients) diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 50dafc14f..2264f36fb 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -14,7 +14,7 @@ import ( ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: WriteToolName, @@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: "", @@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") diff --git a/internal/logging/logger.go b/internal/logging/logger.go index b06391472..7ae2e7b87 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -1,6 +1,12 @@ package logging -import "log/slog" +import ( + "fmt" + "log/slog" + "os" + "runtime/debug" + "time" +) func Info(msg string, args ...any) { slog.Info(msg, args...) @@ -37,3 +43,36 @@ func ErrorPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) slog.Error(msg, args...) } + +// RecoverPanic is a common function to handle panics gracefully. +// It logs the error, creates a panic log file with stack trace, +// and executes an optional cleanup function before returning. +func RecoverPanic(name string, cleanup func()) { + if r := recover(); r != nil { + // Log the panic + ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r)) + + // Create a timestamped panic log file + timestamp := time.Now().Format("20060102-150405") + filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp) + + file, err := os.Create(filename) + if err != nil { + ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) + } else { + defer file.Close() + + // Write panic information and stack trace + fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r) + fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339)) + fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack()) + + InfoPersist(fmt.Sprintf("Panic details written to %s", filename)) + } + + // Execute cleanup function if provided + if cleanup != nil { + cleanup() + } + } +} diff --git a/internal/lsp/client.go b/internal/lsp/client.go index e2eedc4fc..0f03e7fcb 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -97,7 +97,12 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er }() // Start message handling loop - go client.handleMessages() + go func() { + defer logging.RecoverPanic("LSP-message-handler", func() { + logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired") + }) + client.handleMessages() + }() return client, nil } @@ -374,7 +379,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { }, } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { @@ -413,12 +418,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Warn("Error closing file", "file", filePath, "error", err) } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 4913c743d..c3088d685 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -88,7 +88,7 @@ func HandleServerMessage(params json.RawMessage) { Message string `json:"message"` } if err := json.Unmarshal(params, &msg); err == nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 4185966f3..89255fd78 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -20,7 +20,7 @@ func WriteMessage(w io.Writer, msg *Message) error { } cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } @@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } line = strings.TrimSpace(line) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received header", "line", line) } @@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Content-Length", "length", contentLength) } @@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { return nil, fmt.Errorf("failed to read content: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received content", "content", string(content)) } @@ -95,7 +95,7 @@ func (c *Client) handleMessages() { for { msg, err := ReadMessage(c.stdout) if err != nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Error("Error reading message", "error", err) } return @@ -103,7 +103,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } @@ -157,11 +157,11 @@ func (c *Client) handleMessages() { c.notificationMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for notification", "method", msg.Method) } continue @@ -174,12 +174,12 @@ func (c *Client) handleMessages() { c.handlersMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for response", "id", msg.ID) } } @@ -191,7 +191,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any cnf := config.Get() id := c.nextID.Add(1) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Making call", "method", method, "id", id) } @@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any return fmt.Errorf("failed to send request: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response", "id", id) } @@ -250,7 +250,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any // Notify sends a notification (a request without an ID that doesn't expect a response) func (c *Client) Notify(ctx context.Context, method string, params any) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending notification", "method", method) } diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index b5ef15710..156f38e1a 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -50,7 +50,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc w.registrations = append(w.registrations, watchers...) // Print detailed registration information for debugging - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), @@ -116,7 +116,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Skip directories that should be excluded if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -136,7 +136,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc }) elapsedTime := time.Since(startTime) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Workspace scan complete", "filesOpened", filesOpened, "elapsedTime", elapsedTime.Seconds(), @@ -144,7 +144,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) } - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Debug("Error scanning workspace for files to open", "error", err) } }() @@ -175,7 +175,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories (except workspace root) if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -228,7 +228,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str } // Debug logging - if cnf.Debug { + if cnf.DebugLSP { matched, kind := w.isPathWatched(event.Name) logging.Debug("File event", "path", event.Name, @@ -491,7 +491,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan // notifyFileEvent sends a didChangeWatchedFiles notification for a file event func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, @@ -615,7 +615,7 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), @@ -648,7 +648,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if this path should be watched according to server registrations if watched, _ := w.isPathWatched(path); watched { // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug { + if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { logging.Error("Error opening file", "path", path, "error", err) } } diff --git a/internal/message/content.go b/internal/message/content.go index 422c04f52..f9e76b11c 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -2,6 +2,7 @@ package message import ( "encoding/base64" + "slices" "time" "github.com/kujtimiihoxha/termai/internal/llm/models" @@ -16,6 +17,20 @@ const ( Tool MessageRole = "tool" ) +type FinishReason string + +const ( + FinishReasonEndTurn FinishReason = "end_turn" + FinishReasonMaxTokens FinishReason = "max_tokens" + FinishReasonToolUse FinishReason = "tool_use" + FinishReasonCanceled FinishReason = "canceled" + FinishReasonError FinishReason = "error" + FinishReasonPermissionDenied FinishReason = "permission_denied" + + // Should never happen + FinishReasonUnknown FinishReason = "unknown" +) + type ContentPart interface { isPart() } @@ -83,8 +98,8 @@ type ToolResult struct { func (ToolResult) isPart() {} type Finish struct { - Reason string `json:"reason"` - Time int64 `json:"time"` + Reason FinishReason `json:"reason"` + Time int64 `json:"time"` } func (Finish) isPart() {} @@ -176,7 +191,7 @@ func (m *Message) FinishPart() *Finish { return nil } -func (m *Message) FinishReason() string { +func (m *Message) FinishReason() FinishReason { for _, part := range m.Parts { if c, ok := part.(Finish); ok { return c.Reason @@ -246,7 +261,14 @@ func (m *Message) SetToolResults(tr []ToolResult) { } } -func (m *Message) AddFinish(reason string) { +func (m *Message) AddFinish(reason FinishReason) { + // remove any existing finish part + for i, part := range m.Parts { + if _, ok := part.(Finish); ok { + m.Parts = slices.Delete(m.Parts, i, i+1) + break + } + } m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()}) } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 633a6d57f..3e70ae095 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -5,7 +5,7 @@ import ( "sync" ) -const bufferSize = 1024 * 1024 +const bufferSize = 1024 type Logger interface { Debug(msg string, args ...any) diff --git a/internal/session/session.go b/internal/session/session.go index 9a16224c3..019019df4 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -24,6 +24,7 @@ type Session struct { type Service interface { pubsub.Suscriber[Session] Create(ctx context.Context, title string) (Session, error) + CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) Get(ctx context.Context, id string) (Session, error) List(ctx context.Context) ([]Session, error) @@ -63,6 +64,20 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi return session, nil } +func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ + ID: "title-" + parentSessionID, + ParentSessionID: sql.NullString{String: parentSessionID, Valid: true}, + Title: "Generate a title", + }) + if err != nil { + return Session{}, err + } + session := s.fromDBItem(dbSession) + s.Publish(pubsub.CreatedEvent, session) + return session, nil +} + func (s *service) Delete(ctx context.Context, id string) error { session, err := s.Get(ctx, id) if err != nil { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e893ec2f5..e98001efa 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -19,8 +19,6 @@ type SessionSelectedMsg = session.Session type SessionClearedMsg struct{} -type AgentWorkingMsg bool - type EditorFocusMsg bool func lspsConfigured(width int) string { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e87f1ffae..e2f4da9e2 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,14 +5,17 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) type editorCmp struct { - textarea textarea.Model - agentWorking bool + app *app.App + session session.Session + textarea textarea.Model } type focusedEditorKeyMaps struct { @@ -32,7 +35,7 @@ var focusedKeyMaps = focusedEditorKeyMaps{ ), Blur: key.NewBinding( key.WithKeys("esc"), - key.WithHelp("esc", "blur editor"), + key.WithHelp("esc", "focus messages"), ), } @@ -52,7 +55,7 @@ func (m *editorCmp) Init() tea.Cmd { } func (m *editorCmp) send() tea.Cmd { - if m.agentWorking { + if m.app.CoderAgent.IsSessionBusy(m.session.ID) { return util.ReportWarn("Agent is working, please wait...") } @@ -66,7 +69,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(AgentWorkingMsg(true)), util.CmdHandler(EditorFocusMsg(false)), ) } @@ -74,8 +76,11 @@ func (m *editorCmp) send() tea.Cmd { func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + m.session = msg + } + return m, nil case tea.KeyMsg: // if the key does not match any binding, return if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { @@ -122,7 +127,7 @@ func (m *editorCmp) BindingKeys() []key.Binding { return bindings } -func NewEditorCmp() tea.Model { +func NewEditorCmp(app *app.App) tea.Model { ti := textarea.New() ti.Prompt = " " ti.ShowLineNumbers = false @@ -138,6 +143,7 @@ func NewEditorCmp() tea.Model { ti.CharLimit = -1 ti.Focus() return &editorCmp{ + app: app, textarea: ti, } } diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index dc21fca29..26a98970e 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -6,7 +6,9 @@ import ( "fmt" "math" "strings" + "time" + "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" @@ -17,9 +19,11 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) @@ -32,6 +36,9 @@ const ( toolMessageType ) +// messagesTickMsg is a message sent by the timer to refresh messages +type messagesTickMsg time.Time + type uiMessage struct { ID string messageType uiMessageType @@ -52,24 +59,34 @@ type messagesCmp struct { renderer *glamour.TermRenderer focusRenderer *glamour.TermRenderer cachedContent map[string]string - agentWorking bool spinner spinner.Model needsRerender bool - lastViewport string } func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init()) + return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages()) +} + +func (m *messagesCmp) tickMessages() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return messagesTickMsg(t) + }) } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) - if m.agentWorking { - cmds = append(cmds, m.spinner.Tick) + case messagesTickMsg: + // Refresh messages if we have an active session + if m.session.ID != "" { + messages, err := m.app.Messages.List(context.Background(), m.session.ID) + if err == nil { + m.messages = messages + m.needsRerender = true + } } + // Continue ticking + cmds = append(cmds, m.tickMessages()) case EditorFocusMsg: m.writingMode = bool(msg) case SessionSelectedMsg: @@ -84,6 +101,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.messages = make([]message.Message, 0) m.currentMsgID = "" m.needsRerender = true + m.cachedContent = make(map[string]string) return m, nil case tea.KeyMsg: @@ -104,6 +122,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if !messageExists { + // If we have messages, ensure the previous last message is not cached + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + m.messages = append(m.messages, msg.Payload) delete(m.cachedContent, m.currentMsgID) m.currentMsgID = msg.Payload.ID @@ -112,36 +136,40 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } for _, v := range m.messages { for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called if c.ID == msg.Payload.SessionID { m.needsRerender = true } } } } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + logging.Debug("Message", "finish", msg.Payload.FinishReason()) for i, v := range m.messages { if v.ID == msg.Payload.ID { - if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" { - cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false))) - } m.messages[i] = msg.Payload delete(m.cachedContent, msg.Payload.ID) + + // If this is the last message, ensure it's not cached + if i == len(m.messages)-1 { + delete(m.cachedContent, msg.Payload.ID) + } + m.needsRerender = true break } } } } - if m.agentWorking { - u, cmd := m.spinner.Update(msg) - m.spinner = u - cmds = append(cmds, cmd) - } + oldPos := m.viewport.YPosition u, cmd := m.viewport.Update(msg) m.viewport = u m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos cmds = append(cmds, cmd) + + spinner, cmd := m.spinner.Update(msg) + m.spinner = spinner + cmds = append(cmds, cmd) + if m.needsRerender { m.renderView() if len(m.messages) > 0 { @@ -157,10 +185,21 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } +func (m *messagesCmp) IsAgentWorking() bool { + return m.app.CoderAgent.IsSessionBusy(m.session.ID) +} + func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { - if v, ok := m.cachedContent[msg.ID]; ok { - return v + // Check if this is the last message in the list + isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID + + // Only use cache for non-last messages + if !isLastMessage { + if v, ok := m.cachedContent[msg.ID]; ok { + return v + } } + style := styles.BaseStyle. Width(m.width). BorderLeft(true). @@ -191,7 +230,12 @@ func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) s parts..., ), ) - m.cachedContent[msg.ID] = rendered + + // Only cache if it's not the last message + if !isLastMessage { + m.cachedContent[msg.ID] = rendered + } + return rendered } @@ -207,32 +251,71 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string { return fmt.Sprintf("%dm%ds", minutes, seconds) } +func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult { + for _, v := range m.messages { + for _, c := range v.ToolResults() { + if c.ToolCallID == callID { + return &c + } + } + } + return nil +} + func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { key := "" value := "" + result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...") + + response := m.findToolResponse(toolCall.ID) + if response != nil && response.IsError { + // Clean up error message for display by removing newlines + // This ensures error messages display properly in the UI + errMsg := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "...")) + } else if response != nil { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done") + } switch toolCall.Name { // TODO: add result data to the tools case agent.AgentToolName: key = "Task" var params agent.AgentParams json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Prompt - // TODO: handle nested calls + value = strings.ReplaceAll(params.Prompt, "\n", " ") + if response != nil && !response.IsError { + firstRow := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "...")) + } case tools.BashToolName: key = "Bash" var params tools.BashParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Command + if response != nil && !response.IsError { + metadata := tools.BashResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime))) + } + case tools.EditToolName: key = "Edit" var params tools.EditParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.EditResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } case tools.FetchToolName: key = "Fetch" var params tools.FetchParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.URL + if response != nil && !response.IsError { + result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content) + } case tools.GlobToolName: key = "Glob" var params tools.GlobParams @@ -241,6 +324,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GlobResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.GrepToolName: key = "Grep" var params tools.GrepParams @@ -249,19 +341,46 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GrepResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches)) + } + } case tools.LSToolName: - key = "Ls" + key = "ls" var params tools.LSParams json.Unmarshal([]byte(toolCall.Input), ¶ms) if params.Path == "" { params.Path = "." } value = params.Path + if response != nil && !response.IsError { + metadata := tools.LSResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.SourcegraphToolName: key = "Sourcegraph" var params tools.SourcegraphParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Query + if response != nil && !response.IsError { + metadata := tools.SourcegraphResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches)) + } + } case tools.ViewToolName: key = "View" var params tools.ViewParams @@ -272,6 +391,12 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s var params tools.WriteParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.WriteResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } default: key = toolCall.Name var params map[string]any @@ -300,14 +425,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s ) if !isNested { value = valyeStyle. - Width(m.width - lipgloss.Width(keyValye) - 2). Render( ansi.Truncate( - value, - m.width-lipgloss.Width(keyValye)-2, + value+" ", + m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result), "...", ), ) + value += result + } else { keyValye = keyStyle.Render( fmt.Sprintf(" └ %s: ", key), @@ -409,6 +535,27 @@ func (m *messagesCmp) renderView() { m.uiMessages = make([]uiMessage, 0) pos := 0 + // If we have messages, ensure the last message is not cached + // This ensures we always render the latest content for the most recent message + // which may be actively updating (e.g., during generation) + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + + // Limit cache to 10 messages + if len(m.cachedContent) > 15 { + // Create a list of keys to delete (oldest messages first) + keys := make([]string, 0, len(m.cachedContent)) + for k := range m.cachedContent { + keys = append(keys, k) + } + // Delete oldest messages until we have 10 or fewer + for i := 0; i < len(keys)-15; i++ { + delete(m.cachedContent, keys[i]) + } + } + for _, v := range m.messages { switch v.Role { case message.User: @@ -487,7 +634,7 @@ func (m *messagesCmp) View() string { func (m *messagesCmp) help() string { text := "" - if m.agentWorking { + if m.IsAgentWorking() { text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), ) @@ -562,9 +709,15 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { m.messages = messages m.currentMsgID = m.messages[len(m.messages)-1].ID m.needsRerender = true + m.cachedContent = make(map[string]string) return nil } +func (m *messagesCmp) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(m.viewport.KeyMap) + return bindings +} + func NewMessagesCmp(app *app.App) tea.Model { focusRenderer, _ := glamour.NewTermRenderer( glamour.WithStyles(styles.MarkdownTheme(true)), diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 51192cf9a..b90269d1a 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -1,10 +1,15 @@ package chat import ( + "context" "fmt" + "strings" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/styles" @@ -13,9 +18,33 @@ import ( type sidebarCmp struct { width, height int session session.Session + history history.Service + modFiles map[string]struct { + additions int + removals int + } } func (m *sidebarCmp) Init() tea.Cmd { + if m.history != nil { + ctx := context.Background() + // Subscribe to file events + filesCh := m.history.Subscribe(ctx) + + // Initialize the modified files map + m.modFiles = make(map[string]struct { + additions int + removals int + }) + + // Load initial files and calculate diffs + m.loadModifiedFiles(ctx) + + // Return a command that will send file events to the Update method + return func() tea.Msg { + return <-filesCh + } + } return nil } @@ -27,6 +56,13 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.Payload } } + case pubsub.Event[history.File]: + if msg.Payload.SessionID == m.session.ID { + // When a file changes, reload all modified files + // This ensures we have the complete and accurate list + ctx := context.Background() + m.loadModifiedFiles(ctx) + } } return m, nil } @@ -86,18 +122,28 @@ func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) stri func (m *sidebarCmp) modifiedFiles() string { modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render("Modified Files:") - files := []struct { - path string - additions int - removals int - }{ - {"file1.txt", 10, 5}, - {"file2.txt", 20, 0}, - {"file3.txt", 0, 15}, + + // If no modified files, show a placeholder message + if m.modFiles == nil || len(m.modFiles) == 0 { + message := "No modified files" + remainingWidth := m.width - lipgloss.Width(modifiedFiles) + if remainingWidth > 0 { + message += strings.Repeat(" ", remainingWidth) + } + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + modifiedFiles, + styles.BaseStyle.Foreground(styles.ForgroundDim).Render(message), + ), + ) } + var fileViews []string - for _, file := range files { - fileViews = append(fileViews, m.modifiedFile(file.path, file.additions, file.removals)) + for path, stats := range m.modFiles { + fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals)) } return styles.BaseStyle. @@ -123,8 +169,116 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func NewSidebarCmp(session session.Session) tea.Model { +func NewSidebarCmp(session session.Session, history history.Service) tea.Model { return &sidebarCmp{ session: session, + history: history, + } +} + +func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { + if m.history == nil || m.session.ID == "" { + return + } + + // Get all latest files for this session + latestFiles, err := m.history.ListLatestSessionFiles(ctx, m.session.ID) + if err != nil { + return + } + + // Get all files for this session (to find initial versions) + allFiles, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Process each latest file + for _, file := range latestFiles { + // Skip if this is the initial version (no changes to show) + if file.Version == history.InitialVersion { + continue + } + + // Find the initial version for this specific file + var initialVersion history.File + for _, v := range allFiles { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + continue + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } + } + } +} + +func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) { + // Skip if not the latest version + if file.Version == history.InitialVersion { + return + } + + // Get all versions of this file + fileVersions, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Find the initial version + var initialVersion history.File + for _, v := range fileVersions { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + return + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } } } diff --git a/internal/tui/components/core/dialog.go b/internal/tui/components/core/dialog.go deleted file mode 100644 index a8fef2e86..000000000 --- a/internal/tui/components/core/dialog.go +++ /dev/null @@ -1,117 +0,0 @@ -package core - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SizeableModel interface { - tea.Model - layout.Sizeable -} - -type DialogMsg struct { - Content SizeableModel - WidthRatio float64 - HeightRatio float64 - - MinWidth int - MinHeight int -} - -type DialogCloseMsg struct{} - -type KeyBindings struct { - Return key.Binding -} - -var keys = KeyBindings{ - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), -} - -type DialogCmp interface { - tea.Model - layout.Bindings -} - -type dialogCmp struct { - content SizeableModel - screenWidth int - screenHeight int - - widthRatio float64 - heightRatio float64 - - minWidth int - minHeight int - - width int - height int -} - -func (d *dialogCmp) Init() tea.Cmd { - return nil -} - -func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - d.screenWidth = msg.Width - d.screenHeight = msg.Height - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - return d, nil - case DialogMsg: - d.content = msg.Content - d.widthRatio = msg.WidthRatio - d.heightRatio = msg.HeightRatio - d.minWidth = msg.MinWidth - d.minHeight = msg.MinHeight - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - case DialogCloseMsg: - d.content = nil - return d, nil - case tea.KeyMsg: - if key.Matches(msg, keys.Return) { - return d, util.CmdHandler(DialogCloseMsg{}) - } - } - if d.content != nil { - u, cmd := d.content.Update(msg) - d.content = u.(SizeableModel) - return d, cmd - } - return d, nil -} - -func (d *dialogCmp) BindingKeys() []key.Binding { - bindings := []key.Binding{keys.Return} - if d.content == nil { - return bindings - } - if c, ok := d.content.(layout.Bindings); ok { - return append(bindings, c.BindingKeys()...) - } - return bindings -} - -func (d *dialogCmp) View() string { - return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View()) -} - -func NewDialogCmp() DialogCmp { - return &dialogCmp{} -} diff --git a/internal/tui/components/core/help.go b/internal/tui/components/core/help.go deleted file mode 100644 index 4ef857c78..000000000 --- a/internal/tui/components/core/help.go +++ /dev/null @@ -1,119 +0,0 @@ -package core - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type HelpCmp interface { - tea.Model - SetBindings(bindings []key.Binding) - Height() int -} - -const ( - helpWidgetHeight = 12 -) - -type helpCmp struct { - width int - bindings []key.Binding -} - -func (h *helpCmp) Init() tea.Cmd { - return nil -} - -func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - h.width = msg.Width - } - return h, nil -} - -func (h *helpCmp) View() string { - helpKeyStyle := styles.Bold.Foreground(styles.Rosewater).Margin(0, 1, 0, 0) - helpDescStyle := styles.Regular.Foreground(styles.Flamingo) - // Compile list of bindings to render - bindings := removeDuplicateBindings(h.bindings) - // Enumerate through each group of bindings, populating a series of - // pairs of columns, one for keys, one for descriptions - var ( - pairs []string - width int - rows = helpWidgetHeight - 2 - ) - for i := 0; i < len(bindings); i += rows { - var ( - keys []string - descs []string - ) - for j := i; j < min(i+rows, len(bindings)); j++ { - keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) - descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) - } - // Render pair of columns; beyond the first pair, render a three space - // left margin, in order to visually separate the pairs. - var cols []string - if len(pairs) > 0 { - cols = []string{" "} - } - cols = append(cols, - strings.Join(keys, "\n"), - strings.Join(descs, "\n"), - ) - - pair := lipgloss.JoinHorizontal(lipgloss.Top, cols...) - // check whether it exceeds the maximum width avail (the width of the - // terminal, subtracting 2 for the borders). - width += lipgloss.Width(pair) - if width > h.width-2 { - break - } - pairs = append(pairs, pair) - } - - // Join pairs of columns and enclose in a border - content := lipgloss.JoinHorizontal(lipgloss.Top, pairs...) - return styles.DoubleBorder.Height(rows).PaddingLeft(1).Width(h.width - 2).Render(content) -} - -func removeDuplicateBindings(bindings []key.Binding) []key.Binding { - seen := make(map[string]struct{}) - result := make([]key.Binding, 0, len(bindings)) - - // Process bindings in reverse order - for i := len(bindings) - 1; i >= 0; i-- { - b := bindings[i] - k := strings.Join(b.Keys(), " ") - if _, ok := seen[k]; ok { - // duplicate, skip - continue - } - seen[k] = struct{}{} - // Add to the beginning of result to maintain original order - result = append([]key.Binding{b}, result...) - } - - return result -} - -func (h *helpCmp) SetBindings(bindings []key.Binding) { - h.bindings = bindings -} - -func (h helpCmp) Height() int { - return helpWidgetHeight -} - -func NewHelpCmp() HelpCmp { - return &helpCmp{ - width: 0, - bindings: make([]key.Binding, 0), - } -} diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 93ba34507..089dffa2c 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -1,21 +1,25 @@ package core import ( + "fmt" + "strings" "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/termai/internal/version" ) type statusCmp struct { info util.InfoMsg width int messageTTL time.Duration + lspClients map[string]*lsp.Client } // clearMessageCmd is a command that clears status messages after a timeout @@ -47,20 +51,18 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } -var ( - versionWidget = styles.Padded.Background(styles.DarkGrey).Foreground(styles.Text).Render(version.Version) - helpWidget = styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") -) +var helpWidget = styles.Padded.Background(styles.ForgroundMid).Foreground(styles.BackgroundDarker).Bold(true).Render("ctrl+? help") func (m statusCmp) View() string { - status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") + status := helpWidget + diagnostics := styles.Padded.Background(styles.BackgroundDarker).Render(m.projectDiagnostics()) if m.info.Msg != "" { infoStyle := styles.Padded. Foreground(styles.Base). - Width(m.availableFooterMsgWidth()) + Width(m.availableFooterMsgWidth(diagnostics)) switch m.info.Type { case util.InfoTypeInfo: - infoStyle = infoStyle.Background(styles.Blue) + infoStyle = infoStyle.Background(styles.BorderColor) case util.InfoTypeWarn: infoStyle = infoStyle.Background(styles.Peach) case util.InfoTypeError: @@ -68,7 +70,7 @@ func (m statusCmp) View() string { } // Truncate message if it's longer than available width msg := m.info.Msg - availWidth := m.availableFooterMsgWidth() - 10 + availWidth := m.availableFooterMsgWidth(diagnostics) - 10 if len(msg) > availWidth && availWidth > 0 { msg = msg[:availWidth] + "..." } @@ -76,27 +78,81 @@ func (m statusCmp) View() string { } else { status += styles.Padded. Foreground(styles.Base). - Background(styles.LightGrey). - Width(m.availableFooterMsgWidth()). + Background(styles.BackgroundDim). + Width(m.availableFooterMsgWidth(diagnostics)). Render("") } + status += diagnostics status += m.model() - status += versionWidget return status } -func (m statusCmp) availableFooterMsgWidth() int { - // -2 to accommodate padding - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model())) +func (m *statusCmp) projectDiagnostics() string { + errorDiagnostics := []protocol.Diagnostic{} + warnDiagnostics := []protocol.Diagnostic{} + hintDiagnostics := []protocol.Diagnostic{} + infoDiagnostics := []protocol.Diagnostic{} + for _, client := range m.lspClients { + for _, d := range client.GetDiagnostics() { + for _, diag := range d { + switch diag.Severity { + case protocol.SeverityError: + errorDiagnostics = append(errorDiagnostics, diag) + case protocol.SeverityWarning: + warnDiagnostics = append(warnDiagnostics, diag) + case protocol.SeverityHint: + hintDiagnostics = append(hintDiagnostics, diag) + case protocol.SeverityInformation: + infoDiagnostics = append(infoDiagnostics, diag) + } + } + } + } + + if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { + return "No diagnostics" + } + + diagnostics := []string{} + + if len(errorDiagnostics) > 0 { + errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) + diagnostics = append(diagnostics, errStr) + } + if len(warnDiagnostics) > 0 { + warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) + diagnostics = append(diagnostics, warnStr) + } + if len(hintDiagnostics) > 0 { + hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) + diagnostics = append(diagnostics, hintStr) + } + if len(infoDiagnostics) > 0 { + infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) + diagnostics = append(diagnostics, infoStr) + } + + return strings.Join(diagnostics, " ") +} + +func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)) } func (m statusCmp) model() string { - model := models.SupportedModels[config.Get().Model.Coder] + cfg := config.Get() + + coder, ok := cfg.Agents[config.AgentCoder] + if !ok { + return "Unknown" + } + model := models.SupportedModels[coder.Model] return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name) } -func NewStatusCmp() tea.Model { +func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model { return &statusCmp{ messageTTL: 10 * time.Second, + lspClients: lspClients, } } diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go new file mode 100644 index 000000000..1d3c2b077 --- /dev/null +++ b/internal/tui/components/dialog/help.go @@ -0,0 +1,182 @@ +package dialog + +import ( + "strings" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type helpCmp struct { + width int + height int + keys []key.Binding +} + +func (h *helpCmp) Init() tea.Cmd { + return nil +} + +func (h *helpCmp) SetBindings(k []key.Binding) { + h.keys = k +} + +func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + h.width = 80 + h.height = msg.Height + } + return h, nil +} + +func removeDuplicateBindings(bindings []key.Binding) []key.Binding { + seen := make(map[string]struct{}) + result := make([]key.Binding, 0, len(bindings)) + + // Process bindings in reverse order + for i := len(bindings) - 1; i >= 0; i-- { + b := bindings[i] + k := strings.Join(b.Keys(), " ") + if _, ok := seen[k]; ok { + // duplicate, skip + continue + } + seen[k] = struct{}{} + // Add to the beginning of result to maintain original order + result = append([]key.Binding{b}, result...) + } + + return result +} + +func (h *helpCmp) render() string { + helpKeyStyle := styles.Bold.Background(styles.Background).Foreground(styles.Forground).Padding(0, 1, 0, 0) + helpDescStyle := styles.Regular.Background(styles.Background).Foreground(styles.ForgroundMid) + // Compile list of bindings to render + bindings := removeDuplicateBindings(h.keys) + // Enumerate through each group of bindings, populating a series of + // pairs of columns, one for keys, one for descriptions + var ( + pairs []string + width int + rows = 12 - 2 + ) + for i := 0; i < len(bindings); i += rows { + var ( + keys []string + descs []string + ) + for j := i; j < min(i+rows, len(bindings)); j++ { + keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) + descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) + } + // Render pair of columns; beyond the first pair, render a three space + // left margin, in order to visually separate the pairs. + var cols []string + if len(pairs) > 0 { + cols = []string{styles.BaseStyle.Render(" ")} + } + + maxDescWidth := 0 + for _, desc := range descs { + if maxDescWidth < lipgloss.Width(desc) { + maxDescWidth = lipgloss.Width(desc) + } + } + for i := range descs { + remainingWidth := maxDescWidth - lipgloss.Width(descs[i]) + if remainingWidth > 0 { + descs[i] = descs[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + maxKeyWidth := 0 + for _, key := range keys { + if maxKeyWidth < lipgloss.Width(key) { + maxKeyWidth = lipgloss.Width(key) + } + } + for i := range keys { + remainingWidth := maxKeyWidth - lipgloss.Width(keys[i]) + if remainingWidth > 0 { + keys[i] = keys[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + + cols = append(cols, + strings.Join(keys, "\n"), + strings.Join(descs, "\n"), + ) + + pair := styles.BaseStyle.Render(lipgloss.JoinHorizontal(lipgloss.Top, cols...)) + // check whether it exceeds the maximum width avail (the width of the + // terminal, subtracting 2 for the borders). + width += lipgloss.Width(pair) + if width > h.width-2 { + break + } + pairs = append(pairs, pair) + } + + // https://github.com/charmbracelet/lipgloss/issues/209 + if len(pairs) > 1 { + prefix := pairs[:len(pairs)-1] + lastPair := pairs[len(pairs)-1] + prefix = append(prefix, lipgloss.Place( + lipgloss.Width(lastPair), // width + lipgloss.Height(prefix[0]), // height + lipgloss.Left, // x + lipgloss.Top, // y + lastPair, // content + lipgloss.WithWhitespaceBackground(styles.Background), // background + )) + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + prefix..., + ), + ) + return content + } + // Join pairs of columns and enclose in a border + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + pairs..., + ), + ) + return content +} + +func (h *helpCmp) View() string { + content := h.render() + header := styles.BaseStyle. + Bold(true). + Width(lipgloss.Width(content)). + Foreground(styles.PrimaryColor). + Render("Keyboard Shortcuts") + + return styles.BaseStyle.Padding(1). + Border(lipgloss.RoundedBorder()). + BorderForeground(styles.ForgroundDim). + Width(h.width). + BorderBackground(styles.Background). + Render( + lipgloss.JoinVertical(lipgloss.Center, + header, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(header))), + content, + ), + ) +} + +type HelpCmp interface { + tea.Model + SetBindings([]key.Binding) +} + +func NewHelpCmp() HelpCmp { + return &helpCmp{} +} diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index d147f89cd..9c55effde 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -12,12 +12,9 @@ import ( "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) type PermissionAction string @@ -35,69 +32,64 @@ type PermissionResponseMsg struct { Action PermissionAction } -// PermissionDialog interface for permission dialog component -type PermissionDialog interface { +// PermissionDialogCmp interface for permission dialog component +type PermissionDialogCmp interface { tea.Model - layout.Sizeable layout.Bindings + SetPermissions(permission permission.PermissionRequest) } -type keyMap struct { - ChangeFocus key.Binding +type permissionsMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Allow key.Binding + AllowSession key.Binding + Deny key.Binding + Tab key.Binding } -var keyMapValue = keyMap{ - ChangeFocus: key.NewBinding( +var permissionsKeys = permissionsMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Allow: key.NewBinding( + key.WithKeys("a"), + key.WithHelp("a", "allow"), + ), + AllowSession: key.NewBinding( + key.WithKeys("A"), + key.WithHelp("A", "allow for session"), + ), + Deny: key.NewBinding( + key.WithKeys("d"), + key.WithHelp("d", "deny"), + ), + Tab: key.NewBinding( key.WithKeys("tab"), - key.WithHelp("tab", "change focus"), + key.WithHelp("tab", "switch options"), ), } // permissionDialogCmp is the implementation of PermissionDialog type permissionDialogCmp struct { - form *huh.Form width int height int permission permission.PermissionRequest windowSize tea.WindowSizeMsg - r *glamour.TermRenderer contentViewPort viewport.Model - isViewportFocus bool - selectOption *huh.Select[string] -} + selectedOption int // 0: Allow, 1: Allow for session, 2: Deny -// formatDiff formats a diff string with colors for additions and deletions -func formatDiff(diffText string) string { - lines := strings.Split(diffText, "\n") - var formattedLines []string - - // Define styles for different line types - addStyle := lipgloss.NewStyle().Foreground(styles.Green) - removeStyle := lipgloss.NewStyle().Foreground(styles.Red) - headerStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Blue) - contextStyle := lipgloss.NewStyle().Foreground(styles.SubText0) - - // Process each line - for _, line := range lines { - if strings.HasPrefix(line, "+") { - formattedLines = append(formattedLines, addStyle.Render(line)) - } else if strings.HasPrefix(line, "-") { - formattedLines = append(formattedLines, removeStyle.Render(line)) - } else if strings.HasPrefix(line, "Changes:") || strings.HasPrefix(line, " ...") { - formattedLines = append(formattedLines, headerStyle.Render(line)) - } else if strings.HasPrefix(line, " ") { - formattedLines = append(formattedLines, contextStyle.Render(line)) - } else { - formattedLines = append(formattedLines, line) - } - } - - // Join all formatted lines - return strings.Join(formattedLines, "\n") + diffCache map[string]string + markdownCache map[string]string } func (p *permissionDialogCmp) Init() tea.Cmd { - return nil + return p.contentViewPort.Init() } func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -106,373 +98,363 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: p.windowSize = msg + p.SetSize() + p.markdownCache = make(map[string]string) + p.diffCache = make(map[string]string) case tea.KeyMsg: - if key.Matches(msg, keyMapValue.ChangeFocus) { - p.isViewportFocus = !p.isViewportFocus - if p.isViewportFocus { - p.selectOption.Blur() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.ReportInfo("Viewing content - use arrow keys to scroll"), - )) - } else { - p.selectOption.Focus() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.CmdHandler(util.ReportInfo("Select an action")), - )) - } - return p, tea.Batch(cmds...) - } - } - - if p.isViewportFocus { - viewPort, cmd := p.contentViewPort.Update(msg) - p.contentViewPort = viewPort - cmds = append(cmds, cmd) - } else { - form, cmd := p.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - p.form = f + switch { + case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): + // Change selected option + p.selectedOption = (p.selectedOption + 1) % 3 + return p, nil + case key.Matches(msg, permissionsKeys.EnterSpace): + // Select current option + return p, p.selectCurrentOption() + case key.Matches(msg, permissionsKeys.Allow): + // Select Allow + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.AllowSession): + // Select Allow for session + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.Deny): + // Select Deny + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) + default: + // Pass other keys to viewport + viewPort, cmd := p.contentViewPort.Update(msg) + p.contentViewPort = viewPort cmds = append(cmds, cmd) } - - if p.form.State == huh.StateCompleted { - // Get the selected action - action := p.form.GetString("action") - - // Close the dialog and return the response - return p, tea.Batch( - util.CmdHandler(core.DialogCloseMsg{}), - util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}), - ) - } } + return p, tea.Batch(cmds...) } -func (p *permissionDialogCmp) render() string { - keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater) - valueStyle := lipgloss.NewStyle().Foreground(styles.Peach) +func (p *permissionDialogCmp) selectCurrentOption() tea.Cmd { + var action PermissionAction - form := p.form.View() - - headerParts := []string{ - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)), - " ", - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)), - " ", + switch p.selectedOption { + case 0: + action = PermissionAllow + case 1: + action = PermissionAllowForSession + case 2: + action = PermissionDeny } - // Create the header content first so it can be used in all cases - headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(p.width-10), - glamour.WithEmoji(), - ) - - // Handle different tool types - switch p.permission.ToolName { - case tools.BashToolName: - pr := p.permission.Params.(tools.BashPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Command:")) - content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on content - contentLines := len(strings.Split(renderedContent, "\n")) - // Set a reasonable min/max for the viewport height - minContentHeight := 3 - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - - // Add some padding to the content lines - contentHeight := contentLines + 2 - contentHeight = max(contentHeight, minContentHeight) - contentHeight = min(contentHeight, maxContentHeight) - p.contentViewPort.Height = contentHeight - - p.contentViewPort.SetContent(renderedContent) + return util.CmdHandler(PermissionResponseMsg{Action: action, Permission: p.permission}) +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderButtons() string { + allowStyle := styles.BaseStyle + allowSessionStyle := styles.BaseStyle + denyStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + // Style the selected button + switch p.selectedOption { + case 0: + allowStyle = allowStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 1: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 2: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + } - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } + allowButton := allowStyle.Padding(0, 1).Render("Allow (a)") + allowSessionButton := allowSessionStyle.Padding(0, 1).Render("Allow for session (A)") + denyButton := denyStyle.Padding(0, 1).Render("Deny (d)") + + content := lipgloss.JoinHorizontal( + lipgloss.Left, + allowButton, + spacerStyle.Render(" "), + allowSessionButton, + spacerStyle.Render(" "), + denyButton, + spacerStyle.Render(" "), + ) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + remainingWidth := p.width - lipgloss.Width(content) + if remainingWidth > 0 { + content = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + content + } + return content +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderHeader() string { + toolKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Tool") + toolValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(toolKey)). + Render(fmt.Sprintf(": %s", p.permission.ToolName)) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + pathKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Path") + pathValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(pathKey)). + Render(fmt.Sprintf(": %s", p.permission.Path)) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + headerParts := []string{ + lipgloss.JoinHorizontal( + lipgloss.Left, + toolKey, + toolValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + lipgloss.JoinHorizontal( + lipgloss.Left, + pathKey, + pathValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + } + // Add tool-specific header information + switch p.permission.ToolName { + case tools.BashToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Command")) case tools.EditToolName: - pr := p.permission.Params.(tools.EditPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Update")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Format the diff with colors - - // Set up viewport for the diff content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } - p.contentViewPort.SetContent(diff) + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.WriteToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.FetchToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("URL")) + } - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor + return lipgloss.NewStyle().Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) +} - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } +func (p *permissionDialogCmp) renderBashContent() string { + if pr, ok := p.permission.Params.(tools.BashPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) + + finalContent := styles.BaseStyle. + Width(p.contentViewPort.Width). + Render(renderedContent) + p.contentViewPort.SetContent(finalContent) + return p.styleViewport() + } + return "" +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderEditContent() string { + if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok { + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) +func (p *permissionDialogCmp) renderWriteContent() string { + if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok { + // Use the cache for diff rendering + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - case tools.WriteToolName: - pr := p.permission.Params.(tools.WritePermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Content")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Set up viewport for the content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderFetchContent() string { + if pr, ok := p.permission.Params.(tools.FetchPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.URL) - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } - - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } + p.contentViewPort.SetContent(renderedContent) + return p.styleViewport() + } + return "" +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) +func (p *permissionDialogCmp) renderDefaultContent() string { + content := p.permission.Description - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.CatppuccinMarkdownStyle()), + glamour.WithWordWrap(p.width-10), ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - case tools.FetchToolName: - pr := p.permission.Params.(tools.FetchPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("URL: "+pr.URL)) - content := p.permission.Description + p.contentViewPort.SetContent(renderedContent) - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) + if renderedContent == "" { + return "" + } - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + return p.styleViewport() +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } +func (p *permissionDialogCmp) styleViewport() string { + contentStyle := lipgloss.NewStyle(). + Background(styles.Background) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + return contentStyle.Render(p.contentViewPort.View()) +} +func (p *permissionDialogCmp) render() string { + title := styles.BaseStyle. + Bold(true). + Width(p.width - 4). + Foreground(styles.PrimaryColor). + Render("Permission Required") + // Render header + headerContent := p.renderHeader() + // Render buttons + buttons := p.renderButtons() + + // Calculate content height dynamically based on window size + p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(buttons) - 2 - lipgloss.Height(title) + p.contentViewPort.Width = p.width - 4 + + // Render content based on tool type + var contentFinal string + switch p.permission.ToolName { + case tools.BashToolName: + contentFinal = p.renderBashContent() + case tools.EditToolName: + contentFinal = p.renderEditContent() + case tools.WriteToolName: + contentFinal = p.renderWriteContent() + case tools.FetchToolName: + contentFinal = p.renderFetchContent() default: - content := p.permission.Description - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) - - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + contentFinal = p.renderDefaultContent() + } - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } + content := lipgloss.JoinVertical( + lipgloss.Top, + title, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(title))), + headerContent, + contentFinal, + buttons, + ) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + return styles.BaseStyle. + Padding(1, 0, 0, 1). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(p.width). + Height(p.height). + Render( + content, ) - } } func (p *permissionDialogCmp) View() string { return p.render() } -func (p *permissionDialogCmp) GetSize() (int, int) { - return p.width, p.height +func (p *permissionDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(helpKeys) } -func (p *permissionDialogCmp) SetSize(width int, height int) { - p.width = width - p.height = height - p.form = p.form.WithWidth(width) +func (p *permissionDialogCmp) SetSize() { + if p.permission.ID == "" { + return + } + switch p.permission.ToolName { + case tools.BashToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + case tools.EditToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.WriteToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.FetchToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + default: + p.width = int(float64(p.windowSize.Width) * 0.7) + p.height = int(float64(p.windowSize.Height) * 0.5) + } } -func (p *permissionDialogCmp) BindingKeys() []key.Binding { - return p.form.KeyBinds() +func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) { + p.permission = permission + p.SetSize() } -func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog { - // Create a note field for displaying the content +// Helper to get or set cached diff content +func (c *permissionDialogCmp) GetOrSetDiff(key string, generator func() (string, error)) string { + if cached, ok := c.diffCache[key]; ok { + return cached + } - // Create select field for the permission options - selectOption := huh.NewSelect[string](). - Key("action"). - Options( - huh.NewOption("Allow", string(PermissionAllow)), - huh.NewOption("Allow for this session", string(PermissionAllowForSession)), - huh.NewOption("Deny", string(PermissionDeny)), - ). - Title("Select an action") + content, err := generator() + if err != nil { + return fmt.Sprintf("Error formatting diff: %v", err) + } - // Apply theme - theme := styles.HuhTheme() + c.diffCache[key] = content - // Setup form width and height - form := huh.NewForm(huh.NewGroup(selectOption)). - WithShowHelp(false). - WithTheme(theme). - WithShowErrors(false) + return content +} - // Focus the form for immediate interaction - selectOption.Focus() +// Helper to get or set cached markdown content +func (c *permissionDialogCmp) GetOrSetMarkdown(key string, generator func() (string, error)) string { + if cached, ok := c.markdownCache[key]; ok { + return cached + } - return &permissionDialogCmp{ - permission: permission, - form: form, - selectOption: selectOption, + content, err := generator() + if err != nil { + return fmt.Sprintf("Error rendering markdown: %v", err) } -} -// NewPermissionDialogCmd creates a new permission dialog command -func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd { - permDialog := newPermissionDialogCmp(permission) - - // Create the dialog layout - dialogPane := layout.NewSinglePane( - permDialog.(*permissionDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - layout.WithSinglePaneBorderText(map[layout.BorderPosition]string{ - layout.TopMiddleBorder: " Permission Required ", - }), - ) + c.markdownCache[key] = content - // Focus the dialog - dialogPane.Focus() - widthRatio := 0.7 - heightRatio := 0.6 - minWidth := 100 - minHeight := 30 + return content +} - // Make the dialog size more appropriate for different tools - switch permission.ToolName { - case tools.BashToolName: - // For bash commands, use a more compact dialog - widthRatio = 0.7 - heightRatio = 0.4 // Reduced from 0.5 - minWidth = 100 - minHeight = 20 // Reduced from 30 +func NewPermissionDialogCmp() PermissionDialogCmp { + // Create viewport for content + contentViewport := viewport.New(0, 0) + + return &permissionDialogCmp{ + contentViewPort: contentViewport, + selectedOption: 0, // Default to "Allow" + diffCache: make(map[string]string), + markdownCache: make(map[string]string), } - // Return the dialog command - return util.CmdHandler(core.DialogMsg{ - Content: dialogPane, - WidthRatio: widthRatio, - HeightRatio: heightRatio, - MinWidth: minWidth, - MinHeight: minHeight, - }) } diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 60c1fc0d2..10d9ba8a2 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -1,28 +1,58 @@ package dialog import ( + "strings" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" + "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) const question = "Are you sure you want to quit?" +type CloseQuitMsg struct{} + type QuitDialog interface { tea.Model - layout.Sizeable layout.Bindings } type quitDialogCmp struct { - form *huh.Form - width int - height int + selectedNo bool +} + +type helpMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Yes key.Binding + No key.Binding + Tab key.Binding +} + +var helpKeys = helpMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Yes: key.NewBinding( + key.WithKeys("y", "Y"), + key.WithHelp("y/Y", "yes"), + ), + No: key.NewBinding( + key.WithKeys("n", "N"), + key.WithHelp("n/N", "no"), + ), + Tab: key.NewBinding( + key.WithKeys("tab"), + key.WithHelp("tab", "switch options"), + ), } func (q *quitDialogCmp) Init() tea.Cmd { @@ -30,77 +60,73 @@ func (q *quitDialogCmp) Init() tea.Cmd { } func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - form, cmd := q.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - q.form = f - cmds = append(cmds, cmd) - } - - if q.form.State == huh.StateCompleted { - v := q.form.GetBool("quit") - if v { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, helpKeys.LeftRight) || key.Matches(msg, helpKeys.Tab): + q.selectedNo = !q.selectedNo + return q, nil + case key.Matches(msg, helpKeys.EnterSpace): + if !q.selectedNo { + return q, tea.Quit + } + return q, util.CmdHandler(CloseQuitMsg{}) + case key.Matches(msg, helpKeys.Yes): return q, tea.Quit + case key.Matches(msg, helpKeys.No): + return q, util.CmdHandler(CloseQuitMsg{}) } - cmds = append(cmds, util.CmdHandler(core.DialogCloseMsg{})) } - - return q, tea.Batch(cmds...) + return q, nil } func (q *quitDialogCmp) View() string { - return q.form.View() -} + yesStyle := styles.BaseStyle + noStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + if q.selectedNo { + noStyle = noStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + yesStyle = yesStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } else { + yesStyle = yesStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + noStyle = noStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } -func (q *quitDialogCmp) GetSize() (int, int) { - return q.width, q.height -} + yesButton := yesStyle.Padding(0, 1).Render("Yes") + noButton := noStyle.Padding(0, 1).Render("No") + + buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, spacerStyle.Render(" "), noButton) + + width := lipgloss.Width(question) + remainingWidth := width - lipgloss.Width(buttons) + if remainingWidth > 0 { + buttons = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + buttons + } -func (q *quitDialogCmp) SetSize(width int, height int) { - q.width = width - q.height = height - q.form = q.form.WithWidth(width).WithHeight(height) + content := styles.BaseStyle.Render( + lipgloss.JoinVertical( + lipgloss.Center, + question, + "", + buttons, + ), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) } func (q *quitDialogCmp) BindingKeys() []key.Binding { - return q.form.KeyBinds() + return layout.KeyMapToSlice(helpKeys) } -func newQuitDialogCmp() QuitDialog { - confirm := huh.NewConfirm(). - Title(question). - Affirmative("Yes!"). - Key("quit"). - Negative("No.") - - theme := styles.HuhTheme() - theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning) - theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning) - form := huh.NewForm(huh.NewGroup(confirm)). - WithShowHelp(false). - WithWidth(0). - WithHeight(0). - WithTheme(theme). - WithShowErrors(false) - confirm.Focus() +func NewQuitCmp() QuitDialog { return &quitDialogCmp{ - form: form, + selectedNo: true, } } - -func NewQuitDialogCmd() tea.Cmd { - content := layout.NewSinglePane( - newQuitDialogCmp().(*quitDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - ) - content.Focus() - return util.CmdHandler(core.DialogMsg{ - Content: content, - WidthRatio: 0.2, - HeightRatio: 0.1, - MinWidth: 40, - MinHeight: 5, - }) -} diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index dbace5508..18eb1a526 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -16,10 +16,8 @@ import ( type DetailComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type detailCmp struct { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 9500059b1..6e8eb58b1 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -16,22 +16,14 @@ import ( type TableComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type tableCmp struct { table table.Model } -func (i *tableCmp) BorderText() map[layout.BorderPosition]string { - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: "Logs", - } -} - type selectedLogMsg logging.LogMessage func (i *tableCmp) Init() tea.Cmd { @@ -74,20 +66,6 @@ func (i *tableCmp) View() string { return i.table.View() } -func (i *tableCmp) Blur() tea.Cmd { - i.table.Blur() - return nil -} - -func (i *tableCmp) Focus() tea.Cmd { - i.table.Focus() - return nil -} - -func (i *tableCmp) IsFocused() bool { - return i.table.Focused() -} - func (i *tableCmp) GetSize() (int, int) { return i.table.Width(), i.table.Height() } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go deleted file mode 100644 index b659775e0..000000000 --- a/internal/tui/components/repl/editor.go +++ /dev/null @@ -1,201 +0,0 @@ -package repl - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" - "golang.org/x/net/context" -) - -type EditorCmp interface { - tea.Model - layout.Focusable - layout.Sizeable - layout.Bordered - layout.Bindings -} - -type editorCmp struct { - app *app.App - editor vimtea.Editor - editorMode vimtea.EditorMode - sessionID string - focused bool - width int - height int - cancelMessage context.CancelFunc -} - -type editorKeyMap struct { - SendMessage key.Binding - SendMessageI key.Binding - CancelMessage key.Binding - InsertMode key.Binding - NormaMode key.Binding - VisualMode key.Binding - VisualLineMode key.Binding -} - -var editorKeyMapValue = editorKeyMap{ - SendMessage: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "send message normal mode"), - ), - SendMessageI: key.NewBinding( - key.WithKeys("ctrl+s"), - key.WithHelp("ctrl+s", "send message insert mode"), - ), - CancelMessage: key.NewBinding( - key.WithKeys("ctrl+x"), - key.WithHelp("ctrl+x", "cancel current message"), - ), - InsertMode: key.NewBinding( - key.WithKeys("i"), - key.WithHelp("i", "insert mode"), - ), - NormaMode: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "normal mode"), - ), - VisualMode: key.NewBinding( - key.WithKeys("v"), - key.WithHelp("v", "visual mode"), - ), - VisualLineMode: key.NewBinding( - key.WithKeys("V"), - key.WithHelp("V", "visual line mode"), - ), -} - -func (m *editorCmp) Init() tea.Cmd { - return m.editor.Init() -} - -func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case vimtea.EditorModeMsg: - m.editorMode = msg.Mode - case SelectedSessionMsg: - if msg.SessionID != m.sessionID { - m.sessionID = msg.SessionID - } - } - if m.IsFocused() { - switch msg := msg.(type) { - case tea.KeyMsg: - switch { - case key.Matches(msg, editorKeyMapValue.SendMessage): - if m.editorMode == vimtea.ModeNormal { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.SendMessageI): - if m.editorMode == vimtea.ModeInsert { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.CancelMessage): - return m, m.Cancel() - } - } - u, cmd := m.editor.Update(msg) - m.editor = u.(vimtea.Editor) - return m, cmd - } - return m, nil -} - -func (m *editorCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *editorCmp) BorderText() map[layout.BorderPosition]string { - title := "New Message" - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.BottomLeftBorder: title, - } -} - -func (m *editorCmp) Focus() tea.Cmd { - m.focused = true - return m.editor.Tick() -} - -func (m *editorCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *editorCmp) IsFocused() bool { - return m.focused -} - -func (m *editorCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.editor.SetSize(width, height) -} - -func (m *editorCmp) Cancel() tea.Cmd { - if m.cancelMessage == nil { - return util.ReportWarn("No message to cancel") - } - - m.cancelMessage() - m.cancelMessage = nil - return util.ReportWarn("Message cancelled") -} - -func (m *editorCmp) Send() tea.Cmd { - if m.cancelMessage != nil { - return util.ReportWarn("Assistant is still working on the previous message") - } - - messages, err := m.app.Messages.List(context.Background(), m.sessionID) - if err != nil { - return util.ReportError(err) - } - if hasUnfinishedMessages(messages) { - return util.ReportWarn("Assistant is still working on the previous message") - } - - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") - if len(content) == 0 { - return util.ReportWarn("Message is empty") - } - ctx, cancel := context.WithCancel(context.Background()) - m.cancelMessage = cancel - go func() { - defer cancel() - m.app.CoderAgent.Generate(ctx, m.sessionID, content) - m.cancelMessage = nil - }() - - return m.editor.Reset() -} - -func (m *editorCmp) View() string { - return m.editor.View() -} - -func (m *editorCmp) BindingKeys() []key.Binding { - return layout.KeyMapToSlice(editorKeyMapValue) -} - -func NewEditorCmp(app *app.App) EditorCmp { - editor := vimtea.NewEditor( - vimtea.WithFileName("message.md"), - ) - return &editorCmp{ - app: app, - editor: editor, - } -} diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go deleted file mode 100644 index 260be220e..000000000 --- a/internal/tui/components/repl/messages.go +++ /dev/null @@ -1,513 +0,0 @@ -package repl - -import ( - "context" - "encoding/json" - "fmt" - "sort" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/viewport" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/glamour" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type MessagesCmp interface { - tea.Model - layout.Focusable - layout.Bordered - layout.Sizeable - layout.Bindings -} - -type messagesCmp struct { - app *app.App - messages []message.Message - selectedMsgIdx int // Index of the selected message - session session.Session - viewport viewport.Model - mdRenderer *glamour.TermRenderer - width int - height int - focused bool - cachedView string -} - -func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case pubsub.Event[message.Message]: - if msg.Type == pubsub.CreatedEvent { - if msg.Payload.SessionID == m.session.ID { - m.messages = append(m.messages, msg.Payload) - m.renderView() - m.viewport.GotoBottom() - } - for _, v := range m.messages { - for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called - if c.ID == msg.Payload.SessionID { - m.renderView() - m.viewport.GotoBottom() - } - } - } - } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { - for i, v := range m.messages { - if v.ID == msg.Payload.ID { - m.messages[i] = msg.Payload - m.renderView() - if i == len(m.messages)-1 { - m.viewport.GotoBottom() - } - break - } - } - } - case pubsub.Event[session.Session]: - if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID { - m.session = msg.Payload - } - case SelectedSessionMsg: - m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID) - m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID) - m.renderView() - m.viewport.GotoBottom() - } - if m.focused { - u, cmd := m.viewport.Update(msg) - m.viewport = u - return m, cmd - } - return m, nil -} - -func borderColor(role message.MessageRole) lipgloss.TerminalColor { - switch role { - case message.Assistant: - return styles.Mauve - case message.User: - return styles.Rosewater - } - return styles.Blue -} - -func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string { - role := "" - icon := "" - switch msgRole { - case message.Assistant: - role = "Assistant" - icon = styles.BotIcon - case message.User: - role = "User" - icon = styles.UserIcon - } - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("%s %s ", role, icon)), - layout.TopRightBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("#%d ", currentMessage)), - } -} - -func hasUnfinishedMessages(messages []message.Message) bool { - if len(messages) == 0 { - return false - } - for _, msg := range messages { - if !msg.IsFinished() { - return true - } - } - return false -} - -func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string { - allParts := []string{content} - - leftPaddingValue := 4 - connectorStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - toolCallStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Peach). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - toolResultStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Green). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue) - - runningStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - renderTool := func(toolCall message.ToolCall) string { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := m.width - leftPaddingValue*2 - 10 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - return toolCallStyle.Render(toolContent) - } - - findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult { - for _, msg := range messages { - if msg.Role == message.Tool { - for _, result := range msg.ToolResults() { - if result.ToolCallID == toolCallID { - return &result - } - } - } - } - return nil - } - - renderToolResult := func(result message.ToolResult) string { - resultHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Green). - Render(fmt.Sprintf("%s Result", styles.CheckIcon)) - - // Use the same style for both header and border if it's an error - borderColor := styles.Green - if result.IsError { - resultHeader = lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Red). - Render(fmt.Sprintf("%s Error", styles.ErrorIcon)) - borderColor = styles.Red - } - - truncate := 200 - content := result.Content - if len(content) > truncate { - content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - - resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content) - return toolResultStyle.BorderForeground(borderColor).Render(resultContent) - } - - connector := connectorStyle.Render("└─> Tool Calls:") - allParts = append(allParts, connector) - - for _, toolCall := range tools { - toolOutput := renderTool(toolCall) - allParts = append(allParts, leftPadding.Render(toolOutput)) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - - resultOutput := renderToolResult(*result) - allParts = append(allParts, leftPadding.Render(resultOutput)) - - } else if toolCall.Name == agent.AgentToolName { - - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, leftPadding.Render(runningIndicator)) - taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID) - for _, msg := range taskSessionMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := 50 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent) - allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput)) - } - } - } - - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - - for _, msg := range futureMessages { - if msg.Content().String() != "" || msg.FinishReason() == "canceled" { - break - } - - for _, toolCall := range msg.ToolCalls() { - toolOutput := renderTool(toolCall) - allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n ")) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - resultOutput := renderToolResult(*result) - allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n ")) - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - } - - return lipgloss.JoinVertical(lipgloss.Left, allParts...) -} - -func (m *messagesCmp) renderView() { - stringMessages := make([]string, 0) - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(m.width-20), - glamour.WithEmoji(), - ) - textStyle := lipgloss.NewStyle().Width(m.width - 4) - currentMessage := 1 - displayedMsgCount := 0 // Track the actual displayed messages count - - prevMessageWasUser := false - for inx, msg := range m.messages { - content := msg.Content().String() - if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" { - if msg.ReasoningContent().String() != "" && content == "" { - content = msg.ReasoningContent().String() - } else if content == "" { - content = "..." - } - if msg.FinishReason() == "canceled" { - content, _ = r.Render(content) - content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled") - } else { - content, _ = r.Render(content) - } - - isSelected := inx == m.selectedMsgIdx - - border := lipgloss.DoubleBorder() - activeColor := borderColor(msg.Role) - - if isSelected { - activeColor = styles.Primary // Use primary color for selected message - } - - content = layout.Borderize( - textStyle.Render(content), - layout.BorderOptions{ - InactiveBorder: border, - ActiveBorder: border, - ActiveColor: activeColor, - InactiveColor: borderColor(msg.Role), - EmbeddedText: borderText(msg.Role, currentMessage), - }, - ) - if len(msg.ToolCalls()) > 0 { - content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:]) - } - stringMessages = append(stringMessages, content) - currentMessage++ - displayedMsgCount++ - } - if msg.Role == message.User && msg.Content().String() != "" { - prevMessageWasUser = true - } else { - prevMessageWasUser = false - } - } - m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...)) -} - -func (m *messagesCmp) View() string { - return lipgloss.NewStyle().Padding(1).Render(m.viewport.View()) -} - -func (m *messagesCmp) BindingKeys() []key.Binding { - keys := layout.KeyMapToSlice(m.viewport.KeyMap) - - return keys -} - -func (m *messagesCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *messagesCmp) projectDiagnostics() string { - errorDiagnostics := []protocol.Diagnostic{} - warnDiagnostics := []protocol.Diagnostic{} - hintDiagnostics := []protocol.Diagnostic{} - infoDiagnostics := []protocol.Diagnostic{} - for _, client := range m.app.LSPClients { - for _, d := range client.GetDiagnostics() { - for _, diag := range d { - switch diag.Severity { - case protocol.SeverityError: - errorDiagnostics = append(errorDiagnostics, diag) - case protocol.SeverityWarning: - warnDiagnostics = append(warnDiagnostics, diag) - case protocol.SeverityHint: - hintDiagnostics = append(hintDiagnostics, diag) - case protocol.SeverityInformation: - infoDiagnostics = append(infoDiagnostics, diag) - } - } - } - } - - if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { - return "No diagnostics" - } - - diagnostics := []string{} - - if len(errorDiagnostics) > 0 { - errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) - diagnostics = append(diagnostics, errStr) - } - if len(warnDiagnostics) > 0 { - warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) - diagnostics = append(diagnostics, warnStr) - } - if len(hintDiagnostics) > 0 { - hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) - diagnostics = append(diagnostics, hintStr) - } - if len(infoDiagnostics) > 0 { - infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) - diagnostics = append(diagnostics, infoStr) - } - - return strings.Join(diagnostics, " ") -} - -func (m *messagesCmp) BorderText() map[layout.BorderPosition]string { - title := m.session.Title - titleWidth := m.width / 2 - if len(title) > titleWidth { - title = title[:titleWidth] + "..." - } - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - borderTest := map[layout.BorderPosition]string{ - layout.TopLeftBorder: title, - layout.BottomRightBorder: m.projectDiagnostics(), - } - if hasUnfinishedMessages(m.messages) { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...") - } else { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ") - } - - return borderTest -} - -func (m *messagesCmp) Focus() tea.Cmd { - m.focused = true - return nil -} - -func (m *messagesCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *messagesCmp) IsFocused() bool { - return m.focused -} - -func (m *messagesCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.viewport.Width = width - 2 // padding - m.viewport.Height = height - 2 // padding - m.renderView() -} - -func (m *messagesCmp) Init() tea.Cmd { - return nil -} - -func NewMessagesCmp(app *app.App) MessagesCmp { - return &messagesCmp{ - app: app, - messages: []message.Message{}, - viewport: viewport.New(0, 0), - } -} diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go deleted file mode 100644 index c83c40367..000000000 --- a/internal/tui/components/repl/sessions.go +++ /dev/null @@ -1,249 +0,0 @@ -package repl - -import ( - "context" - "fmt" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/list" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SessionsCmp interface { - tea.Model - layout.Sizeable - layout.Focusable - layout.Bordered - layout.Bindings -} -type sessionsCmp struct { - app *app.App - list list.Model - focused bool -} - -type listItem struct { - id, title, desc string -} - -func (i listItem) Title() string { return i.title } -func (i listItem) Description() string { return i.desc } -func (i listItem) FilterValue() string { return i.title } - -type InsertSessionsMsg struct { - sessions []session.Session -} - -type SelectedSessionMsg struct { - SessionID string -} - -type sessionsKeyMap struct { - Select key.Binding -} - -var sessionKeyMapValue = sessionsKeyMap{ - Select: key.NewBinding( - key.WithKeys("enter", " "), - key.WithHelp("enter/space", "select session"), - ), -} - -func (i *sessionsCmp) Init() tea.Cmd { - existing, err := i.app.Sessions.List(context.Background()) - if err != nil { - return util.ReportError(err) - } - if len(existing) == 0 || existing[0].MessageCount > 0 { - newSession, err := i.app.Sessions.Create( - context.Background(), - "New Session", - ) - if err != nil { - return util.ReportError(err) - } - existing = append([]session.Session{newSession}, existing...) - } - return tea.Batch( - util.CmdHandler(InsertSessionsMsg{existing}), - util.CmdHandler(SelectedSessionMsg{existing[0].ID}), - ) -} - -func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case InsertSessionsMsg: - items := make([]list.Item, len(msg.sessions)) - for i, s := range msg.sessions { - items[i] = listItem{ - id: s.ID, - title: s.Title, - desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost), - } - } - return i, i.list.SetItems(items) - case pubsub.Event[session.Session]: - if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" { - // Check if the session is already in the list - items := i.list.Items() - for _, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - return i, nil - } - } - // insert the new session at the top of the list - items = append([]list.Item{listItem{ - id: msg.Payload.ID, - title: msg.Payload.Title, - desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost), - }}, items...) - return i, i.list.SetItems(items) - } else if msg.Type == pubsub.UpdatedEvent { - // update the session in the list - items := i.list.Items() - for idx, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - s.title = msg.Payload.Title - s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost) - items[idx] = s - break - } - } - return i, i.list.SetItems(items) - } - - case tea.KeyMsg: - switch { - case key.Matches(msg, sessionKeyMapValue.Select): - selected := i.list.SelectedItem() - if selected == nil { - return i, nil - } - return i, util.CmdHandler(SelectedSessionMsg{selected.(listItem).id}) - } - } - if i.focused { - u, cmd := i.list.Update(msg) - i.list = u - return i, cmd - } - return i, nil -} - -func (i *sessionsCmp) View() string { - return i.list.View() -} - -func (i *sessionsCmp) Blur() tea.Cmd { - i.focused = false - return nil -} - -func (i *sessionsCmp) Focus() tea.Cmd { - i.focused = true - return nil -} - -func (i *sessionsCmp) GetSize() (int, int) { - return i.list.Width(), i.list.Height() -} - -func (i *sessionsCmp) IsFocused() bool { - return i.focused -} - -func (i *sessionsCmp) SetSize(width int, height int) { - i.list.SetSize(width, height) -} - -func (i *sessionsCmp) BorderText() map[layout.BorderPosition]string { - totalCount := len(i.list.Items()) - itemsPerPage := i.list.Paginator.PerPage - currentPage := i.list.Paginator.Page - - current := min(currentPage*itemsPerPage+itemsPerPage, totalCount) - - pageInfo := fmt.Sprintf( - "%d-%d of %d", - currentPage*itemsPerPage+1, - current, - totalCount, - ) - - title := "Sessions" - if i.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.TopMiddleBorder: title, - layout.BottomMiddleBorder: pageInfo, - } -} - -func (i *sessionsCmp) BindingKeys() []key.Binding { - return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select) -} - -func formatTokensAndCost(tokens int64, cost float64) string { - // Format tokens in human-readable format (e.g., 110K, 1.2M) - var formattedTokens string - switch { - case tokens >= 1_000_000: - formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000) - case tokens >= 1_000: - formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000) - default: - formattedTokens = fmt.Sprintf("%d", tokens) - } - - // Remove .0 suffix if present - if strings.HasSuffix(formattedTokens, ".0K") { - formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1) - } - if strings.HasSuffix(formattedTokens, ".0M") { - formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1) - } - - // Format cost with $ symbol and 2 decimal places - formattedCost := fmt.Sprintf("$%.2f", cost) - - return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost) -} - -func NewSessionsCmp(app *app.App) SessionsCmp { - listDelegate := list.NewDefaultDelegate() - defaultItemStyle := list.NewDefaultItemStyles() - defaultItemStyle.SelectedTitle = defaultItemStyle.SelectedTitle.BorderForeground(styles.Secondary).Foreground(styles.Primary) - defaultItemStyle.SelectedDesc = defaultItemStyle.SelectedDesc.BorderForeground(styles.Secondary).Foreground(styles.Primary) - - defaultStyle := list.DefaultStyles() - defaultStyle.FilterPrompt = defaultStyle.FilterPrompt.Foreground(styles.Secondary) - defaultStyle.FilterCursor = defaultStyle.FilterCursor.Foreground(styles.Flamingo) - - listDelegate.Styles = defaultItemStyle - - listComponent := list.New([]list.Item{}, listDelegate, 0, 0) - listComponent.FilterInput.PromptStyle = defaultStyle.FilterPrompt - listComponent.FilterInput.Cursor.Style = defaultStyle.FilterCursor - listComponent.SetShowTitle(false) - listComponent.SetShowPagination(false) - listComponent.SetShowHelp(false) - listComponent.SetShowStatusBar(false) - listComponent.DisableQuitKeybindings() - - return &sessionsCmp{ - app: app, - list: listComponent, - focused: false, - } -} diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 22f9e00fe..4a1bcf661 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" @@ -45,13 +46,15 @@ func PlaceOverlay( if shadow { var shadowbg string = "" shadowchar := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#333333")). + Background(styles.BackgroundDarker). + Foreground(styles.Background). Render("░") + bgchar := styles.BaseStyle.Render(" ") for i := 0; i <= fgHeight; i++ { if i == 0 { - shadowbg += " " + strings.Repeat(" ", fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(bgchar, fgWidth) + "\n" } else { - shadowbg += " " + strings.Repeat(shadowchar, fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(shadowchar, fgWidth) + "\n" } } @@ -159,8 +162,6 @@ func max(a, b int) int { return b } - - type whitespace struct { style termenv.Style chars string diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 0ed85dd6f..6482fc74c 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -10,6 +10,7 @@ import ( type SplitPaneLayout interface { tea.Model Sizeable + Bindings SetLeftPanel(panel Container) SetRightPanel(panel Container) SetBottomPanel(panel Container) diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 439c89e1f..cebc0e461 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -37,7 +37,6 @@ var keyMap = ChatKeyMap{ } func (p *chatPage) Init() tea.Cmd { - // TODO: remove cmds := []tea.Cmd{ p.layout.Init(), } @@ -48,9 +47,7 @@ func (p *chatPage) Init() tea.Cmd { cmd := p.setSidebar() cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) } - return tea.Batch( - cmds..., - ) + return tea.Batch(cmds...) } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -68,6 +65,13 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.session = session.Session{} p.clearSidebar() return p, util.CmdHandler(chat.SessionClearedMsg{}) + case key.Matches(msg, keyMap.Cancel): + if p.session.ID != "" { + // Cancel the current session's generation process + // This allows users to interrupt long-running operations + p.app.CoderAgent.Cancel(p.session.ID) + return p, nil + } } } u, cmd := p.layout.Update(msg) @@ -80,7 +84,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (p *chatPage) setSidebar() tea.Cmd { sidebarContainer := layout.NewContainer( - chat.NewSidebarCmp(p.session), + chat.NewSidebarCmp(p.session, p.app.History), layout.WithPadding(1, 1, 1, 1), ) p.layout.SetRightPanel(sidebarContainer) @@ -111,14 +115,28 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - p.app.CoderAgent.Generate(context.Background(), p.session.ID, text) + p.app.CoderAgent.Run(context.Background(), p.session.ID, text) return tea.Batch(cmds...) } +func (p *chatPage) SetSize(width, height int) { + p.layout.SetSize(width, height) +} + +func (p *chatPage) GetSize() (int, int) { + return p.layout.GetSize() +} + func (p *chatPage) View() string { return p.layout.View() } +func (p *chatPage) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(keyMap) + bindings = append(bindings, p.layout.BindingKeys()...) + return bindings +} + func NewChatPage(app *app.App) tea.Model { messagesContainer := layout.NewContainer( chat.NewMessagesCmp(app), @@ -126,7 +144,7 @@ func NewChatPage(app *app.App) tea.Model { ) editorContainer := layout.NewContainer( - chat.NewEditorCmp(), + chat.NewEditorCmp(app), layout.WithBorder(true, false, false, false), ) return &chatPage{ diff --git a/internal/tui/page/init.go b/internal/tui/page/init.go deleted file mode 100644 index 0a5c6f82a..000000000 --- a/internal/tui/page/init.go +++ /dev/null @@ -1,308 +0,0 @@ -package page - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/huh" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/spf13/viper" -) - -var InitPage PageID = "init" - -type configSaved struct{} - -type initPage struct { - form *huh.Form - width int - height int - saved bool - errorMsg string - statusMsg string - modelOpts []huh.Option[string] - bigModel string - smallModel string - openAIKey string - anthropicKey string - groqKey string - maxTokens string - dataDir string - agent string -} - -func (i *initPage) Init() tea.Cmd { - return i.form.Init() -} - -func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.WindowSizeMsg: - i.width = msg.Width - 4 // Account for border - i.height = msg.Height - 4 - i.form = i.form.WithWidth(i.width).WithHeight(i.height) - return i, nil - - case configSaved: - i.saved = true - i.statusMsg = "Configuration saved successfully. Press any key to continue." - return i, nil - } - - if i.saved { - switch msg.(type) { - case tea.KeyMsg: - return i, util.CmdHandler(PageChangeMsg{ID: ReplPage}) - } - return i, nil - } - - // Process the form - form, cmd := i.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - i.form = f - cmds = append(cmds, cmd) - } - - if i.form.State == huh.StateCompleted { - // Save configuration to file - configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml") - maxTokens, _ := strconv.Atoi(i.maxTokens) - config := map[string]any{ - "models": map[string]string{ - "big": i.bigModel, - "small": i.smallModel, - }, - "providers": map[string]any{ - "openai": map[string]string{ - "key": i.openAIKey, - }, - "anthropic": map[string]string{ - "key": i.anthropicKey, - }, - "groq": map[string]string{ - "key": i.groqKey, - }, - "common": map[string]int{ - "max_tokens": maxTokens, - }, - }, - "data": map[string]string{ - "dir": i.dataDir, - }, - "agents": map[string]string{ - "default": i.agent, - }, - "log": map[string]string{ - "level": "info", - }, - } - - // Write config to viper - for k, v := range config { - viper.Set(k, v) - } - - // Save configuration - err := viper.WriteConfigAs(configPath) - if err != nil { - i.errorMsg = fmt.Sprintf("Failed to save configuration: %s", err) - return i, nil - } - - // Return to main page - return i, util.CmdHandler(configSaved{}) - } - - return i, tea.Batch(cmds...) -} - -func (i *initPage) View() string { - if i.saved { - return lipgloss.NewStyle(). - Width(i.width). - Height(i.height). - Align(lipgloss.Center, lipgloss.Center). - Render(lipgloss.JoinVertical( - lipgloss.Center, - lipgloss.NewStyle().Foreground(styles.Green).Render("✓ Configuration Saved"), - "", - lipgloss.NewStyle().Foreground(styles.Blue).Render(i.statusMsg), - )) - } - - view := i.form.View() - if i.errorMsg != "" { - errorBox := lipgloss.NewStyle(). - Padding(1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Red). - Width(i.width - 4). - Render(i.errorMsg) - view = lipgloss.JoinVertical(lipgloss.Left, errorBox, view) - } - return view -} - -func (i *initPage) GetSize() (int, int) { - return i.width, i.height -} - -func (i *initPage) SetSize(width int, height int) { - i.width = width - i.height = height - i.form = i.form.WithWidth(width).WithHeight(height) -} - -func (i *initPage) BindingKeys() []key.Binding { - if i.saved { - return []key.Binding{ - key.NewBinding( - key.WithKeys("enter", "space", "esc"), - key.WithHelp("any key", "continue"), - ), - } - } - return i.form.KeyBinds() -} - -func NewInitPage() tea.Model { - // Create model options - var modelOpts []huh.Option[string] - for id, model := range models.SupportedModels { - modelOpts = append(modelOpts, huh.NewOption(model.Name, string(id))) - } - - // Create agent options - agentOpts := []huh.Option[string]{ - huh.NewOption("Coder", "coder"), - huh.NewOption("Assistant", "assistant"), - } - - // Init page with form - initModel := &initPage{ - modelOpts: modelOpts, - bigModel: string(models.Claude37Sonnet), - smallModel: string(models.Claude37Sonnet), - maxTokens: "4000", - dataDir: ".termai", - agent: "coder", - } - - // API Keys group - apiKeysGroup := huh.NewGroup( - huh.NewNote(). - Title("API Keys"). - Description("You need to provide at least one API key to use termai"), - - huh.NewInput(). - Title("OpenAI API Key"). - Placeholder("sk-..."). - Key("openai_key"). - Value(&initModel.openAIKey), - - huh.NewInput(). - Title("Anthropic API Key"). - Placeholder("sk-ant-..."). - Key("anthropic_key"). - Value(&initModel.anthropicKey), - - huh.NewInput(). - Title("Groq API Key"). - Placeholder("gsk_..."). - Key("groq_key"). - Value(&initModel.groqKey), - ) - - // Model configuration group - modelsGroup := huh.NewGroup( - huh.NewNote(). - Title("Model Configuration"). - Description("Select which models to use"), - - huh.NewSelect[string](). - Title("Big Model"). - Options(modelOpts...). - Key("big_model"). - Value(&initModel.bigModel), - - huh.NewSelect[string](). - Title("Small Model"). - Options(modelOpts...). - Key("small_model"). - Value(&initModel.smallModel), - - huh.NewInput(). - Title("Max Tokens"). - Placeholder("4000"). - Key("max_tokens"). - CharLimit(5). - Validate(func(s string) error { - var n int - _, err := fmt.Sscanf(s, "%d", &n) - if err != nil || n <= 0 { - return fmt.Errorf("must be a positive number") - } - initModel.maxTokens = s - return nil - }). - Value(&initModel.maxTokens), - ) - - // General settings group - generalGroup := huh.NewGroup( - huh.NewNote(). - Title("General Settings"). - Description("Configure general termai settings"), - - huh.NewInput(). - Title("Data Directory"). - Placeholder(".termai"). - Key("data_dir"). - Value(&initModel.dataDir), - - huh.NewSelect[string](). - Title("Default Agent"). - Options(agentOpts...). - Key("agent"). - Value(&initModel.agent), - - huh.NewConfirm(). - Title("Save Configuration"). - Affirmative("Save"). - Negative("Cancel"), - ) - - // Create form with theme - form := huh.NewForm( - apiKeysGroup, - modelsGroup, - generalGroup, - ).WithTheme(styles.HuhTheme()). - WithShowHelp(true). - WithShowErrors(true) - - // Set the form in the model - initModel.form = form - - return layout.NewSinglePane( - initModel, - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneBorderText( - map[layout.BorderPosition]string{ - layout.TopMiddleBorder: "Welcome to termai - Initial Setup", - }, - ), - ) -} diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index 12afaf6aa..d1e557eab 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -8,6 +8,23 @@ import ( var LogsPage PageID = "logs" +type logsPage struct { + table logs.TableComponent + details logs.DetailComponent +} + +func (p *logsPage) Init() tea.Cmd { + return nil +} + +func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + return p, nil +} + +func (p *logsPage) View() string { + return p.table.View() + "\n" + p.details.View() +} + func NewLogsPage() tea.Model { return layout.NewBentoLayout( layout.BentoPanes{ diff --git a/internal/tui/page/repl.go b/internal/tui/page/repl.go deleted file mode 100644 index 47a924b7b..000000000 --- a/internal/tui/page/repl.go +++ /dev/null @@ -1,21 +0,0 @@ -package page - -import ( - tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" - "github.com/kujtimiihoxha/termai/internal/tui/layout" -) - -var ReplPage PageID = "repl" - -func NewReplPage(app *app.App) tea.Model { - return layout.NewBentoLayout( - layout.BentoPanes{ - layout.BentoLeftPane: repl.NewSessionsCmp(app), - layout.BentoRightTopPane: repl.NewMessagesCmp(app), - layout.BentoRightBottomPane: repl.NewEditorCmp(app), - }, - layout.WithBentoLayoutCurrentPane(layout.BentoRightBottomPane), - ) -} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1b1a1ed50..dff7ad63d 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,8 +1,6 @@ package tui import ( - "context" - "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -12,47 +10,41 @@ import ( "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/page" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" ) type keyMap struct { - Logs key.Binding - Return key.Binding - Back key.Binding - Quit key.Binding - Help key.Binding + Logs key.Binding + Quit key.Binding + Help key.Binding } var keys = keyMap{ Logs: key.NewBinding( - key.WithKeys("L"), - key.WithHelp("L", "logs"), - ), - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), - Back: key.NewBinding( - key.WithKeys("backspace"), - key.WithHelp("backspace", "back"), + key.WithKeys("ctrl+l"), + key.WithHelp("ctrl+L", "logs"), ), + Quit: key.NewBinding( - key.WithKeys("ctrl+c", "q"), - key.WithHelp("ctrl+c/q", "quit"), + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), ), Help: key.NewBinding( - key.WithKeys("?"), - key.WithHelp("?", "toggle help"), + key.WithKeys("ctrl+_"), + key.WithHelp("ctrl+?", "toggle help"), ), } -var replKeyMap = key.NewBinding( - key.WithKeys("N"), - key.WithHelp("N", "new session"), +var returnKey = key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), +) + +var logsKeyReturnKey = key.NewBinding( + key.WithKeys("backspace"), + key.WithHelp("backspace", "go back"), ) type appModel struct { @@ -62,18 +54,30 @@ type appModel struct { pages map[page.PageID]tea.Model loadedPages map[page.PageID]bool status tea.Model - help core.HelpCmp - dialog core.DialogCmp app *app.App - dialogVisible bool - editorMode vimtea.EditorMode - showHelp bool + + showPermissions bool + permissions dialog.PermissionDialogCmp + + showHelp bool + help dialog.HelpCmp + + showQuit bool + quit dialog.QuitDialog } func (a appModel) Init() tea.Cmd { + var cmds []tea.Cmd cmd := a.pages[a.currentPage].Init() a.loadedPages[a.currentPage] = true - return cmd + cmds = append(cmds, cmd) + cmd = a.status.Init() + cmds = append(cmds, cmd) + cmd = a.quit.Init() + cmds = append(cmds, cmd) + cmd = a.help.Init() + cmds = append(cmds, cmd) + return tea.Batch(cmds...) } func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -81,22 +85,20 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - var cmds []tea.Cmd msg.Height -= 1 // Make space for the status bar a.width, a.height = msg.Width, msg.Height a.status, _ = a.status.Update(msg) - - uh, _ := a.help.Update(msg) - a.help = uh.(core.HelpCmp) - - p, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) - a.pages[a.currentPage] = p - d, cmd := a.dialog.Update(msg) - cmds = append(cmds, cmd) - a.dialog = d.(core.DialogCmp) + prm, permCmd := a.permissions.Update(msg) + a.permissions = prm.(dialog.PermissionDialogCmp) + cmds = append(cmds, permCmd) + + help, helpCmd := a.help.Update(msg) + a.help = help.(dialog.HelpCmp) + cmds = append(cmds, helpCmd) return a, tea.Batch(cmds...) @@ -141,7 +143,9 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Permission case pubsub.Event[permission.PermissionRequest]: - return a, dialog.NewPermissionDialogCmd(msg.Payload) + a.showPermissions = true + a.permissions.SetPermissions(msg.Payload) + return a, nil case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: @@ -151,91 +155,71 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) } - - // Dialog - case core.DialogMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = true - return a, cmd - case core.DialogCloseMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = false - return a, cmd - - // Editor - case vimtea.EditorModeMsg: - a.editorMode = msg.Mode + a.showPermissions = false + return a, nil case page.PageChangeMsg: return a, a.moveToPage(msg.ID) + + case dialog.CloseQuitMsg: + a.showQuit = false + return a, nil + case tea.KeyMsg: - if a.editorMode == vimtea.ModeNormal { - switch { - case key.Matches(msg, keys.Quit): - return a, dialog.NewQuitDialogCmd() - case key.Matches(msg, keys.Back): - if a.previousPage != "" { - return a, a.moveToPage(a.previousPage) - } - case key.Matches(msg, keys.Return): - if a.showHelp { - a.ToggleHelp() - return a, nil - } - case key.Matches(msg, replKeyMap): - if a.currentPage == page.ReplPage { - sessions, err := a.app.Sessions.List(context.Background()) - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - lastSession := sessions[0] - if lastSession.MessageCount == 0 { - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID}) - } - s, err := a.app.Sessions.Create(context.Background(), "New Session") - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID}) - } - // case key.Matches(msg, keys.Logs): - // return a, a.moveToPage(page.LogsPage) - case msg.String() == "O": - return a, a.moveToPage(page.ReplPage) - case key.Matches(msg, keys.Help): - a.ToggleHelp() + switch { + case key.Matches(msg, keys.Quit): + a.showQuit = !a.showQuit + if a.showHelp { + a.showHelp = false + } + return a, nil + case key.Matches(msg, logsKeyReturnKey): + if a.currentPage == page.LogsPage { + return a, a.moveToPage(page.ChatPage) + } + case key.Matches(msg, returnKey): + if a.showQuit { + a.showQuit = !a.showQuit + return a, nil + } + if a.showHelp { + a.showHelp = !a.showHelp + return a, nil + } + case key.Matches(msg, keys.Logs): + return a, a.moveToPage(page.LogsPage) + case key.Matches(msg, keys.Help): + if a.showQuit { return a, nil } + a.showHelp = !a.showHelp + return a, nil } } - if a.dialogVisible { - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - cmds = append(cmds, cmd) - return a, tea.Batch(cmds...) + if a.showQuit { + q, quitCmd := a.quit.Update(msg) + a.quit = q.(dialog.QuitDialog) + cmds = append(cmds, quitCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + if a.showPermissions { + d, permissionsCmd := a.permissions.Update(msg) + a.permissions = d.(dialog.PermissionDialogCmp) + cmds = append(cmds, permissionsCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } } a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) } -func (a *appModel) ToggleHelp() { - if a.showHelp { - a.showHelp = false - a.height += a.help.Height() - } else { - a.showHelp = true - a.height -= a.help.Height() - } - - if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok { - sizable.SetSize(a.width, a.height) - } -} - func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { var cmd tea.Cmd if _, ok := a.loadedPages[pageID]; !ok { @@ -256,27 +240,55 @@ func (a appModel) View() string { a.pages[a.currentPage].View(), } + components = append(components, a.status.View()) + + appView := lipgloss.JoinVertical(lipgloss.Top, components...) + + if a.showPermissions { + overlay := a.permissions.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + if a.showHelp { bindings := layout.KeyMapToSlice(keys) if p, ok := a.pages[a.currentPage].(layout.Bindings); ok { bindings = append(bindings, p.BindingKeys()...) } - if a.dialogVisible { - bindings = append(bindings, a.dialog.BindingKeys()...) + if a.showPermissions { + bindings = append(bindings, a.permissions.BindingKeys()...) } - if a.currentPage == page.ReplPage { - bindings = append(bindings, replKeyMap) + if a.currentPage == page.LogsPage { + bindings = append(bindings, logsKeyReturnKey) } - a.help.SetBindings(bindings) - components = append(components, a.help.View()) - } - components = append(components, a.status.View()) + a.help.SetBindings(bindings) - appView := lipgloss.JoinVertical(lipgloss.Top, components...) + overlay := a.help.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } - if a.dialogVisible { - overlay := a.dialog.View() + if a.showQuit { + overlay := a.quit.View() row := lipgloss.Height(appView) / 2 row -= lipgloss.Height(overlay) / 2 col := lipgloss.Width(appView) / 2 @@ -289,30 +301,23 @@ func (a appModel) View() string { true, ) } + return appView } func New(app *app.App) tea.Model { - // homedir, _ := os.UserHomeDir() - // configPath := filepath.Join(homedir, ".termai.yaml") - // startPage := page.ChatPage - // if _, err := os.Stat(configPath); os.IsNotExist(err) { - // startPage = page.InitPage - // } - return &appModel{ currentPage: startPage, loadedPages: make(map[page.PageID]bool), - status: core.NewStatusCmp(), - help: core.NewHelpCmp(), - dialog: core.NewDialogCmp(), + status: core.NewStatusCmp(app.LSPClients), + help: dialog.NewHelpCmp(), + quit: dialog.NewQuitCmp(), + permissions: dialog.NewPermissionDialogCmp(), app: app, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), - page.InitPage: page.NewInitPage(), - page.ReplPage: page.NewReplPage(app), }, } } diff --git a/main.go b/main.go index 4bc8a22f0..2e6954646 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,15 @@ package main import ( "github.com/kujtimiihoxha/termai/cmd" + "github.com/kujtimiihoxha/termai/internal/logging" ) func main() { + // Set up panic recovery for the main function + defer logging.RecoverPanic("main", func() { + // Perform any necessary cleanup before exit + logging.ErrorPersist("Application terminated due to unhandled panic") + }) + cmd.Execute() } -- cgit v1.2.3 From cc07f7a186995f428436bc1adc66a264a95171a4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 21:48:29 +0200 Subject: rename to opencode --- .opencode.json | 11 ++++ cmd/root.go | 14 ++-- go.mod | 2 +- internal/app/app.go | 18 +++--- internal/app/lsp.go | 8 +-- internal/config/config.go | 4 +- internal/db/connect.go | 6 +- internal/diff/diff.go | 4 +- internal/history/file.go | 4 +- internal/llm/agent/agent-tool.go | 10 +-- internal/llm/agent/agent.go | 18 +++--- internal/llm/agent/mcp-tools.go | 10 +-- internal/llm/agent/tools.go | 12 ++-- internal/llm/prompt/coder.go | 97 ++++++++++++---------------- internal/llm/prompt/prompt.go | 4 +- internal/llm/prompt/task.go | 4 +- internal/llm/prompt/title.go | 2 +- internal/llm/provider/anthropic.go | 8 +-- internal/llm/provider/bedrock.go | 4 +- internal/llm/provider/gemini.go | 8 +-- internal/llm/provider/openai.go | 8 +-- internal/llm/provider/provider.go | 6 +- internal/llm/tools/bash.go | 16 ++--- internal/llm/tools/diagnostics.go | 4 +- internal/llm/tools/edit.go | 10 +-- internal/llm/tools/edit_test.go | 2 +- internal/llm/tools/fetch.go | 6 +- internal/llm/tools/glob.go | 2 +- internal/llm/tools/grep.go | 2 +- internal/llm/tools/ls.go | 2 +- internal/llm/tools/mocks_test.go | 6 +- internal/llm/tools/shell/shell.go | 8 +-- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/view.go | 4 +- internal/llm/tools/write.go | 10 +-- internal/llm/tools/write_test.go | 2 +- internal/logging/writer.go | 2 +- internal/lsp/client.go | 6 +- internal/lsp/handlers.go | 8 +-- internal/lsp/language.go | 2 +- internal/lsp/methods.go | 2 +- internal/lsp/transport.go | 4 +- internal/lsp/util/edit.go | 2 +- internal/lsp/watcher/watcher.go | 8 +-- internal/message/content.go | 2 +- internal/message/message.go | 6 +- internal/permission/permission.go | 2 +- internal/session/session.go | 4 +- internal/tui/components/chat/chat.go | 8 +-- internal/tui/components/chat/editor.go | 10 +-- internal/tui/components/chat/messages.go | 22 +++---- internal/tui/components/chat/sidebar.go | 12 ++-- internal/tui/components/core/status.go | 12 ++-- internal/tui/components/dialog/help.go | 2 +- internal/tui/components/dialog/permission.go | 12 ++-- internal/tui/components/dialog/quit.go | 6 +- internal/tui/components/logs/details.go | 6 +- internal/tui/components/logs/table.go | 10 +-- internal/tui/layout/border.go | 2 +- internal/tui/layout/container.go | 2 +- internal/tui/layout/overlay.go | 4 +- internal/tui/layout/split.go | 2 +- internal/tui/page/chat.go | 10 +-- internal/tui/page/logs.go | 4 +- internal/tui/tui.go | 18 +++--- main.go | 4 +- 66 files changed, 266 insertions(+), 266 deletions(-) (limited to 'cmd') diff --git a/.opencode.json b/.opencode.json index b7fc19b52..4b2944f86 100644 --- a/.opencode.json +++ b/.opencode.json @@ -3,5 +3,16 @@ "gopls": { "command": "gopls" } + }, + "agents": { + "coder": { + "model": "gpt-4.1" + }, + "task": { + "model": "gpt-4.1" + }, + "title": { + "model": "gpt-4.1" + } } } diff --git a/cmd/root.go b/cmd/root.go index ff71747d5..f506e9940 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,13 +8,13 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" ) diff --git a/go.mod b/go.mod index 16c88d3a6..822e70dbd 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/kujtimiihoxha/termai +module github.com/kujtimiihoxha/opencode go 1.24.0 diff --git a/internal/app/app.go b/internal/app/app.go index 1c16ccc11..748fdaa7f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,15 +7,15 @@ import ( "sync" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) type App struct { diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4a762f1a1..d8a35c8b3 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -4,10 +4,10 @@ import ( "context" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/watcher" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/watcher" ) func (app *App) initLSPClients(ctx context.Context) { diff --git a/internal/config/config.go b/internal/config/config.go index 147d6c83a..20a8bac97 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/spf13/viper" ) diff --git a/internal/db/connect.go b/internal/db/connect.go index 8bba9cad8..e850bc8d0 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -12,8 +12,8 @@ import ( "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/mattn/go-sqlite3" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func Connect() (*sql.DB, error) { @@ -24,7 +24,7 @@ func Connect() (*sql.DB, error) { if err := os.MkdirAll(dataDir, 0o700); err != nil { return nil, fmt.Errorf("failed to create data directory: %w", err) } - dbPath := filepath.Join(dataDir, "termai.db") + dbPath := filepath.Join(dataDir, "opencode.db") // Open the SQLite database db, err := sql.Open("sqlite3", dbPath) if err != nil { diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 829554c7e..f48079c9c 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,8 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) diff --git a/internal/history/file.go b/internal/history/file.go index 82017d4cf..1e8bc50bb 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 308412bde..be6e09a9b 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/session" ) type agentTool struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ab2742ec1..a5dadb89d 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,15 +7,15 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/prompt" - "github.com/kujtimiihoxha/termai/internal/llm/provider" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/prompt" + "github.com/kujtimiihoxha/opencode/internal/llm/provider" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) // Common errors diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index c7ea4916c..16dddc1ba 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/version" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index a37f1d65d..409d14273 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -3,12 +3,12 @@ package agent import ( "context" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) func CoderAgentTools( diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 7439fd570..3a06911da 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -8,9 +8,9 @@ import ( "runtime" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" ) func CoderPrompt(provider models.ModelProvider) string { @@ -24,69 +24,58 @@ func CoderPrompt(provider models.ModelProvider) string { return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) } -const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. - -# Your mindset -Act like a competent, efficient software engineer who is familiar with large codebases. You should: -- Think critically about user requests. -- Proactively search the codebase for related information. -- Infer likely commands, tools, or conventions. -- Write and edit code with minimal user input. -- Anticipate next steps (tests, lints, etc.), but never commit unless explicitly told. - -# Context awareness -- Before acting, infer the purpose of a file from its name, directory, and neighboring files. -- If a file or function appears malicious, refuse to interact with it or discuss it. -- If a termai.md file exists, auto-load it as memory. Offer to update it only if new useful info appears (commands, preferences, structure). - -# CLI communication -- Use GitHub-flavored markdown in monospace font. -- Be concise. Never add preambles or postambles unless asked. Max 4 lines per response. -- Never explain your code unless asked. Do not narrate actions. -- Avoid unnecessary questions. Infer, search, act. - -# Behavior guidelines -- Follow project conventions: naming, formatting, libraries, frameworks. -- Before using any library or framework, confirm it’s already used. -- Always look at the surrounding code to match existing style. -- Do not add comments unless the code is complex or the user asks. - -# Autonomy rules -You are allowed and expected to: -- Search for commands, tools, or config files before asking the user. -- Run multiple search tool calls concurrently to gather relevant context. -- Choose test, lint, and typecheck commands based on package files or scripts. -- Offer to store these commands in termai.md if not already present. - -# Example behavior -user: write tests for new feature -assistant: [searches for existing test patterns, finds appropriate location, generates test code using existing style, optionally asks to add test command to termai.md] +const baseOpenAICoderPrompt = ` +You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. + +### ── INTERNAL REFLECTION ── +• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). +• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. + +### ── PUBLIC RESPONSE RULES ── +• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. +• Use GitHub‑flavored Markdown. +• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. + +### ── CONTEXT & MEMORY ── +• Infer file intent from directory structure before editing. +• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. -user: how do I typecheck this codebase? -assistant: [searches for known commands, infers package manager, checks for scripts or config files] -tsc --noEmit +### ── AUTONOMY PRIORITY ── +**Ask‑Only‑If Decision Tree:** +1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. +2. **Critical unknown?** (no docs/tests; cannot infer) → ask. +3. **Tool failure after two self‑attempts?** → ask. +Otherwise, proceed autonomously. -user: is X function used anywhere else? -assistant: [searches repo for references, returns file paths and lines] +### ── SAFETY & STYLE ── +• Mimic existing code style; verify libraries exist before import. +• Never commit unless explicitly told. +• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). +• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. -# Tool usage -- Use parallel calls when possible. -- Use file search and content tools before asking the user. -- Do not ask the user for information unless it cannot be determined via tools. +### ── TOOL USAGE ── +• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. +• Communicate with the user only via visible text; do not expose tool output or internal reasoning. -Never commit changes unless the user explicitly asks you to.` +### ── EXAMPLES ── +user: list files +assistant: ls + +user: write tests for new feature +assistant: [searches & edits autonomously, no extra chit‑chat] +` -const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. # Memory -If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes: +If the current working directory contains a file called OpenCode.md, it will be automatically added to your context. This file serves multiple purposes: 1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time 2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) 3. Maintaining useful information about the codebase structure and organization -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time. +When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to OpenCode.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to OpenCode.md so you can remember it for next time. # Tone and style You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system). @@ -161,7 +150,7 @@ The user will primarily request you perform software engineering tasks. This inc 1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially. 2. Implement the solution using all tools available to you 3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. -4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time. +4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to opencode.md so that you will know to run it next time. NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 63fc2df7b..cdc3560ce 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,8 +1,8 @@ package prompt import ( - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 8bf604ad9..88cd1a0f4 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -3,11 +3,11 @@ package prompt import ( "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func TaskPrompt(_ models.ModelProvider) string { - agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. + agentPrompt := `You are an agent for OpenCode. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 3023a8550..6e5289b24 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,6 @@ package prompt -import "github.com/kujtimiihoxha/termai/internal/llm/models" +import "github.com/kujtimiihoxha/opencode/internal/llm/models" func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index c3a4efc49..7bbc02103 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,10 +12,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" ) type anthropicOptions struct { diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index d76925ad1..9415b30fe 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type bedrockOptions struct { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 804baea28..384bff900 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -11,10 +11,10 @@ import ( "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 9c2ad2012..13ce934f2 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -8,10 +8,10 @@ import ( "io" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 1a5b3dc8a..e04bee71b 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type EventType string diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index c7c970e5a..18533b761 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -7,9 +7,9 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools/shell" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools/shell" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type BashParams struct { @@ -122,16 +122,16 @@ When the user asks you to create a new git commit, follow these steps carefully: 4. Create the commit with a message ending with: -🤖 Generated with termai -Co-Authored-By: termai +🤖 Generated with opencode +Co-Authored-By: opencode - In order to ensure good formatting, ALWAYS pass the commit message via a HEREDOC, a la this example: git commit -m "$(cat <<'EOF' Commit message here. - 🤖 Generated with termai - Co-Authored-By: termai + 🤖 Generated with opencode + Co-Authored-By: opencode EOF )" @@ -193,7 +193,7 @@ gh pr create --title "the pr title" --body "$(cat <<'EOF' ## Test plan [Checklist of TODOs for testing the pull request...] -🤖 Generated with termai +🤖 Generated with opencode EOF )" diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index b7b2bb8ba..82989c774 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type DiagnosticsParams struct { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 148e7aba7..6a1616010 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type EditParams struct { diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 0971775dd..1b58a0d7d 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 91bcb36a0..827755863 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -11,8 +11,8 @@ import ( md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type FetchParams struct { @@ -146,7 +146,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 7b4fb1187..40262ce2b 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -12,7 +12,7 @@ import ( "time" "github.com/bmatcuk/doublestar/v4" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) const ( diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 19333f50b..3436dd7eb 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type GrepParams struct { diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a63bf0eeb..05f300c0e 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type LSParams struct { diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go index 321f09ac1..81993160c 100644 --- a/internal/llm/tools/mocks_test.go +++ b/internal/llm/tools/mocks_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) // Mock permission service for testing diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 4a776478a..e25bdf3ea 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -126,10 +126,10 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx } tempDir := os.TempDir() - stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano())) - stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano())) - statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano())) - cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano())) + stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano())) + stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano())) + statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano())) + cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano())) defer func() { os.Remove(stdoutFile) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index a6f2c8afb..0d38c975f 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -218,7 +218,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 7450a84bf..3fa4ca116 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -10,8 +10,8 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/lsp" ) type ViewParams struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index bb49381fd..261865c39 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -8,11 +8,11 @@ import ( "path/filepath" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type WriteParams struct { diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 2264f36fb..b5ecb3fda 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/logging/writer.go b/internal/logging/writer.go index 9fe469c5e..1dc07e853 100644 --- a/internal/logging/writer.go +++ b/internal/logging/writer.go @@ -9,7 +9,7 @@ import ( "time" "github.com/go-logfmt/logfmt" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 0f03e7fcb..dad07f3c0 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -13,9 +13,9 @@ import ( "sync/atomic" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type Client struct { diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index c3088d685..7a11286e6 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -3,10 +3,10 @@ package lsp import ( "encoding/json" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/lsp/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/util" ) // Requests diff --git a/internal/lsp/language.go b/internal/lsp/language.go index 2e276c464..65ccd54f3 100644 --- a/internal/lsp/language.go +++ b/internal/lsp/language.go @@ -4,7 +4,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func DetectLanguageID(uri string) protocol.LanguageKind { diff --git a/internal/lsp/methods.go b/internal/lsp/methods.go index 079b3bfe3..ab33d7e1b 100644 --- a/internal/lsp/methods.go +++ b/internal/lsp/methods.go @@ -4,7 +4,7 @@ package lsp import ( "context" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // Implementation sends a textDocument/implementation request to the LSP server. diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 89255fd78..fe59b0fbb 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -8,8 +8,8 @@ import ( "io" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) // Write writes an LSP message to the given writer diff --git a/internal/lsp/util/edit.go b/internal/lsp/util/edit.go index 3b94fb39f..52f03ee77 100644 --- a/internal/lsp/util/edit.go +++ b/internal/lsp/util/edit.go @@ -7,7 +7,7 @@ import ( "sort" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error { diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 156f38e1a..595c78db9 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -10,10 +10,10 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // WorkspaceWatcher manages LSP file watching diff --git a/internal/message/content.go b/internal/message/content.go index f9e76b11c..f52449f4a 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) type MessageRole string diff --git a/internal/message/message.go b/internal/message/message.go index 2871780a7..f165fcfc7 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -7,9 +7,9 @@ import ( "fmt" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type CreateMessageParams struct { diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 8aa280906..4cb379dea 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,7 +6,7 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) var ErrorPermissionDenied = errors.New("permission denied") diff --git a/internal/session/session.go b/internal/session/session.go index 019019df4..280da1ff0 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -5,8 +5,8 @@ import ( "database/sql" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type Session struct { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e98001efa..52ff4c8bf 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -5,10 +5,10 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/version" ) type SendMsg struct { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e2f4da9e2..4d6ef5ca0 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type editorCmp struct { diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index 26a98970e..c2ce7d88b 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -15,17 +15,17 @@ import ( "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type uiMessageType int diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index b90269d1a..54b39f4a1 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type sidebarCmp struct { diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 089dffa2c..411cac1c5 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type statusCmp struct { diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go index 1d3c2b077..6242017f1 100644 --- a/internal/tui/components/dialog/help.go +++ b/internal/tui/components/dialog/help.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type helpCmp struct { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 9c55effde..200a7970d 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,12 +9,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type PermissionAction string diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 10d9ba8a2..5bbe6696c 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -6,9 +6,9 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) const question = "Are you sure you want to quit?" diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 18eb1a526..3a8f17999 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -9,9 +9,9 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type DetailComponent interface { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 6e8eb58b1..dc6184e3d 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -7,11 +7,11 @@ import ( "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/table" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type TableComponent interface { diff --git a/internal/tui/layout/border.go b/internal/tui/layout/border.go index 8fe5c430c..ea9f5e0bc 100644 --- a/internal/tui/layout/border.go +++ b/internal/tui/layout/border.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type BorderPosition int diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index db07d49fb..603699955 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type Container interface { diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 4a1bcf661..4c05e8462 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,8 +5,8 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" "github.com/muesli/reflow/truncate" diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 6482fc74c..bfb616a53 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type SplitPaneLayout interface { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index cebc0e461..c268e677f 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/components/chat" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) var ChatPage PageID = "chat" diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index d1e557eab..c77a033f4 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -2,8 +2,8 @@ package page import ( tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/logs" - "github.com/kujtimiihoxha/termai/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/components/logs" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" ) var LogsPage PageID = "logs" diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dff7ad63d..657de6b6e 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -4,15 +4,15 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" - "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/page" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/components/core" + "github.com/kujtimiihoxha/opencode/internal/tui/components/dialog" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/page" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type keyMap struct { diff --git a/main.go b/main.go index 2e6954646..06578c7ef 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,8 @@ package main import ( - "github.com/kujtimiihoxha/termai/cmd" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/cmd" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func main() { -- cgit v1.2.3 From e7bb99baab5e6968ce0351d6ad219ed21ceec4df Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 13:33:51 +0200 Subject: fix the memory bug --- README.md | 7 +++- cmd/root.go | 24 ++++++------ internal/pubsub/broker.go | 72 ++++++++++++++++++++++------------ internal/tui/components/chat/list.go | 1 + internal/tui/components/core/status.go | 15 +++++-- internal/tui/tui.go | 63 +++++++++++++++++++++++------ 6 files changed, 127 insertions(+), 55 deletions(-) (limited to 'cmd') diff --git a/README.md b/README.md index ef55b6929..075114fc3 100644 --- a/README.md +++ b/README.md @@ -351,9 +351,12 @@ go build -o opencode ## Acknowledgments -OpenCode builds upon the work of several open source projects and developers: +OpenCode gratefully acknowledges the contributions and support from these key individuals: -- [@isaacphi](https://github.com/isaacphi) - LSP client implementation +- [@isaacphi](https://github.com/isaacphi) - For the [mcp-language-server](https://github.com/isaacphi/mcp-language-server) project which provided the foundation for our LSP client implementation +- [@adamdottv](https://github.com/adamdottv) - For the design direction and UI/UX architecture + +Special thanks to the broader open source community whose tools and libraries have made this project possible. ## License diff --git a/cmd/root.go b/cmd/root.go index f506e9940..54280ecaa 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -79,7 +79,7 @@ var rootCmd = &cobra.Command{ initMCPTools(ctx, app) // Setup the subscriptions, this will send services events to the TUI - ch, cancelSubs := setupSubscriptions(app) + ch, cancelSubs := setupSubscriptions(app, ctx) // Create a context for the TUI message handler tuiCtx, tuiCancel := context.WithCancel(ctx) @@ -174,21 +174,21 @@ func setupSubscriber[T any]( defer wg.Done() defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil) + subCh := subscriber(ctx) + for { select { - case event, ok := <-subscriber(ctx): + case event, ok := <-subCh: if !ok { logging.Info("%s subscription channel closed", name) return } - // Convert generic event to tea.Msg if needed var msg tea.Msg = event - // Non-blocking send with timeout to prevent deadlocks select { case outputCh <- msg: - case <-time.After(500 * time.Millisecond): + case <-time.After(2 * time.Second): logging.Warn("%s message dropped due to slow consumer", name) case <-ctx.Done(): logging.Info("%s subscription cancelled", name) @@ -202,23 +202,21 @@ func setupSubscriber[T any]( }() } -func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { +func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) { ch := make(chan tea.Msg, 100) - // Add a buffer to prevent blocking + wg := sync.WaitGroup{} - ctx, cancel := context.WithCancel(context.Background()) - // Setup each subscription using the helper + ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context + setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch) setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch) setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch) setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch) - // Return channel and a cleanup function cleanupFunc := func() { logging.Info("Cancelling all subscriptions") cancel() // Signal all goroutines to stop - // Wait with a timeout for all goroutines to complete waitCh := make(chan struct{}) go func() { defer logging.RecoverPanic("subscription-cleanup", nil) @@ -229,11 +227,11 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { select { case <-waitCh: logging.Info("All subscription goroutines completed successfully") + close(ch) // Only close after all writers are confirmed done case <-time.After(5 * time.Second): logging.Warn("Timed out waiting for some subscription goroutines to complete") + close(ch) } - - close(ch) // Safe to close after all writers are done or timed out } return ch, cleanupFunc } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index d73accffb..0de1be063 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -5,47 +5,53 @@ import ( "sync" ) -const bufferSize = 1024 +const bufferSize = 64 -// Broker allows clients to publish events and subscribe to events type Broker[T any] struct { - subs map[chan Event[T]]struct{} // subscriptions - mu sync.Mutex // sync access to map - done chan struct{} // close when broker is shutting down + subs map[chan Event[T]]struct{} + mu sync.RWMutex + done chan struct{} + subCount int + maxEvents int } -// NewBroker constructs a pub/sub broker. func NewBroker[T any]() *Broker[T] { + return NewBrokerWithOptions[T](bufferSize, 1000) +} + +func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] { b := &Broker[T]{ - subs: make(map[chan Event[T]]struct{}), - done: make(chan struct{}), + subs: make(map[chan Event[T]]struct{}), + done: make(chan struct{}), + subCount: 0, + maxEvents: maxEvents, } return b } -// Shutdown the broker, terminating any subscriptions. func (b *Broker[T]) Shutdown() { - close(b.done) + select { + case <-b.done: // Already closed + return + default: + close(b.done) + } b.mu.Lock() defer b.mu.Unlock() - // Remove each subscriber entry, so Publish() cannot send any further - // messages, and close each subscriber's channel, so the subscriber cannot - // consume any more messages. for ch := range b.subs { delete(b.subs, ch) close(ch) } + + b.subCount = 0 } -// Subscribe subscribes the caller to a stream of events. The returned channel -// is closed when the broker is shutdown. func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { b.mu.Lock() defer b.mu.Unlock() - // Check if broker has shutdown and if so return closed channel select { case <-b.done: ch := make(chan Event[T]) @@ -54,18 +60,16 @@ func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { default: } - // Subscribe sub := make(chan Event[T], bufferSize) b.subs[sub] = struct{}{} + b.subCount++ - // Unsubscribe when context is done. go func() { <-ctx.Done() b.mu.Lock() defer b.mu.Unlock() - // Check if broker has shutdown and if so do nothing select { case <-b.done: return @@ -74,21 +78,39 @@ func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { delete(b.subs, sub) close(sub) + b.subCount-- }() return sub } -// Publish an event to subscribers. +func (b *Broker[T]) GetSubscriberCount() int { + b.mu.RLock() + defer b.mu.RUnlock() + return b.subCount +} + func (b *Broker[T]) Publish(t EventType, payload T) { - b.mu.Lock() - defer b.mu.Unlock() + b.mu.RLock() + select { + case <-b.done: + b.mu.RUnlock() + return + default: + } + subscribers := make([]chan Event[T], 0, len(b.subs)) for sub := range b.subs { + subscribers = append(subscribers, sub) + } + b.mu.RUnlock() + + event := Event[T]{Type: t, Payload: payload} + + for _, sub := range subscribers { select { - case sub <- Event[T]{Type: t, Payload: payload}: - case <-b.done: - return + case sub <- event: + default: } } } diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index b09cc4495..03a50541e 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -370,6 +370,7 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { delete(m.cachedContent, msg.ID) } m.uiMessages = make([]uiMessage, 0) + m.renderView() return nil } diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 5a2114e83..8bf3e5166 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -18,6 +18,11 @@ import ( "github.com/kujtimiihoxha/opencode/internal/tui/util" ) +type StatusCmp interface { + tea.Model + SetHelpMsg(string) +} + type statusCmp struct { info util.InfoMsg width int @@ -146,7 +151,7 @@ func (m *statusCmp) projectDiagnostics() string { break } } - + // If any server is initializing, show that status if initializing { return lipgloss.NewStyle(). @@ -154,7 +159,7 @@ func (m *statusCmp) projectDiagnostics() string { Foreground(styles.Peach). Render(fmt.Sprintf("%s Initializing LSP...", styles.SpinnerIcon)) } - + errorDiagnostics := []protocol.Diagnostic{} warnDiagnostics := []protocol.Diagnostic{} hintDiagnostics := []protocol.Diagnostic{} @@ -235,7 +240,11 @@ func (m statusCmp) model() string { return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name) } -func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model { +func (m statusCmp) SetHelpMsg(s string) { + helpWidget = styles.Padded.Background(styles.Forground).Foreground(styles.BackgroundDarker).Bold(true).Render(s) +} + +func NewStatusCmp(lspClients map[string]*lsp.Client) StatusCmp { return &statusCmp{ messageTTL: 10 * time.Second, lspClients: lspClients, diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 2a9ed0d70..dec43f7c0 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -39,12 +39,18 @@ var keys = keyMap{ key.WithKeys("ctrl+_"), key.WithHelp("ctrl+?", "toggle help"), ), + SwitchSession: key.NewBinding( key.WithKeys("ctrl+a"), key.WithHelp("ctrl+a", "switch session"), ), } +var helpEsc = key.NewBinding( + key.WithKeys("?"), + key.WithHelp("?", "toggle help"), +) + var returnKey = key.NewBinding( key.WithKeys("esc"), key.WithHelp("esc", "close"), @@ -61,7 +67,7 @@ type appModel struct { previousPage page.PageID pages map[page.PageID]tea.Model loadedPages map[page.PageID]bool - status tea.Model + status core.StatusCmp app *app.App showPermissions bool @@ -75,6 +81,8 @@ type appModel struct { showSessionDialog bool sessionDialog dialog.SessionDialog + + editingMode bool } func (a appModel) Init() tea.Cmd { @@ -101,7 +109,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { msg.Height -= 1 // Make space for the status bar a.width, a.height = msg.Width, msg.Height - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) @@ -118,45 +127,56 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, sessionCmd) return a, tea.Batch(cmds...) - + case chat.EditorFocusMsg: + a.editingMode = bool(msg) // Status case util.InfoMsg: - a.status, cmd = a.status.Update(msg) + s, cmd := a.status.Update(msg) + a.status = s.(core.StatusCmp) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) case pubsub.Event[logging.LogMessage]: if msg.Payload.Persist { switch msg.Payload.Level { case "error": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeError, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) case "info": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeInfo, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) + case "warn": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeWarn, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) default: - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeInfo, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) } - cmds = append(cmds, cmd) } case util.ClearStatusMsg: - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) // Permission case pubsub.Event[permission.PermissionRequest]: @@ -243,7 +263,16 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } a.showHelp = !a.showHelp return a, nil + case key.Matches(msg, helpEsc): + if !a.editingMode { + if a.showQuit { + return a, nil + } + a.showHelp = !a.showHelp + return a, nil + } } + } if a.showQuit { @@ -275,7 +304,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) @@ -326,6 +356,12 @@ func (a appModel) View() string { ) } + if a.editingMode { + a.status.SetHelpMsg("ctrl+? help") + } else { + a.status.SetHelpMsg("? help") + } + if a.showHelp { bindings := layout.KeyMapToSlice(keys) if p, ok := a.pages[a.currentPage].(layout.Bindings); ok { @@ -337,7 +373,9 @@ func (a appModel) View() string { if a.currentPage == page.LogsPage { bindings = append(bindings, logsKeyReturnKey) } - + if !a.editingMode { + bindings = append(bindings, helpEsc) + } a.help.SetBindings(bindings) overlay := a.help.View() @@ -398,6 +436,7 @@ func New(app *app.App) tea.Model { sessionDialog: dialog.NewSessionDialogCmp(), permissions: dialog.NewPermissionDialogCmp(), app: app, + editingMode: true, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), -- cgit v1.2.3 From a8d5787e8ef561037f73b669128f46ae1b1e8553 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 14:29:03 +0200 Subject: config validation --- .opencode.json | 1 + cmd/schema/README.md | 64 +++++++++ cmd/schema/main.go | 262 ++++++++++++++++++++++++++++++++++++ internal/config/config.go | 276 +++++++++++++++++++++++++++++++++++++- internal/llm/agent/agent.go | 2 +- internal/llm/tools/edit.go | 27 +++- internal/llm/tools/write.go | 11 +- internal/permission/permission.go | 10 +- internal/tui/tui.go | 4 +- internal/version/version.go | 2 +- opencode-schema.json | 269 +++++++++++++++++++++++++++++++++++++ 11 files changed, 911 insertions(+), 17 deletions(-) create mode 100644 cmd/schema/README.md create mode 100644 cmd/schema/main.go create mode 100644 opencode-schema.json (limited to 'cmd') diff --git a/.opencode.json b/.opencode.json index b7fc19b52..c4d1547a0 100644 --- a/.opencode.json +++ b/.opencode.json @@ -1,4 +1,5 @@ { + "$schema": "./opencode-schema.json", "lsp": { "gopls": { "command": "gopls" diff --git a/cmd/schema/README.md b/cmd/schema/README.md new file mode 100644 index 000000000..93ebe9f03 --- /dev/null +++ b/cmd/schema/README.md @@ -0,0 +1,64 @@ +# OpenCode Configuration Schema Generator + +This tool generates a JSON Schema for the OpenCode configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema. + +## Usage + +```bash +go run cmd/schema/main.go > opencode-schema.json +``` + +This will generate a JSON Schema file that can be used to validate configuration files. + +## Schema Features + +The generated schema includes: + +- All configuration options with descriptions +- Default values where applicable +- Validation for enum values (e.g., model IDs, provider types) +- Required fields +- Type checking + +## Using the Schema + +You can use the generated schema in several ways: + +1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.opencode.json` files. + +2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema. + +3. **Documentation**: The schema serves as documentation for the configuration options. + +## Example Configuration + +Here's an example configuration that conforms to the schema: + +```json +{ + "data": { + "directory": ".opencode" + }, + "debug": false, + "providers": { + "anthropic": { + "apiKey": "your-api-key" + } + }, + "agents": { + "coder": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000, + "reasoningEffort": "medium" + }, + "task": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "title": { + "model": "claude-3.7-sonnet", + "maxTokens": 80 + } + } +} +``` \ No newline at end of file diff --git a/cmd/schema/main.go b/cmd/schema/main.go new file mode 100644 index 000000000..030c0907e --- /dev/null +++ b/cmd/schema/main.go @@ -0,0 +1,262 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" +) + +// JSONSchemaType represents a JSON Schema type +type JSONSchemaType struct { + Type string `json:"type,omitempty"` + Description string `json:"description,omitempty"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` + Enum []any `json:"enum,omitempty"` + Items map[string]any `json:"items,omitempty"` + OneOf []map[string]any `json:"oneOf,omitempty"` + AnyOf []map[string]any `json:"anyOf,omitempty"` + Default any `json:"default,omitempty"` +} + +func main() { + schema := generateSchema() + + // Pretty print the schema + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + if err := encoder.Encode(schema); err != nil { + fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err) + os.Exit(1) + } +} + +func generateSchema() map[string]any { + schema := map[string]any{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "OpenCode Configuration", + "description": "Configuration schema for the OpenCode application", + "type": "object", + "properties": map[string]any{}, + } + + // Add Data configuration + schema["properties"].(map[string]any)["data"] = map[string]any{ + "type": "object", + "description": "Storage configuration", + "properties": map[string]any{ + "directory": map[string]any{ + "type": "string", + "description": "Directory where application data is stored", + "default": ".opencode", + }, + }, + "required": []string{"directory"}, + } + + // Add working directory + schema["properties"].(map[string]any)["wd"] = map[string]any{ + "type": "string", + "description": "Working directory for the application", + } + + // Add debug flags + schema["properties"].(map[string]any)["debug"] = map[string]any{ + "type": "boolean", + "description": "Enable debug mode", + "default": false, + } + + schema["properties"].(map[string]any)["debugLSP"] = map[string]any{ + "type": "boolean", + "description": "Enable LSP debug mode", + "default": false, + } + + // Add MCP servers + schema["properties"].(map[string]any)["mcpServers"] = map[string]any{ + "type": "object", + "description": "Model Control Protocol server configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "MCP server configuration", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "Command to execute for the MCP server", + }, + "env": map[string]any{ + "type": "array", + "description": "Environment variables for the MCP server", + "items": map[string]any{ + "type": "string", + }, + }, + "args": map[string]any{ + "type": "array", + "description": "Command arguments for the MCP server", + "items": map[string]any{ + "type": "string", + }, + }, + "type": map[string]any{ + "type": "string", + "description": "Type of MCP server", + "enum": []string{"stdio", "sse"}, + "default": "stdio", + }, + "url": map[string]any{ + "type": "string", + "description": "URL for SSE type MCP servers", + }, + "headers": map[string]any{ + "type": "object", + "description": "HTTP headers for SSE type MCP servers", + "additionalProperties": map[string]any{ + "type": "string", + }, + }, + }, + "required": []string{"command"}, + }, + } + + // Add providers + providerSchema := map[string]any{ + "type": "object", + "description": "LLM provider configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "Provider configuration", + "properties": map[string]any{ + "apiKey": map[string]any{ + "type": "string", + "description": "API key for the provider", + }, + "disabled": map[string]any{ + "type": "boolean", + "description": "Whether the provider is disabled", + "default": false, + }, + }, + }, + } + + // Add known providers + knownProviders := []string{ + string(models.ProviderAnthropic), + string(models.ProviderOpenAI), + string(models.ProviderGemini), + string(models.ProviderGROQ), + string(models.ProviderBedrock), + } + + providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ + "type": "string", + "description": "Provider type", + "enum": knownProviders, + } + + schema["properties"].(map[string]any)["providers"] = providerSchema + + // Add agents + agentSchema := map[string]any{ + "type": "object", + "description": "Agent configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "Agent configuration", + "properties": map[string]any{ + "model": map[string]any{ + "type": "string", + "description": "Model ID for the agent", + }, + "maxTokens": map[string]any{ + "type": "integer", + "description": "Maximum tokens for the agent", + "minimum": 1, + }, + "reasoningEffort": map[string]any{ + "type": "string", + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": []string{"low", "medium", "high"}, + }, + }, + "required": []string{"model"}, + }, + } + + // Add model enum + modelEnum := []string{} + for modelID := range models.SupportedModels { + modelEnum = append(modelEnum, string(modelID)) + } + agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum + + // Add specific agent properties + agentProperties := map[string]any{} + knownAgents := []string{ + string(config.AgentCoder), + string(config.AgentTask), + string(config.AgentTitle), + } + + for _, agentName := range knownAgents { + agentProperties[agentName] = map[string]any{ + "$ref": "#/definitions/agent", + } + } + + // Create a combined schema that allows both specific agents and additional ones + combinedAgentSchema := map[string]any{ + "type": "object", + "description": "Agent configurations", + "properties": agentProperties, + "additionalProperties": agentSchema["additionalProperties"], + } + + schema["properties"].(map[string]any)["agents"] = combinedAgentSchema + schema["definitions"] = map[string]any{ + "agent": agentSchema["additionalProperties"], + } + + // Add LSP configuration + schema["properties"].(map[string]any)["lsp"] = map[string]any{ + "type": "object", + "description": "Language Server Protocol configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "LSP configuration for a language", + "properties": map[string]any{ + "disabled": map[string]any{ + "type": "boolean", + "description": "Whether the LSP is disabled", + "default": false, + }, + "command": map[string]any{ + "type": "string", + "description": "Command to execute for the LSP server", + }, + "args": map[string]any{ + "type": "array", + "description": "Command arguments for the LSP server", + "items": map[string]any{ + "type": "string", + }, + }, + "options": map[string]any{ + "type": "object", + "description": "Additional options for the LSP server", + }, + }, + "required": []string{"command"}, + }, + } + + return schema +} + diff --git a/internal/config/config.go b/internal/config/config.go index 2dbbcc9ca..13c7d1328 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -120,13 +120,11 @@ func Load(workingDir string, debug bool) (*Config, error) { } applyDefaultValues() - defaultLevel := slog.LevelInfo if cfg.Debug { defaultLevel = slog.LevelDebug } - // if we are in debug mode make the writer a file - if cfg.Debug { + if os.Getenv("OPENCODE_DEV_DEBUG") == "true" { loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") // if file does not exist create it @@ -156,6 +154,11 @@ func Load(workingDir string, debug bool) (*Config, error) { slog.SetDefault(logger) } + // Validate configuration + if err := Validate(); err != nil { + return cfg, fmt.Errorf("config validation failed: %w", err) + } + if cfg.Agents == nil { cfg.Agents = make(map[AgentName]Agent) } @@ -302,6 +305,273 @@ func applyDefaultValues() { } } +// Validate checks if the configuration is valid and applies defaults where needed. +// It validates model IDs and providers, ensuring they are supported. +func Validate() error { + if cfg == nil { + return fmt.Errorf("config not loaded") + } + + // Validate agent models + for name, agent := range cfg.Agents { + // Check if model exists + model, modelExists := models.SupportedModels[agent.Model] + if !modelExists { + logging.Warn("unsupported model configured, reverting to default", + "agent", name, + "configured_model", agent.Model) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + continue + } + + // Check if provider for the model is configured + provider := model.Provider + providerCfg, providerExists := cfg.Providers[provider] + + if !providerExists { + // Provider not configured, check if we have environment variables + apiKey := getProviderAPIKey(provider) + if apiKey == "" { + logging.Warn("provider not configured for model, reverting to default", + "agent", name, + "model", agent.Model, + "provider", provider) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + } else { + // Add provider with API key from environment + cfg.Providers[provider] = Provider{ + APIKey: apiKey, + } + logging.Info("added provider from environment", "provider", provider) + } + } else if providerCfg.Disabled || providerCfg.APIKey == "" { + // Provider is disabled or has no API key + logging.Warn("provider is disabled or has no API key, reverting to default", + "agent", name, + "model", agent.Model, + "provider", provider) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + } + + // Validate max tokens + if agent.MaxTokens <= 0 { + logging.Warn("invalid max tokens, setting to default", + "agent", name, + "model", agent.Model, + "max_tokens", agent.MaxTokens) + + // Update the agent with default max tokens + updatedAgent := cfg.Agents[name] + if model.DefaultMaxTokens > 0 { + updatedAgent.MaxTokens = model.DefaultMaxTokens + } else { + updatedAgent.MaxTokens = 4096 // Fallback default + } + cfg.Agents[name] = updatedAgent + } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 { + // Ensure max tokens doesn't exceed half the context window (reasonable limit) + logging.Warn("max tokens exceeds half the context window, adjusting", + "agent", name, + "model", agent.Model, + "max_tokens", agent.MaxTokens, + "context_window", model.ContextWindow) + + // Update the agent with adjusted max tokens + updatedAgent := cfg.Agents[name] + updatedAgent.MaxTokens = model.ContextWindow / 2 + cfg.Agents[name] = updatedAgent + } + + // Validate reasoning effort for models that support reasoning + if model.CanReason && provider == models.ProviderOpenAI { + if agent.ReasoningEffort == "" { + // Set default reasoning effort for models that support it + logging.Info("setting default reasoning effort for model that supports reasoning", + "agent", name, + "model", agent.Model) + + // Update the agent with default reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "medium" + cfg.Agents[name] = updatedAgent + } else { + // Check if reasoning effort is valid (low, medium, high) + effort := strings.ToLower(agent.ReasoningEffort) + if effort != "low" && effort != "medium" && effort != "high" { + logging.Warn("invalid reasoning effort, setting to medium", + "agent", name, + "model", agent.Model, + "reasoning_effort", agent.ReasoningEffort) + + // Update the agent with valid reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "medium" + cfg.Agents[name] = updatedAgent + } + } + } else if !model.CanReason && agent.ReasoningEffort != "" { + // Model doesn't support reasoning but reasoning effort is set + logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring", + "agent", name, + "model", agent.Model, + "reasoning_effort", agent.ReasoningEffort) + + // Update the agent to remove reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "" + cfg.Agents[name] = updatedAgent + } + } + + // Validate providers + for provider, providerCfg := range cfg.Providers { + if providerCfg.APIKey == "" && !providerCfg.Disabled { + logging.Warn("provider has no API key, marking as disabled", "provider", provider) + providerCfg.Disabled = true + cfg.Providers[provider] = providerCfg + } + } + + // Validate LSP configurations + for language, lspConfig := range cfg.LSP { + if lspConfig.Command == "" && !lspConfig.Disabled { + logging.Warn("LSP configuration has no command, marking as disabled", "language", language) + lspConfig.Disabled = true + cfg.LSP[language] = lspConfig + } + } + + return nil +} + +// getProviderAPIKey gets the API key for a provider from environment variables +func getProviderAPIKey(provider models.ModelProvider) string { + switch provider { + case models.ProviderAnthropic: + return os.Getenv("ANTHROPIC_API_KEY") + case models.ProviderOpenAI: + return os.Getenv("OPENAI_API_KEY") + case models.ProviderGemini: + return os.Getenv("GEMINI_API_KEY") + case models.ProviderGROQ: + return os.Getenv("GROQ_API_KEY") + case models.ProviderBedrock: + if hasAWSCredentials() { + return "aws-credentials-available" + } + } + return "" +} + +// setDefaultModelForAgent sets a default model for an agent based on available providers +func setDefaultModelForAgent(agent AgentName) bool { + // Check providers in order of preference + if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + cfg.Agents[agent] = Agent{ + Model: models.Claude37Sonnet, + MaxTokens: maxTokens, + } + return true + } + + if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { + var model models.ModelID + maxTokens := int64(5000) + reasoningEffort := "" + + switch agent { + case AgentTitle: + model = models.GPT41Mini + maxTokens = 80 + case AgentTask: + model = models.GPT41Mini + default: + model = models.GPT41 + } + + // Check if model supports reasoning + if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { + reasoningEffort = "medium" + } + + cfg.Agents[agent] = Agent{ + Model: model, + MaxTokens: maxTokens, + ReasoningEffort: reasoningEffort, + } + return true + } + + if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { + var model models.ModelID + maxTokens := int64(5000) + + if agent == AgentTitle { + model = models.Gemini25Flash + maxTokens = 80 + } else { + model = models.Gemini25 + } + + cfg.Agents[agent] = Agent{ + Model: model, + MaxTokens: maxTokens, + } + return true + } + + if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + + cfg.Agents[agent] = Agent{ + Model: models.QWENQwq, + MaxTokens: maxTokens, + } + return true + } + + if hasAWSCredentials() { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + + cfg.Agents[agent] = Agent{ + Model: models.BedrockClaude37Sonnet, + MaxTokens: maxTokens, + ReasoningEffort: "medium", // Claude models support reasoning + } + return true + } + + return false +} + // Get returns the current configuration. // It's safe to call this function multiple times. func Get() *Config { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ae5bcb231..6c5808eab 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -471,7 +471,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithReasoningEffort(agentConfig.ReasoningEffort), ), ) - } else if model.Provider == models.ProviderAnthropic && model.CanReason { + } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder { opts = append( opts, provider.WithAnthropicOptions( diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index b7b813ca7..23c44399b 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -196,11 +196,16 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) content, filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "create", + Action: "write", Description: fmt.Sprintf("Create file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, @@ -301,11 +306,16 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "delete", + Action: "write", Description: fmt.Sprintf("Delete content from file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, @@ -415,11 +425,16 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent, filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "replace", + Action: "write", Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 2b3fa3dd0..3a94b47b6 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" "github.com/kujtimiihoxha/opencode/internal/config" @@ -159,11 +160,17 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error params.Content, filePath, ) + + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := w.permissions.Request( permission.CreatePermissionRequest{ - Path: filePath, + Path: permissionPath, ToolName: WriteToolName, - Action: "create", + Action: "write", Description: fmt.Sprintf("Create file %s", filePath), Params: WritePermissionsParams{ FilePath: filePath, diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 4cb379dea..06f69a33d 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -2,10 +2,12 @@ package permission import ( "errors" + "path/filepath" "sync" "time" "github.com/google/uuid" + "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/pubsub" ) @@ -67,9 +69,13 @@ func (s *permissionService) Deny(permission PermissionRequest) { } func (s *permissionService) Request(opts CreatePermissionRequest) bool { + dir := filepath.Dir(opts.Path) + if dir == "." { + dir = config.WorkingDirectory() + } permission := PermissionRequest{ ID: uuid.New().String(), - Path: opts.Path, + Path: dir, ToolName: opts.ToolName, Description: opts.Description, Action: opts.Action, @@ -77,7 +83,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { } for _, p := range s.sessionPermissions { - if p.ToolName == permission.ToolName && p.Action == permission.Action { + if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { return true } } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dec43f7c0..392b9ec41 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -57,8 +57,8 @@ var returnKey = key.NewBinding( ) var logsKeyReturnKey = key.NewBinding( - key.WithKeys("backspace"), - key.WithHelp("backspace", "go back"), + key.WithKeys("backspace", "q"), + key.WithHelp("backspace/q", "go back"), ) type appModel struct { diff --git a/internal/version/version.go b/internal/version/version.go index 54c576f6c..1e19bea38 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -5,7 +5,7 @@ import "runtime/debug" // Build-time parameters set via -ldflags var Version = "unknown" -// A user may install pug using `go install github.com/leg100/pug@latest` +// A user may install pug using `go install github.com/kujtimiihoxha/opencode@latest`. // without -ldflags, in which case the version above is unset. As a workaround // we use the embedded build version that *is* set when using `go install` (and // is only set for `go install` and not for `go build`). diff --git a/opencode-schema.json b/opencode-schema.json new file mode 100644 index 000000000..452790cdf --- /dev/null +++ b/opencode-schema.json @@ -0,0 +1,269 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "agent": { + "description": "Agent configuration", + "properties": { + "maxTokens": { + "description": "Maximum tokens for the agent", + "minimum": 1, + "type": "integer" + }, + "model": { + "description": "Model ID for the agent", + "enum": [ + "gemini-2.0-flash", + "bedrock.claude-3.7-sonnet", + "claude-3-opus", + "claude-3.5-sonnet", + "gpt-4o-mini", + "o1", + "o3-mini", + "o1-pro", + "o4-mini", + "claude-3-haiku", + "gpt-4o", + "o3", + "gpt-4.1-mini", + "gpt-4.5-preview", + "gemini-2.5-flash", + "claude-3.5-haiku", + "gpt-4.1", + "gemini-2.0-flash-lite", + "claude-3.7-sonnet", + "o1-mini", + "gpt-4.1-nano", + "gemini-2.5" + ], + "type": "string" + }, + "reasoningEffort": { + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + } + }, + "required": [ + "model" + ], + "type": "object" + } + }, + "description": "Configuration schema for the OpenCode application", + "properties": { + "agents": { + "additionalProperties": { + "description": "Agent configuration", + "properties": { + "maxTokens": { + "description": "Maximum tokens for the agent", + "minimum": 1, + "type": "integer" + }, + "model": { + "description": "Model ID for the agent", + "enum": [ + "gemini-2.0-flash", + "bedrock.claude-3.7-sonnet", + "claude-3-opus", + "claude-3.5-sonnet", + "gpt-4o-mini", + "o1", + "o3-mini", + "o1-pro", + "o4-mini", + "claude-3-haiku", + "gpt-4o", + "o3", + "gpt-4.1-mini", + "gpt-4.5-preview", + "gemini-2.5-flash", + "claude-3.5-haiku", + "gpt-4.1", + "gemini-2.0-flash-lite", + "claude-3.7-sonnet", + "o1-mini", + "gpt-4.1-nano", + "gemini-2.5" + ], + "type": "string" + }, + "reasoningEffort": { + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + } + }, + "required": [ + "model" + ], + "type": "object" + }, + "description": "Agent configurations", + "properties": { + "coder": { + "$ref": "#/definitions/agent" + }, + "task": { + "$ref": "#/definitions/agent" + }, + "title": { + "$ref": "#/definitions/agent" + } + }, + "type": "object" + }, + "data": { + "description": "Storage configuration", + "properties": { + "directory": { + "default": ".opencode", + "description": "Directory where application data is stored", + "type": "string" + } + }, + "required": [ + "directory" + ], + "type": "object" + }, + "debug": { + "default": false, + "description": "Enable debug mode", + "type": "boolean" + }, + "debugLSP": { + "default": false, + "description": "Enable LSP debug mode", + "type": "boolean" + }, + "lsp": { + "additionalProperties": { + "description": "LSP configuration for a language", + "properties": { + "args": { + "description": "Command arguments for the LSP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "command": { + "description": "Command to execute for the LSP server", + "type": "string" + }, + "disabled": { + "default": false, + "description": "Whether the LSP is disabled", + "type": "boolean" + }, + "options": { + "description": "Additional options for the LSP server", + "type": "object" + } + }, + "required": [ + "command" + ], + "type": "object" + }, + "description": "Language Server Protocol configurations", + "type": "object" + }, + "mcpServers": { + "additionalProperties": { + "description": "MCP server configuration", + "properties": { + "args": { + "description": "Command arguments for the MCP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "command": { + "description": "Command to execute for the MCP server", + "type": "string" + }, + "env": { + "description": "Environment variables for the MCP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "headers": { + "additionalProperties": { + "type": "string" + }, + "description": "HTTP headers for SSE type MCP servers", + "type": "object" + }, + "type": { + "default": "stdio", + "description": "Type of MCP server", + "enum": [ + "stdio", + "sse" + ], + "type": "string" + }, + "url": { + "description": "URL for SSE type MCP servers", + "type": "string" + } + }, + "required": [ + "command" + ], + "type": "object" + }, + "description": "Model Control Protocol server configurations", + "type": "object" + }, + "providers": { + "additionalProperties": { + "description": "Provider configuration", + "properties": { + "apiKey": { + "description": "API key for the provider", + "type": "string" + }, + "disabled": { + "default": false, + "description": "Whether the provider is disabled", + "type": "boolean" + }, + "provider": { + "description": "Provider type", + "enum": [ + "anthropic", + "openai", + "gemini", + "groq", + "bedrock" + ], + "type": "string" + } + }, + "type": "object" + }, + "description": "LLM provider configurations", + "type": "object" + }, + "wd": { + "description": "Working directory for the application", + "type": "string" + } + }, + "title": "OpenCode Configuration", + "type": "object" +} -- cgit v1.2.3 From 1e11805efc9f3feaf9b9696bcaa8a8dd599db0b1 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 15:52:32 +0200 Subject: add description --- cmd/root.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'cmd') diff --git a/cmd/root.go b/cmd/root.go index 54280ecaa..545652a7a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -21,8 +21,10 @@ import ( var rootCmd = &cobra.Command{ Use: "OpenCode", - Short: "A terminal ai assistant", - Long: `A terminal ai assistant`, + Short: "A terminal AI assistant for software development", + Long: `OpenCode is a powerful terminal-based AI assistant that helps with software development tasks. +It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration +to assist developers in writing, debugging, and understanding code directly from the terminal.`, RunE: func(cmd *cobra.Command, args []string) error { // If the help flag is set, show the help message if cmd.Flag("help").Changed { -- cgit v1.2.3 From ed3518d0755cb5cae25d9d8f1690ab2e60702588 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 16:24:38 +0200 Subject: small things --- cmd/root.go | 8 ++++---- internal/db/connect.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'cmd') diff --git a/cmd/root.go b/cmd/root.go index 545652a7a..8777acb82 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -182,7 +182,7 @@ func setupSubscriber[T any]( select { case event, ok := <-subCh: if !ok { - logging.Info("%s subscription channel closed", name) + logging.Info("subscription channel closed", "name", name) return } @@ -191,13 +191,13 @@ func setupSubscriber[T any]( select { case outputCh <- msg: case <-time.After(2 * time.Second): - logging.Warn("%s message dropped due to slow consumer", name) + logging.Warn("message dropped due to slow consumer", "name", name) case <-ctx.Done(): - logging.Info("%s subscription cancelled", name) + logging.Info("subscription cancelled", "name", name) return } case <-ctx.Done(): - logging.Info("%s subscription cancelled", name) + logging.Info("subscription cancelled", "name", name) return } } diff --git a/internal/db/connect.go b/internal/db/connect.go index e850bc8d0..9335bfc26 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -48,9 +48,9 @@ func Connect() (*sql.DB, error) { for _, pragma := range pragmas { if _, err = db.Exec(pragma); err != nil { - logging.Warn("Failed to set pragma", pragma, err) + logging.Error("Failed to set pragma", pragma, err) } else { - logging.Warn("Set pragma", "pragma", pragma) + logging.Debug("Set pragma", "pragma", pragma) } } -- cgit v1.2.3