summaryrefslogtreecommitdiffhomepage
path: root/internal/llm/tools/shell/shell.go
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-03-23 22:25:31 +0100
committerKujtim Hoxha <[email protected]>2025-03-23 22:25:31 +0100
commite7258e38aeb46281fda474b8b7fcc3eee35edd9f (patch)
tree0ae4a7558b3942519ff137aed7c3cd6a9b473bf5 /internal/llm/tools/shell/shell.go
parent8daa6e774a6e02698c90392e7b2008542f789594 (diff)
downloadopencode-e7258e38aeb46281fda474b8b7fcc3eee35edd9f.tar.gz
opencode-e7258e38aeb46281fda474b8b7fcc3eee35edd9f.zip
initial agent setup
Diffstat (limited to 'internal/llm/tools/shell/shell.go')
-rw-r--r--internal/llm/tools/shell/shell.go294
1 files changed, 294 insertions, 0 deletions
diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go
new file mode 100644
index 000000000..d63c50dff
--- /dev/null
+++ b/internal/llm/tools/shell/shell.go
@@ -0,0 +1,294 @@
+package shell
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+)
+
+type PersistentShell struct {
+ cmd *exec.Cmd
+ stdin *os.File
+ isAlive bool
+ cwd string
+ mu sync.Mutex
+ commandQueue chan *commandExecution
+}
+
+type commandExecution struct {
+ command string
+ timeout time.Duration
+ resultChan chan commandResult
+ ctx context.Context
+}
+
+type commandResult struct {
+ stdout string
+ stderr string
+ exitCode int
+ interrupted bool
+ err error
+}
+
+var (
+ shellInstance *PersistentShell
+ shellInstanceOnce sync.Once
+)
+
+func GetPersistentShell(workingDir string) *PersistentShell {
+ shellInstanceOnce.Do(func() {
+ shellInstance = newPersistentShell(workingDir)
+ })
+
+ if !shellInstance.isAlive {
+ shellInstance = newPersistentShell(shellInstance.cwd)
+ }
+
+ return shellInstance
+}
+
+func newPersistentShell(cwd string) *PersistentShell {
+ shellPath := os.Getenv("SHELL")
+ if shellPath == "" {
+ shellPath = "/bin/bash"
+ }
+
+ cmd := exec.Command(shellPath, "-l")
+ cmd.Dir = cwd
+
+ stdinPipe, err := cmd.StdinPipe()
+ if err != nil {
+ return nil
+ }
+
+ cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
+
+ err = cmd.Start()
+ if err != nil {
+ return nil
+ }
+
+ shell := &PersistentShell{
+ cmd: cmd,
+ stdin: stdinPipe.(*os.File),
+ isAlive: true,
+ cwd: cwd,
+ commandQueue: make(chan *commandExecution, 10),
+ }
+
+ go shell.processCommands()
+
+ go func() {
+ err := cmd.Wait()
+ if err != nil {
+ }
+ shell.isAlive = false
+ close(shell.commandQueue)
+ }()
+
+ return shell
+}
+
+func (s *PersistentShell) processCommands() {
+ for cmd := range s.commandQueue {
+ result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
+ cmd.resultChan <- result
+ }
+}
+
+func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.isAlive {
+ return commandResult{
+ stderr: "Shell is not alive",
+ exitCode: 1,
+ err: errors.New("shell is not alive"),
+ }
+ }
+
+ tempDir := os.TempDir()
+ stdoutFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stdout-%d", time.Now().UnixNano()))
+ stderrFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stderr-%d", time.Now().UnixNano()))
+ statusFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-status-%d", time.Now().UnixNano()))
+ cwdFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-cwd-%d", time.Now().UnixNano()))
+
+ defer func() {
+ os.Remove(stdoutFile)
+ os.Remove(stderrFile)
+ os.Remove(statusFile)
+ os.Remove(cwdFile)
+ }()
+
+ fullCommand := fmt.Sprintf(`
+eval %s < /dev/null > %s 2> %s
+EXEC_EXIT_CODE=$?
+pwd > %s
+echo $EXEC_EXIT_CODE > %s
+`,
+ shellQuote(command),
+ shellQuote(stdoutFile),
+ shellQuote(stderrFile),
+ shellQuote(cwdFile),
+ shellQuote(statusFile),
+ )
+
+ _, err := s.stdin.Write([]byte(fullCommand + "\n"))
+ if err != nil {
+ return commandResult{
+ stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
+ exitCode: 1,
+ err: err,
+ }
+ }
+
+ interrupted := false
+
+ startTime := time.Now()
+
+ done := make(chan bool)
+ go func() {
+ for {
+ select {
+ case <-ctx.Done():
+ s.killChildren()
+ interrupted = true
+ done <- true
+ return
+
+ case <-time.After(10 * time.Millisecond):
+ if fileExists(statusFile) && fileSize(statusFile) > 0 {
+ done <- true
+ return
+ }
+
+ if timeout > 0 {
+ elapsed := time.Since(startTime)
+ if elapsed > timeout {
+ s.killChildren()
+ interrupted = true
+ done <- true
+ return
+ }
+ }
+ }
+ }
+ }()
+
+ <-done
+
+ stdout := readFileOrEmpty(stdoutFile)
+ stderr := readFileOrEmpty(stderrFile)
+ exitCodeStr := readFileOrEmpty(statusFile)
+ newCwd := readFileOrEmpty(cwdFile)
+
+ exitCode := 0
+ if exitCodeStr != "" {
+ fmt.Sscanf(exitCodeStr, "%d", &exitCode)
+ } else if interrupted {
+ exitCode = 143
+ stderr += "\nCommand execution timed out or was interrupted"
+ }
+
+ if newCwd != "" {
+ s.cwd = strings.TrimSpace(newCwd)
+ }
+
+ return commandResult{
+ stdout: stdout,
+ stderr: stderr,
+ exitCode: exitCode,
+ interrupted: interrupted,
+ }
+}
+
+func (s *PersistentShell) killChildren() {
+ if s.cmd == nil || s.cmd.Process == nil {
+ return
+ }
+
+ pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
+ output, err := pgrepCmd.Output()
+ if err != nil {
+ return
+ }
+
+ for _, pidStr := range strings.Split(string(output), "\n") {
+ if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
+ var pid int
+ fmt.Sscanf(pidStr, "%d", &pid)
+ if pid > 0 {
+ proc, err := os.FindProcess(pid)
+ if err == nil {
+ proc.Signal(syscall.SIGTERM)
+ }
+ }
+ }
+ }
+}
+
+func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
+ if !s.isAlive {
+ return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
+ }
+
+ timeout := time.Duration(timeoutMs) * time.Millisecond
+
+ resultChan := make(chan commandResult)
+ s.commandQueue <- &commandExecution{
+ command: command,
+ timeout: timeout,
+ resultChan: resultChan,
+ ctx: ctx,
+ }
+
+ result := <-resultChan
+ return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
+}
+
+func (s *PersistentShell) Close() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.isAlive {
+ return
+ }
+
+ s.stdin.Write([]byte("exit\n"))
+
+ s.cmd.Process.Kill()
+ s.isAlive = false
+}
+
+func shellQuote(s string) string {
+ return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
+}
+
+func readFileOrEmpty(path string) string {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return ""
+ }
+ return string(content)
+}
+
+func fileExists(path string) bool {
+ _, err := os.Stat(path)
+ return err == nil
+}
+
+func fileSize(path string) int64 {
+ info, err := os.Stat(path)
+ if err != nil {
+ return 0
+ }
+ return info.Size()
+}