summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-04-04 14:23:08 +0200
committerKujtim Hoxha <[email protected]>2025-04-04 14:23:08 +0200
commit6bb1c84f7f7f0430f2808d50c533e923aae4c787 (patch)
tree2a92f077570d505d5ae0387660a0d246de0fa43a
parenteb9877ee20c44b7cd34f78e9110d315db71977f6 (diff)
downloadopencode-6bb1c84f7f7f0430f2808d50c533e923aae4c787.tar.gz
opencode-6bb1c84f7f7f0430f2808d50c533e923aae4c787.zip
Improve Sourcegraph tool with context window and fix diagnostics
- Add context_window parameter to control code context display - Fix LSP diagnostics notification handling with proper async waiting - Switch to keyword search pattern for better results - Add Sourcegraph tool to task agent 🤖 Generated with termai Co-Authored-By: termai <[email protected]>
-rw-r--r--cmd/lsp/main.go12
-rw-r--r--internal/llm/agent/task.go1
-rw-r--r--internal/llm/tools/diagnostics.go62
-rw-r--r--internal/llm/tools/sourcegraph.go32
4 files changed, 87 insertions, 20 deletions
diff --git a/cmd/lsp/main.go b/cmd/lsp/main.go
index da29a2cad..0c7f79329 100644
--- a/cmd/lsp/main.go
+++ b/cmd/lsp/main.go
@@ -1,4 +1,16 @@
package main
+import (
+ "context"
+ "fmt"
+
+ "github.com/kujtimiihoxha/termai/internal/llm/tools"
+)
+
func main() {
+ t := tools.NewSourcegraphTool()
+ r, _ := t.Run(context.Background(), tools.ToolCall{
+ Input: `{"query": "context.WithCancel lang:go"}`,
+ })
+ fmt.Println(r.Content)
}
diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go
index 97611e62b..9737d41b8 100644
--- a/internal/llm/agent/task.go
+++ b/internal/llm/agent/task.go
@@ -34,6 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) {
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
tools.NewViewTool(app.LSPClients),
},
model: model,
diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go
index d58dbd9fc..97ac149b6 100644
--- a/internal/llm/tools/diagnostics.go
+++ b/internal/llm/tools/diagnostics.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "maps"
"sort"
"strings"
"time"
@@ -50,7 +51,7 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
return NewTextErrorResponse("no LSP clients available"), nil
}
- if params.FilePath == "" {
+ if params.FilePath != "" {
notifyLspOpenFile(ctx, params.FilePath, lsps)
}
@@ -60,15 +61,68 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
}
func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
+ // Create a channel to receive diagnostic notifications
+ diagChan := make(chan struct{}, 1)
+
+ // Register a temporary diagnostic handler for each client
for _, client := range lsps {
+ // Store the original diagnostics map to detect changes
+ originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
+ maps.Copy(originalDiags, client.GetDiagnostics())
+
+ // Create a notification handler that will signal when diagnostics are received
+ handler := func(params json.RawMessage) {
+ var diagParams protocol.PublishDiagnosticsParams
+ if err := json.Unmarshal(params, &diagParams); err != nil {
+ return
+ }
+
+ // If this is for our file or we've received any new diagnostics, signal completion
+ if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
+ select {
+ case diagChan <- struct{}{}:
+ // Signal sent
+ default:
+ // Channel already has a value, no need to send again
+ }
+ }
+ }
+
+ // Register our temporary handler
+ client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
+
+ // Open the file
err := client.OpenFile(ctx, filePath)
if err != nil {
- // Wait for the file to be opened and diagnostics to be received
- // TODO: see if we can do this in a more efficient way
- time.Sleep(3 * time.Second)
+ // If there's an error opening the file, continue to the next client
+ continue
}
+ }
+
+ // Wait for diagnostics with a reasonable timeout
+ select {
+ case <-diagChan:
+ // Diagnostics received
+ case <-time.After(5 * time.Second):
+ // Timeout after 2 seconds - this is a fallback in case no diagnostics are published
+ case <-ctx.Done():
+ // Context cancelled
+ }
+ // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
+ // replaces any existing handler, and we'll be replaced by the original handler when
+ // the LSP client is reinitialized or when a new handler is registered.
+}
+
+// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
+func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
+ for uri, diags := range current {
+ origDiags, exists := original[uri]
+ if !exists || len(diags) != len(origDiags) {
+ return true
+ }
}
+ return false
}
func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go
index 50b95c50e..f20ce8a62 100644
--- a/internal/llm/tools/sourcegraph.go
+++ b/internal/llm/tools/sourcegraph.go
@@ -111,15 +111,10 @@ TIPS:
)
type SourcegraphParams struct {
- Query string `json:"query"`
- Count int `json:"count,omitempty"`
- Timeout int `json:"timeout,omitempty"`
-}
-
-type SourcegraphPermissionsParams struct {
- Query string `json:"query"`
- Count int `json:"count,omitempty"`
- Timeout int `json:"timeout,omitempty"`
+ Query string `json:"query"`
+ Count int `json:"count,omitempty"`
+ ContextWindow int `json:"context_window,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
}
type sourcegraphTool struct {
@@ -147,6 +142,10 @@ func (t *sourcegraphTool) Info() ToolInfo {
"type": "number",
"description": "Optional number of results to return (default: 10, max: 20)",
},
+ "context_window": map[string]any{
+ "type": "number",
+ "description": "The context around the match to return (default: 10 lines)",
+ },
"timeout": map[string]any{
"type": "number",
"description": "Optional timeout in seconds (max 120)",
@@ -173,6 +172,9 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
params.Count = 20 // Limit to 20 results
}
+ if params.ContextWindow <= 0 {
+ params.ContextWindow = 10 // Default context window
+ }
client := t.client
if params.Timeout > 0 {
maxTimeout := 120 // 2 minutes
@@ -194,7 +196,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
}
request := graphqlRequest{
- Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: standard ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
+ Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
}
request.Variables.Query = params.Query
@@ -246,7 +248,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
}
// Format the results in a readable way
- formattedResults, err := formatSourcegraphResults(result)
+ formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
if err != nil {
return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
}
@@ -254,7 +256,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
return NewTextResponse(formattedResults), nil
}
-func formatSourcegraphResults(result map[string]any) (string, error) {
+func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
var buffer strings.Builder
// Check for errors in the GraphQL response
@@ -364,8 +366,7 @@ func formatSourcegraphResults(result map[string]any) (string, error) {
buffer.WriteString("```\n")
// Display context before the match (up to 10 lines)
- contextBefore := 10
- startLine := max(1, int(lineNumber)-contextBefore)
+ startLine := max(1, int(lineNumber)-contextWindow)
for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
if j >= 0 {
@@ -377,8 +378,7 @@ func formatSourcegraphResults(result map[string]any) (string, error) {
buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
// Display context after the match (up to 10 lines)
- contextAfter := 10
- endLine := int(lineNumber) + contextAfter
+ endLine := int(lineNumber) + contextWindow
for j := int(lineNumber); j < endLine && j < len(lines); j++ {
if j < len(lines) {