summaryrefslogtreecommitdiffhomepage
path: root/internal
diff options
context:
space:
mode:
authoradamdottv <[email protected]>2025-05-15 12:44:16 -0500
committeradamdottv <[email protected]>2025-05-15 12:44:16 -0500
commita65e593ab4f35e1a647832ba36be2c696e1f5165 (patch)
tree38eed9994399e426a418c1e0424fe55133b648f4 /internal
parent5d9058eb74581091d84b2cd935927da636b3dd37 (diff)
downloadopencode-a65e593ab4f35e1a647832ba36be2c696e1f5165.tar.gz
opencode-a65e593ab4f35e1a647832ba36be2c696e1f5165.zip
feat: batch tool
Diffstat (limited to 'internal')
-rw-r--r--internal/llm/agent/tools.go64
-rw-r--r--internal/llm/tools/batch.go191
-rw-r--r--internal/llm/tools/batch_test.go224
-rw-r--r--internal/tui/components/chat/message.go40
4 files changed, 498 insertions, 21 deletions
diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go
index dba437bd2..157b5bf59 100644
--- a/internal/llm/agent/tools.go
+++ b/internal/llm/agent/tools.go
@@ -21,30 +21,41 @@ func PrimaryAgentTools(
ctx := context.Background()
mcpTools := GetMcpTools(ctx, permissions)
- return append(
- []tools.BaseTool{
- tools.NewBashTool(permissions),
- tools.NewEditTool(lspClients, permissions, history),
- tools.NewFetchTool(permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewViewTool(lspClients),
- tools.NewPatchTool(lspClients, permissions, history),
- tools.NewWriteTool(lspClients, permissions, history),
- tools.NewDiagnosticsTool(lspClients),
- tools.NewDefinitionTool(lspClients),
- tools.NewReferencesTool(lspClients),
- tools.NewDocSymbolsTool(lspClients),
- tools.NewWorkspaceSymbolsTool(lspClients),
- tools.NewCodeActionTool(lspClients),
- NewAgentTool(sessions, messages, lspClients),
- }, mcpTools...,
- )
+ // Create the list of tools
+ toolsList := []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions, history),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewPatchTool(lspClients, permissions, history),
+ tools.NewWriteTool(lspClients, permissions, history),
+ tools.NewDiagnosticsTool(lspClients),
+ tools.NewDefinitionTool(lspClients),
+ tools.NewReferencesTool(lspClients),
+ tools.NewDocSymbolsTool(lspClients),
+ tools.NewWorkspaceSymbolsTool(lspClients),
+ tools.NewCodeActionTool(lspClients),
+ NewAgentTool(sessions, messages, lspClients),
+ }
+
+ // Create a map of tools for the batch tool
+ toolsMap := make(map[string]tools.BaseTool)
+ for _, tool := range toolsList {
+ toolsMap[tool.Info().Name] = tool
+ }
+
+ // Add the batch tool with access to all other tools
+ toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
+
+ return append(toolsList, mcpTools...)
}
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
- return []tools.BaseTool{
+ // Create the list of tools
+ toolsList := []tools.BaseTool{
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
@@ -54,4 +65,15 @@ func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
tools.NewDocSymbolsTool(lspClients),
tools.NewWorkspaceSymbolsTool(lspClients),
}
+
+ // Create a map of tools for the batch tool
+ toolsMap := make(map[string]tools.BaseTool)
+ for _, tool := range toolsList {
+ toolsMap[tool.Info().Name] = tool
+ }
+
+ // Add the batch tool with access to all other tools
+ toolsList = append(toolsList, tools.NewBatchTool(toolsMap))
+
+ return toolsList
}
diff --git a/internal/llm/tools/batch.go b/internal/llm/tools/batch.go
new file mode 100644
index 000000000..55101a50f
--- /dev/null
+++ b/internal/llm/tools/batch.go
@@ -0,0 +1,191 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "sync"
+)
+
+type BatchToolCall struct {
+ Name string `json:"name"`
+ Input json.RawMessage `json:"input"`
+}
+
+type BatchParams struct {
+ Calls []BatchToolCall `json:"calls"`
+}
+
+type BatchToolResult struct {
+ ToolName string `json:"tool_name"`
+ ToolInput json.RawMessage `json:"tool_input"`
+ Result json.RawMessage `json:"result"`
+ Error string `json:"error,omitempty"`
+ // Added for better formatting and separation between results
+ Separator string `json:"separator,omitempty"`
+}
+
+type BatchResult struct {
+ Results []BatchToolResult `json:"results"`
+}
+
+type batchTool struct {
+ tools map[string]BaseTool
+}
+
+const (
+ BatchToolName = "batch"
+ BatchToolDescription = `Executes multiple tool calls in parallel and returns their results.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to run multiple independent tool calls at once
+- Helpful for improving performance by parallelizing operations
+- Great for gathering information from multiple sources simultaneously
+
+HOW TO USE:
+- Provide an array of tool calls, each with a name and input
+- Each tool call will be executed in parallel
+- Results are returned in the same order as the input calls
+
+FEATURES:
+- Runs tool calls concurrently for better performance
+- Returns both results and errors for each call
+- Maintains the order of results to match input calls
+
+LIMITATIONS:
+- All tools must be available in the current context
+- Complex error handling may be required for some use cases
+- Not suitable for tool calls that depend on each other's results
+
+TIPS:
+- Use for independent operations like multiple file reads or searches
+- Great for batch operations like searching multiple directories
+- Combine with other tools for more complex workflows`
+)
+
+func NewBatchTool(tools map[string]BaseTool) BaseTool {
+ return &batchTool{
+ tools: tools,
+ }
+}
+
+func (b *batchTool) Info() ToolInfo {
+ return ToolInfo{
+ Name: BatchToolName,
+ Description: BatchToolDescription,
+ Parameters: map[string]any{
+ "calls": map[string]any{
+ "type": "array",
+ "description": "Array of tool calls to execute in parallel",
+ "items": map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "description": "Name of the tool to call",
+ },
+ "input": map[string]any{
+ "type": "object",
+ "description": "Input parameters for the tool",
+ },
+ },
+ "required": []string{"name", "input"},
+ },
+ },
+ },
+ Required: []string{"calls"},
+ }
+}
+
+func (b *batchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ var params BatchParams
+ if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+ }
+
+ if len(params.Calls) == 0 {
+ return NewTextErrorResponse("no tool calls provided"), nil
+ }
+
+ var wg sync.WaitGroup
+ results := make([]BatchToolResult, len(params.Calls))
+
+ for i, toolCall := range params.Calls {
+ wg.Add(1)
+ go func(index int, tc BatchToolCall) {
+ defer wg.Done()
+
+ // Create separator for better visual distinction between results
+ separator := ""
+ if index > 0 {
+ separator = fmt.Sprintf("\n%s\n", strings.Repeat("=", 80))
+ }
+
+ result := BatchToolResult{
+ ToolName: tc.Name,
+ ToolInput: tc.Input,
+ Separator: separator,
+ }
+
+ tool, ok := b.tools[tc.Name]
+ if !ok {
+ result.Error = fmt.Sprintf("tool not found: %s", tc.Name)
+ results[index] = result
+ return
+ }
+
+ // Create a proper ToolCall object
+ callObj := ToolCall{
+ ID: fmt.Sprintf("batch-%d", index),
+ Name: tc.Name,
+ Input: string(tc.Input),
+ }
+
+ response, err := tool.Run(ctx, callObj)
+ if err != nil {
+ result.Error = fmt.Sprintf("error executing tool %s: %s", tc.Name, err)
+ results[index] = result
+ return
+ }
+
+ // Standardize metadata format if present
+ if response.Metadata != "" {
+ var metadata map[string]interface{}
+ if err := json.Unmarshal([]byte(response.Metadata), &metadata); err == nil {
+ // Add tool name to metadata for better context
+ metadata["tool"] = tc.Name
+
+ // Re-marshal with consistent formatting
+ if metadataBytes, err := json.MarshalIndent(metadata, "", " "); err == nil {
+ response.Metadata = string(metadataBytes)
+ }
+ }
+ }
+
+ // Convert the response to JSON
+ responseJSON, err := json.Marshal(response)
+ if err != nil {
+ result.Error = fmt.Sprintf("error marshaling response: %s", err)
+ results[index] = result
+ return
+ }
+
+ result.Result = responseJSON
+ results[index] = result
+ }(i, toolCall)
+ }
+
+ wg.Wait()
+
+ batchResult := BatchResult{
+ Results: results,
+ }
+
+ resultJSON, err := json.Marshal(batchResult)
+ if err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("error marshaling batch result: %s", err)), nil
+ }
+
+ return NewTextResponse(string(resultJSON)), nil
+} \ No newline at end of file
diff --git a/internal/llm/tools/batch_test.go b/internal/llm/tools/batch_test.go
new file mode 100644
index 000000000..1d5f05640
--- /dev/null
+++ b/internal/llm/tools/batch_test.go
@@ -0,0 +1,224 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// MockTool is a simple tool implementation for testing
+type MockTool struct {
+ name string
+ description string
+ response ToolResponse
+ err error
+}
+
+func (m *MockTool) Info() ToolInfo {
+ return ToolInfo{
+ Name: m.name,
+ Description: m.description,
+ Parameters: map[string]any{},
+ Required: []string{},
+ }
+}
+
+func (m *MockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ return m.response, m.err
+}
+
+func TestBatchTool(t *testing.T) {
+ t.Parallel()
+
+ t.Run("successful batch execution", func(t *testing.T) {
+ t.Parallel()
+
+ // Create mock tools
+ mockTools := map[string]BaseTool{
+ "tool1": &MockTool{
+ name: "tool1",
+ description: "Mock Tool 1",
+ response: NewTextResponse("Tool 1 Response"),
+ err: nil,
+ },
+ "tool2": &MockTool{
+ name: "tool2",
+ description: "Mock Tool 2",
+ response: NewTextResponse("Tool 2 Response"),
+ err: nil,
+ },
+ }
+
+ // Create batch tool
+ batchTool := NewBatchTool(mockTools)
+
+ // Create batch call
+ input := `{
+ "calls": [
+ {
+ "name": "tool1",
+ "input": {}
+ },
+ {
+ "name": "tool2",
+ "input": {}
+ }
+ ]
+ }`
+
+ call := ToolCall{
+ ID: "test-batch",
+ Name: "batch",
+ Input: input,
+ }
+
+ // Execute batch
+ response, err := batchTool.Run(context.Background(), call)
+
+ // Verify results
+ assert.NoError(t, err)
+ assert.Equal(t, ToolResponseTypeText, response.Type)
+ assert.False(t, response.IsError)
+
+ // Parse the response
+ var batchResult BatchResult
+ err = json.Unmarshal([]byte(response.Content), &batchResult)
+ assert.NoError(t, err)
+
+ // Verify batch results
+ assert.Len(t, batchResult.Results, 2)
+ assert.Empty(t, batchResult.Results[0].Error)
+ assert.Empty(t, batchResult.Results[1].Error)
+ assert.Empty(t, batchResult.Results[0].Separator)
+ assert.NotEmpty(t, batchResult.Results[1].Separator)
+
+ // Verify individual results
+ var result1 ToolResponse
+ err = json.Unmarshal(batchResult.Results[0].Result, &result1)
+ assert.NoError(t, err)
+ assert.Equal(t, "Tool 1 Response", result1.Content)
+
+ var result2 ToolResponse
+ err = json.Unmarshal(batchResult.Results[1].Result, &result2)
+ assert.NoError(t, err)
+ assert.Equal(t, "Tool 2 Response", result2.Content)
+ })
+
+ t.Run("tool not found", func(t *testing.T) {
+ t.Parallel()
+
+ // Create mock tools
+ mockTools := map[string]BaseTool{
+ "tool1": &MockTool{
+ name: "tool1",
+ description: "Mock Tool 1",
+ response: NewTextResponse("Tool 1 Response"),
+ err: nil,
+ },
+ }
+
+ // Create batch tool
+ batchTool := NewBatchTool(mockTools)
+
+ // Create batch call with non-existent tool
+ input := `{
+ "calls": [
+ {
+ "name": "tool1",
+ "input": {}
+ },
+ {
+ "name": "nonexistent",
+ "input": {}
+ }
+ ]
+ }`
+
+ call := ToolCall{
+ ID: "test-batch",
+ Name: "batch",
+ Input: input,
+ }
+
+ // Execute batch
+ response, err := batchTool.Run(context.Background(), call)
+
+ // Verify results
+ assert.NoError(t, err)
+ assert.Equal(t, ToolResponseTypeText, response.Type)
+ assert.False(t, response.IsError)
+
+ // Parse the response
+ var batchResult BatchResult
+ err = json.Unmarshal([]byte(response.Content), &batchResult)
+ assert.NoError(t, err)
+
+ // Verify batch results
+ assert.Len(t, batchResult.Results, 2)
+ assert.Empty(t, batchResult.Results[0].Error)
+ assert.Contains(t, batchResult.Results[1].Error, "tool not found: nonexistent")
+ })
+
+ t.Run("empty calls", func(t *testing.T) {
+ t.Parallel()
+
+ // Create batch tool with empty tools map
+ batchTool := NewBatchTool(map[string]BaseTool{})
+
+ // Create batch call with empty calls
+ input := `{
+ "calls": []
+ }`
+
+ call := ToolCall{
+ ID: "test-batch",
+ Name: "batch",
+ Input: input,
+ }
+
+ // Execute batch
+ response, err := batchTool.Run(context.Background(), call)
+
+ // Verify results
+ assert.NoError(t, err)
+ assert.Equal(t, ToolResponseTypeText, response.Type)
+ assert.True(t, response.IsError)
+ assert.Contains(t, response.Content, "no tool calls provided")
+ })
+
+ t.Run("invalid input", func(t *testing.T) {
+ t.Parallel()
+
+ // Create batch tool with empty tools map
+ batchTool := NewBatchTool(map[string]BaseTool{})
+
+ // Create batch call with invalid JSON
+ input := `{
+ "calls": [
+ {
+ "name": "tool1",
+ "input": {
+ "invalid": json
+ }
+ }
+ ]
+ }`
+
+ call := ToolCall{
+ ID: "test-batch",
+ Name: "batch",
+ Input: input,
+ }
+
+ // Execute batch
+ response, err := batchTool.Run(context.Background(), call)
+
+ // Verify results
+ assert.NoError(t, err)
+ assert.Equal(t, ToolResponseTypeText, response.Type)
+ assert.True(t, response.IsError)
+ assert.Contains(t, response.Content, "error parsing parameters")
+ })
+} \ No newline at end of file
diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go
index 58c0aed49..f887337ae 100644
--- a/internal/tui/components/chat/message.go
+++ b/internal/tui/components/chat/message.go
@@ -266,6 +266,8 @@ func toolName(name string) string {
return "Write"
case tools.PatchToolName:
return "Patch"
+ case tools.BatchToolName:
+ return "Batch"
}
return name
}
@@ -292,6 +294,8 @@ func getToolAction(name string) string {
return "Preparing write..."
case tools.PatchToolName:
return "Preparing patch..."
+ case tools.BatchToolName:
+ return "Running batch operations..."
}
return "Working..."
}
@@ -443,6 +447,10 @@ func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
json.Unmarshal([]byte(toolCall.Input), &params)
filePath := removeWorkingDirPrefix(params.FilePath)
return renderParams(paramWidth, filePath)
+ case tools.BatchToolName:
+ var params tools.BatchParams
+ json.Unmarshal([]byte(toolCall.Input), &params)
+ return renderParams(paramWidth, fmt.Sprintf("%d parallel calls", len(params.Calls)))
default:
input := strings.ReplaceAll(toolCall.Input, "\n", " ")
params = renderParams(paramWidth, input)
@@ -540,6 +548,38 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult,
toMarkdown(resultContent, true, width),
t.Background(),
)
+ case tools.BatchToolName:
+ var batchResult tools.BatchResult
+ if err := json.Unmarshal([]byte(resultContent), &batchResult); err != nil {
+ return baseStyle.Width(width).Foreground(t.Error()).Render(fmt.Sprintf("Error parsing batch result: %s", err))
+ }
+
+ var toolCalls []string
+ for i, result := range batchResult.Results {
+ toolName := toolName(result.ToolName)
+
+ // Format the tool input as a string
+ inputStr := string(result.ToolInput)
+
+ // Format the result
+ var resultStr string
+ if result.Error != "" {
+ resultStr = fmt.Sprintf("Error: %s", result.Error)
+ } else {
+ var toolResponse tools.ToolResponse
+ if err := json.Unmarshal(result.Result, &toolResponse); err != nil {
+ resultStr = "Error parsing tool response"
+ } else {
+ resultStr = truncateHeight(toolResponse.Content, 3)
+ }
+ }
+
+ // Format the tool call
+ toolCall := fmt.Sprintf("%d. %s: %s\n %s", i+1, toolName, inputStr, resultStr)
+ toolCalls = append(toolCalls, toolCall)
+ }
+
+ return baseStyle.Width(width).Foreground(t.TextMuted()).Render(strings.Join(toolCalls, "\n\n"))
default:
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
return styles.ForceReplaceBackgroundWithLipgloss(