diff options
| author | Kujtim Hoxha <[email protected]> | 2025-03-23 22:25:31 +0100 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-03-23 22:25:31 +0100 |
| commit | e7258e38aeb46281fda474b8b7fcc3eee35edd9f (patch) | |
| tree | 0ae4a7558b3942519ff137aed7c3cd6a9b473bf5 /internal/llm/tools/shell/shell.go | |
| parent | 8daa6e774a6e02698c90392e7b2008542f789594 (diff) | |
| download | opencode-e7258e38aeb46281fda474b8b7fcc3eee35edd9f.tar.gz opencode-e7258e38aeb46281fda474b8b7fcc3eee35edd9f.zip | |
initial agent setup
Diffstat (limited to 'internal/llm/tools/shell/shell.go')
| -rw-r--r-- | internal/llm/tools/shell/shell.go | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go new file mode 100644 index 000000000..d63c50dff --- /dev/null +++ b/internal/llm/tools/shell/shell.go @@ -0,0 +1,294 @@ +package shell + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + "time" +) + +type PersistentShell struct { + cmd *exec.Cmd + stdin *os.File + isAlive bool + cwd string + mu sync.Mutex + commandQueue chan *commandExecution +} + +type commandExecution struct { + command string + timeout time.Duration + resultChan chan commandResult + ctx context.Context +} + +type commandResult struct { + stdout string + stderr string + exitCode int + interrupted bool + err error +} + +var ( + shellInstance *PersistentShell + shellInstanceOnce sync.Once +) + +func GetPersistentShell(workingDir string) *PersistentShell { + shellInstanceOnce.Do(func() { + shellInstance = newPersistentShell(workingDir) + }) + + if !shellInstance.isAlive { + shellInstance = newPersistentShell(shellInstance.cwd) + } + + return shellInstance +} + +func newPersistentShell(cwd string) *PersistentShell { + shellPath := os.Getenv("SHELL") + if shellPath == "" { + shellPath = "/bin/bash" + } + + cmd := exec.Command(shellPath, "-l") + cmd.Dir = cwd + + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return nil + } + + cmd.Env = append(os.Environ(), "GIT_EDITOR=true") + + err = cmd.Start() + if err != nil { + return nil + } + + shell := &PersistentShell{ + cmd: cmd, + stdin: stdinPipe.(*os.File), + isAlive: true, + cwd: cwd, + commandQueue: make(chan *commandExecution, 10), + } + + go shell.processCommands() + + go func() { + err := cmd.Wait() + if err != nil { + } + shell.isAlive = false + close(shell.commandQueue) + }() + + return shell +} + +func (s *PersistentShell) processCommands() { + for cmd := range s.commandQueue { + result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx) + cmd.resultChan <- result + } +} + +func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.isAlive { + return commandResult{ + stderr: "Shell is not alive", + exitCode: 1, + err: errors.New("shell is not alive"), + } + } + + tempDir := os.TempDir() + stdoutFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stdout-%d", time.Now().UnixNano())) + stderrFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stderr-%d", time.Now().UnixNano())) + statusFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-status-%d", time.Now().UnixNano())) + cwdFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-cwd-%d", time.Now().UnixNano())) + + defer func() { + os.Remove(stdoutFile) + os.Remove(stderrFile) + os.Remove(statusFile) + os.Remove(cwdFile) + }() + + fullCommand := fmt.Sprintf(` +eval %s < /dev/null > %s 2> %s +EXEC_EXIT_CODE=$? +pwd > %s +echo $EXEC_EXIT_CODE > %s +`, + shellQuote(command), + shellQuote(stdoutFile), + shellQuote(stderrFile), + shellQuote(cwdFile), + shellQuote(statusFile), + ) + + _, err := s.stdin.Write([]byte(fullCommand + "\n")) + if err != nil { + return commandResult{ + stderr: fmt.Sprintf("Failed to write command to shell: %v", err), + exitCode: 1, + err: err, + } + } + + interrupted := false + + startTime := time.Now() + + done := make(chan bool) + go func() { + for { + select { + case <-ctx.Done(): + s.killChildren() + interrupted = true + done <- true + return + + case <-time.After(10 * time.Millisecond): + if fileExists(statusFile) && fileSize(statusFile) > 0 { + done <- true + return + } + + if timeout > 0 { + elapsed := time.Since(startTime) + if elapsed > timeout { + s.killChildren() + interrupted = true + done <- true + return + } + } + } + } + }() + + <-done + + stdout := readFileOrEmpty(stdoutFile) + stderr := readFileOrEmpty(stderrFile) + exitCodeStr := readFileOrEmpty(statusFile) + newCwd := readFileOrEmpty(cwdFile) + + exitCode := 0 + if exitCodeStr != "" { + fmt.Sscanf(exitCodeStr, "%d", &exitCode) + } else if interrupted { + exitCode = 143 + stderr += "\nCommand execution timed out or was interrupted" + } + + if newCwd != "" { + s.cwd = strings.TrimSpace(newCwd) + } + + return commandResult{ + stdout: stdout, + stderr: stderr, + exitCode: exitCode, + interrupted: interrupted, + } +} + +func (s *PersistentShell) killChildren() { + if s.cmd == nil || s.cmd.Process == nil { + return + } + + pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid)) + output, err := pgrepCmd.Output() + if err != nil { + return + } + + for _, pidStr := range strings.Split(string(output), "\n") { + if pidStr = strings.TrimSpace(pidStr); pidStr != "" { + var pid int + fmt.Sscanf(pidStr, "%d", &pid) + if pid > 0 { + proc, err := os.FindProcess(pid) + if err == nil { + proc.Signal(syscall.SIGTERM) + } + } + } + } +} + +func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) { + if !s.isAlive { + return "", "Shell is not alive", 1, false, errors.New("shell is not alive") + } + + timeout := time.Duration(timeoutMs) * time.Millisecond + + resultChan := make(chan commandResult) + s.commandQueue <- &commandExecution{ + command: command, + timeout: timeout, + resultChan: resultChan, + ctx: ctx, + } + + result := <-resultChan + return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err +} + +func (s *PersistentShell) Close() { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.isAlive { + return + } + + s.stdin.Write([]byte("exit\n")) + + s.cmd.Process.Kill() + s.isAlive = false +} + +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +func readFileOrEmpty(path string) string { + content, err := os.ReadFile(path) + if err != nil { + return "" + } + return string(content) +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func fileSize(path string) int64 { + info, err := os.Stat(path) + if err != nil { + return 0 + } + return info.Size() +} |
