summaryrefslogtreecommitdiffhomepage
path: root/internal/llm
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-04-03 15:20:15 +0200
committerKujtim Hoxha <[email protected]>2025-04-03 17:23:41 +0200
commitcfdd687216799cb5b47f099f1e7cd5dd16b3bdd0 (patch)
treea822bfde1463a7080c0ea06dd17796d7a1617d3d /internal/llm
parentafd9ad0560d76c2a6d161dad52553b10ff428905 (diff)
downloadopencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.tar.gz
opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.zip
add initial lsp support
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/agent/agent-tool.go2
-rw-r--r--internal/llm/agent/agent.go49
-rw-r--r--internal/llm/agent/coder.go13
-rw-r--r--internal/llm/agent/task.go2
-rw-r--r--internal/llm/prompt/coder.go23
-rw-r--r--internal/llm/provider/anthropic.go21
-rw-r--r--internal/llm/provider/gemini.go256
-rw-r--r--internal/llm/provider/openai.go15
-rw-r--r--internal/llm/provider/provider.go7
-rw-r--r--internal/llm/tools/diagnostics.go229
-rw-r--r--internal/llm/tools/edit.go21
-rw-r--r--internal/llm/tools/shell/shell.go2
-rw-r--r--internal/llm/tools/view.go26
-rw-r--r--internal/llm/tools/write.go18
-rw-r--r--internal/llm/tools/write_test.go76
15 files changed, 511 insertions, 249 deletions
diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go
index a87f48339..bf5e31f8f 100644
--- a/internal/llm/agent/agent-tool.go
+++ b/internal/llm/agent/agent-tool.go
@@ -91,7 +91,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
- return tools.NewTextResponse(response.Content), nil
+ return tools.NewTextResponse(response.Content().String()), nil
}
func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go
index 06dbca4e8..cb123e78c 100644
--- a/internal/llm/agent/agent.go
+++ b/internal/llm/agent/agent.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
+ "strings"
"sync"
"github.com/kujtimiihoxha/termai/internal/app"
@@ -33,8 +34,12 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
c.Context,
[]message.Message{
{
- Role: message.User,
- Content: content,
+ Role: message.User,
+ Parts: []message.ContentPart{
+ message.TextContent{
+ Text: content,
+ },
+ },
},
},
nil,
@@ -49,6 +54,8 @@ func (c *agent) handleTitleGeneration(sessionID, content string) {
}
if response.Content != "" {
session.Title = response.Content
+ session.Title = strings.TrimSpace(session.Title)
+ session.Title = strings.ReplaceAll(session.Title, "\n", " ")
c.Sessions.Save(session)
}
}
@@ -79,17 +86,18 @@ func (c *agent) processEvent(
) error {
switch event.Type {
case provider.EventThinkingDelta:
- assistantMsg.Thinking += event.Thinking
+ assistantMsg.AppendReasoningContent(event.Content)
return c.Messages.Update(*assistantMsg)
case provider.EventContentDelta:
- assistantMsg.Content += event.Content
+ assistantMsg.AppendContent(event.Content)
return c.Messages.Update(*assistantMsg)
case provider.EventError:
log.Println("error", event.Error)
return event.Error
case provider.EventComplete:
- assistantMsg.ToolCalls = event.Response.ToolCalls
+ assistantMsg.SetToolCalls(event.Response.ToolCalls)
+ assistantMsg.AddFinish(event.Response.FinishReason)
err := c.Messages.Update(*assistantMsg)
if err != nil {
return err
@@ -157,18 +165,21 @@ func (c *agent) handleToolExecution(
ctx context.Context,
assistantMsg message.Message,
) (*message.Message, error) {
- if len(assistantMsg.ToolCalls) == 0 {
+ if len(assistantMsg.ToolCalls()) == 0 {
return nil, nil
}
- toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
+ toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
if err != nil {
return nil, err
}
-
+ parts := make([]message.ContentPart, 0)
+ for _, toolResult := range toolResults {
+ parts = append(parts, toolResult)
+ }
msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
- Role: message.Tool,
- ToolResults: toolResults,
+ Role: message.Tool,
+ Parts: parts,
})
return &msg, err
@@ -185,8 +196,12 @@ func (c *agent) generate(sessionID string, content string) error {
}
userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
- Role: message.User,
- Content: content,
+ Role: message.User,
+ Parts: []message.ContentPart{
+ message.TextContent{
+ Text: content,
+ },
+ },
})
if err != nil {
return err
@@ -201,8 +216,8 @@ func (c *agent) generate(sessionID string, content string) error {
}
assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Content: "",
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
})
if err != nil {
return err
@@ -210,20 +225,20 @@ func (c *agent) generate(sessionID string, content string) error {
for event := range eventChan {
err = c.processEvent(sessionID, &assistantMsg, event)
if err != nil {
- assistantMsg.Finished = true
+ assistantMsg.AddFinish("error:" + err.Error())
c.Messages.Update(assistantMsg)
return err
}
}
msg, err := c.handleToolExecution(c.Context, assistantMsg)
- assistantMsg.Finished = true
+
c.Messages.Update(assistantMsg)
if err != nil {
return err
}
- if len(assistantMsg.ToolCalls) == 0 {
+ if len(assistantMsg.ToolCalls()) == 0 {
break
}
diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go
index 38dfd2de1..d167ede99 100644
--- a/internal/llm/agent/coder.go
+++ b/internal/llm/agent/coder.go
@@ -44,20 +44,23 @@ func NewCoderAgent(app *app.App) (Agent, error) {
return nil, err
}
- mcpTools := GetMcpTools(app.Context)
+ otherTools := GetMcpTools(app.Context)
+ if len(app.LSPClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
+ }
return &coderAgent{
agent: &agent{
App: app,
tools: append(
[]tools.BaseTool{
tools.NewBashTool(),
- tools.NewEditTool(),
+ tools.NewEditTool(app.LSPClients),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
- tools.NewViewTool(),
- tools.NewWriteTool(),
- }, mcpTools...,
+ tools.NewViewTool(app.LSPClients),
+ tools.NewWriteTool(app.LSPClients),
+ }, otherTools...,
),
model: model,
agent: agentProvider,
diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go
index 9e0311ed1..97611e62b 100644
--- a/internal/llm/agent/task.go
+++ b/internal/llm/agent/task.go
@@ -34,7 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) {
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
- tools.NewViewTool(),
+ tools.NewViewTool(app.LSPClients),
},
model: model,
agent: agentProvider,
diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go
index cd71bc005..27bb7e431 100644
--- a/internal/llm/prompt/coder.go
+++ b/internal/llm/prompt/coder.go
@@ -67,7 +67,7 @@ Never commit changes unless the user explicitly asks you to.`
envInfo := getEnvironmentInfo()
- return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
+ return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
}
func CoderAnthropicSystemPrompt() string {
@@ -168,7 +168,7 @@ You MUST answer concisely with fewer than 4 lines of text (not including tool us
envInfo := getEnvironmentInfo()
- return fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
+ return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
}
func getEnvironmentInfo() string {
@@ -198,6 +198,25 @@ func isGitRepo(dir string) bool {
return err == nil
}
+func lspInformation() string {
+ cfg := config.Get()
+ hasLSP := false
+ for _, v := range cfg.LSP {
+ if !v.Disabled {
+ hasLSP = true
+ break
+ }
+ }
+ if !hasLSP {
+ return ""
+ }
+ return `# LSP Information
+Tools that support it will also include useful diagnostics such as linting and typechecking.
+These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the <file_diagnostics></file_diagnostics> and <project_diagnostics></project_diagnostics> tags.
+Take necessary actions to fix the issues.
+`
+}
+
func boolToYesNo(b bool) string {
if b {
return "Yes"
diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go
index 63a68b92b..2b960ebca 100644
--- a/internal/llm/provider/anthropic.go
+++ b/internal/llm/provider/anthropic.go
@@ -111,7 +111,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
var thinkingParam anthropic.ThinkingConfigParamUnion
lastMessage := messages[len(messages)-1]
temperature := anthropic.Float(0)
- if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
+ if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
thinkingParam = anthropic.ThinkingConfigParamUnion{
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: int64(float64(a.maxTokens) * 0.8),
@@ -187,9 +187,10 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
- Content: content,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
+ Content: content,
+ ToolCalls: toolCalls,
+ Usage: tokenUsage,
+ FinishReason: string(accumulatedMessage.StopReason),
},
}
}
@@ -263,7 +264,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
for i, msg := range messages {
switch msg.Role {
case message.User:
- content := anthropic.NewTextBlock(msg.Content)
+ content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
@@ -274,8 +275,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
case message.Assistant:
blocks := []anthropic.ContentBlockParamUnion{}
- if msg.Content != "" {
- content := anthropic.NewTextBlock(msg.Content)
+ if msg.Content().String() != "" {
+ content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
@@ -285,7 +286,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
blocks = append(blocks, content)
}
- for _, toolCall := range msg.ToolCalls {
+ for _, toolCall := range msg.ToolCalls() {
var inputMap map[string]any
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
if err != nil {
@@ -297,8 +298,8 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
case message.Tool:
- results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
- for i, toolResult := range msg.ToolResults {
+ results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
+ for i, toolResult := range msg.ToolResults() {
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
}
anthropicMessages[i] = anthropic.NewUserMessage(results...)
diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go
index 6b252b581..53ffa154e 100644
--- a/internal/llm/provider/gemini.go
+++ b/internal/llm/provider/gemini.go
@@ -78,7 +78,6 @@ func (p *geminiProvider) Close() {
}
}
-// convertToGeminiHistory converts the message history to Gemini's format
func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content {
var history []*genai.Content
@@ -86,7 +85,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
switch msg.Role {
case message.User:
history = append(history, &genai.Content{
- Parts: []genai.Part{genai.Text(msg.Content)},
+ Parts: []genai.Part{genai.Text(msg.Content().String())},
Role: "user",
})
case message.Assistant:
@@ -95,14 +94,12 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
Parts: []genai.Part{},
}
- // Handle regular content
- if msg.Content != "" {
- content.Parts = append(content.Parts, genai.Text(msg.Content))
+ if msg.Content().String() != "" {
+ content.Parts = append(content.Parts, genai.Text(msg.Content().String()))
}
- // Handle tool calls if any
- if len(msg.ToolCalls) > 0 {
- for _, call := range msg.ToolCalls {
+ if len(msg.ToolCalls()) > 0 {
+ for _, call := range msg.ToolCalls() {
args, _ := parseJsonToMap(call.Input)
content.Parts = append(content.Parts, genai.FunctionCall{
Name: call.Name,
@@ -113,8 +110,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
history = append(history, content)
case message.Tool:
- for _, result := range msg.ToolResults {
- // Parse response content to map if possible
+ for _, result := range msg.ToolResults() {
response := map[string]interface{}{"result": result.Content}
parsed, err := parseJsonToMap(result.Content)
if err == nil {
@@ -123,7 +119,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
var toolCall message.ToolCall
for _, msg := range messages {
if msg.Role == message.Assistant {
- for _, call := range msg.ToolCalls {
+ for _, call := range msg.ToolCalls() {
if call.ID == result.ToolCallID {
toolCall = call
break
@@ -146,108 +142,6 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g
return history
}
-// convertToolsToGeminiFunctionDeclarations converts tool definitions to Gemini's function declarations
-func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
- declarations := make([]*genai.FunctionDeclaration, len(tools))
-
- for i, tool := range tools {
- info := tool.Info()
-
- // Convert parameters to genai.Schema format
- properties := make(map[string]*genai.Schema)
- for name, param := range info.Parameters {
- // Try to extract type and description from the parameter
- paramMap, ok := param.(map[string]interface{})
- if !ok {
- // Default to string if unable to determine type
- properties[name] = &genai.Schema{Type: genai.TypeString}
- continue
- }
-
- schemaType := genai.TypeString // Default
- var description string
- var itemsTypeSchema *genai.Schema
- if typeVal, found := paramMap["type"]; found {
- if typeStr, ok := typeVal.(string); ok {
- switch typeStr {
- case "string":
- schemaType = genai.TypeString
- case "number":
- schemaType = genai.TypeNumber
- case "integer":
- schemaType = genai.TypeInteger
- case "boolean":
- schemaType = genai.TypeBoolean
- case "array":
- schemaType = genai.TypeArray
- items, found := paramMap["items"]
- if found {
- itemsMap, ok := items.(map[string]interface{})
- if ok {
- itemsType, found := itemsMap["type"]
- if found {
- itemsTypeStr, ok := itemsType.(string)
- if ok {
- switch itemsTypeStr {
- case "string":
- itemsTypeSchema = &genai.Schema{
- Type: genai.TypeString,
- }
- case "number":
- itemsTypeSchema = &genai.Schema{
- Type: genai.TypeNumber,
- }
- case "integer":
- itemsTypeSchema = &genai.Schema{
- Type: genai.TypeInteger,
- }
- case "boolean":
- itemsTypeSchema = &genai.Schema{
- Type: genai.TypeBoolean,
- }
- }
- }
- }
- }
- }
- case "object":
- schemaType = genai.TypeObject
- if _, found := paramMap["properties"]; !found {
- continue
- }
- // TODO: Add support for other types
- }
- }
- }
-
- if desc, found := paramMap["description"]; found {
- if descStr, ok := desc.(string); ok {
- description = descStr
- }
- }
-
- properties[name] = &genai.Schema{
- Type: schemaType,
- Description: description,
- Items: itemsTypeSchema,
- }
- }
-
- declarations[i] = &genai.FunctionDeclaration{
- Name: info.Name,
- Description: info.Description,
- Parameters: &genai.Schema{
- Type: genai.TypeObject,
- Properties: properties,
- Required: info.Required,
- },
- }
- }
-
- return declarations
-}
-
-// extractTokenUsage extracts token usage information from Gemini's response
func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
@@ -261,41 +155,28 @@ func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse)
}
}
-// SendMessages sends a batch of messages to Gemini and returns the response
func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- // Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
- // Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
- // Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
- model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
+ for _, declaration := range declarations {
+ model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}})
+ }
}
- // Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
- // Get the most recent user message
- var lastUserMsg message.Message
- for i := len(messages) - 1; i >= 0; i-- {
- if messages[i].Role == message.User {
- lastUserMsg = messages[i]
- break
- }
- }
-
- // Send the message
- resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content))
+ lastUserMsg := messages[len(messages)-1]
+ resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String()))
if err != nil {
return nil, err
}
- // Process the response
var content string
var toolCalls []message.ToolCall
@@ -317,7 +198,6 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
}
}
- // Extract token usage
tokenUsage := p.extractTokenUsage(resp)
return &ProviderResponse{
@@ -327,16 +207,12 @@ func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Me
}, nil
}
-// StreamResponse streams the response from Gemini
func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
- // Create a generative model
model := p.client.GenerativeModel(p.model.APIModel)
model.SetMaxOutputTokens(p.maxTokens)
- // Set system instruction
model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage))
- // Set up tools if provided
if len(tools) > 0 {
declarations := p.convertToolsToGeminiFunctionDeclarations(tools)
for _, declaration := range declarations {
@@ -344,14 +220,12 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
}
- // Create chat session and set history
chat := model.StartChat()
chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message
lastUserMsg := messages[len(messages)-1]
- // Start streaming
- iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content))
+ iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String()))
eventChan := make(chan ProviderEvent)
@@ -392,7 +266,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
currentContent += newText
case genai.FunctionCall:
- // For function calls, we assume they come complete, not streamed in parts
id := "call_" + uuid.New().String()
args, _ := json.Marshal(p.Args)
newCall := message.ToolCall{
@@ -402,7 +275,6 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
Type: "function",
}
- // Check if this is a new tool call
isNew := true
for _, existing := range toolCalls {
if existing.Name == newCall.Name && existing.Input == newCall.Input {
@@ -419,15 +291,15 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
}
}
- // Extract token usage from the final response
tokenUsage := p.extractTokenUsage(finalResp)
eventChan <- ProviderEvent{
Type: EventComplete,
Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: tokenUsage,
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: tokenUsage,
+ FinishReason: string(finalResp.Candidates[0].FinishReason.String()),
},
}
}()
@@ -435,7 +307,99 @@ func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.
return eventChan, nil
}
-// Helper function to parse JSON string into map
+func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration {
+ declarations := make([]*genai.FunctionDeclaration, len(tools))
+
+ for i, tool := range tools {
+ info := tool.Info()
+ declarations[i] = &genai.FunctionDeclaration{
+ Name: info.Name,
+ Description: info.Description,
+ Parameters: &genai.Schema{
+ Type: genai.TypeObject,
+ Properties: convertSchemaProperties(info.Parameters),
+ Required: info.Required,
+ },
+ }
+ }
+
+ return declarations
+}
+
+func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
+ properties := make(map[string]*genai.Schema)
+
+ for name, param := range parameters {
+ properties[name] = convertToSchema(param)
+ }
+
+ return properties
+}
+
+func convertToSchema(param interface{}) *genai.Schema {
+ schema := &genai.Schema{Type: genai.TypeString}
+
+ paramMap, ok := param.(map[string]interface{})
+ if !ok {
+ return schema
+ }
+
+ if desc, ok := paramMap["description"].(string); ok {
+ schema.Description = desc
+ }
+
+ typeVal, hasType := paramMap["type"]
+ if !hasType {
+ return schema
+ }
+
+ typeStr, ok := typeVal.(string)
+ if !ok {
+ return schema
+ }
+
+ schema.Type = mapJSONTypeToGenAI(typeStr)
+
+ switch typeStr {
+ case "array":
+ schema.Items = processArrayItems(paramMap)
+ case "object":
+ if props, ok := paramMap["properties"].(map[string]interface{}); ok {
+ schema.Properties = convertSchemaProperties(props)
+ }
+ }
+
+ return schema
+}
+
+func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
+ items, ok := paramMap["items"].(map[string]interface{})
+ if !ok {
+ return nil
+ }
+
+ return convertToSchema(items)
+}
+
+func mapJSONTypeToGenAI(jsonType string) genai.Type {
+ switch jsonType {
+ case "string":
+ return genai.TypeString
+ case "number":
+ return genai.TypeNumber
+ case "integer":
+ return genai.TypeInteger
+ case "boolean":
+ return genai.TypeBoolean
+ case "array":
+ return genai.TypeArray
+ case "object":
+ return genai.TypeObject
+ default:
+ return genai.TypeString // Default to string for unknown types
+ }
+}
+
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go
index d86a58690..c8e04d5ee 100644
--- a/internal/llm/provider/openai.go
+++ b/internal/llm/provider/openai.go
@@ -84,22 +84,22 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
for _, msg := range messages {
switch msg.Role {
case message.User:
- chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
+ chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
Role: "assistant",
}
- if msg.Content != "" {
+ if msg.Content().String() != "" {
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
- OfString: openai.String(msg.Content),
+ OfString: openai.String(msg.Content().String()),
}
}
- if len(msg.ToolCalls) > 0 {
- assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
- for i, call := range msg.ToolCalls {
+ if len(msg.ToolCalls()) > 0 {
+ assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
+ for i, call := range msg.ToolCalls() {
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
ID: call.ID,
Type: "function",
@@ -116,7 +116,7 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o
})
case message.Tool:
- for _, result := range msg.ToolResults {
+ for _, result := range msg.ToolResults() {
chatMessages = append(chatMessages,
openai.ToolMessage(result.Content, result.ToolCallID),
)
@@ -276,3 +276,4 @@ func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.
return eventChan, nil
}
+
diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go
index 9ac1def37..f40429738 100644
--- a/internal/llm/provider/provider.go
+++ b/internal/llm/provider/provider.go
@@ -27,9 +27,10 @@ type TokenUsage struct {
}
type ProviderResponse struct {
- Content string
- ToolCalls []message.ToolCall
- Usage TokenUsage
+ Content string
+ ToolCalls []message.ToolCall
+ Usage TokenUsage
+ FinishReason string
}
type ProviderEvent struct {
diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go
new file mode 100644
index 000000000..dc90e5860
--- /dev/null
+++ b/internal/llm/tools/diagnostics.go
@@ -0,0 +1,229 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
+)
+
+type diagnosticsTool struct {
+ lspClients map[string]*lsp.Client
+}
+
+const (
+ DiagnosticsToolName = "diagnostics"
+)
+
+type DiagnosticsParams struct {
+ FilePath string `json:"file_path"`
+}
+
+func (b *diagnosticsTool) Info() ToolInfo {
+ return ToolInfo{
+ Name: DiagnosticsToolName,
+ Description: "Get diagnostics for a file and/or project.",
+ Parameters: map[string]any{
+ "file_path": map[string]any{
+ "type": "string",
+ "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
+ },
+ },
+ Required: []string{},
+ }
+}
+
+func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ var params DiagnosticsParams
+ if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
+ return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
+ }
+
+ lsps := b.lspClients
+
+ if len(lsps) == 0 {
+ return NewTextErrorResponse("no LSP clients available"), nil
+ }
+
+ if params.FilePath == "" {
+ notifyLspOpenFile(ctx, params.FilePath, lsps)
+ }
+
+ output := appendDiagnostics(params.FilePath, lsps)
+
+ return NewTextResponse(output), nil
+}
+
+func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
+ for _, client := range lsps {
+ err := client.OpenFile(ctx, filePath)
+ if err != nil {
+ // Wait for the file to be opened and diagnostics to be received
+ // TODO: see if we can do this in a more efficient way
+ time.Sleep(2 * time.Second)
+ }
+
+ }
+}
+
+func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
+ fileDiagnostics := []string{}
+ projectDiagnostics := []string{}
+
+ // Enhanced format function that includes more diagnostic information
+ formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
+ // Base components
+ severity := "Info"
+ switch diagnostic.Severity {
+ case protocol.SeverityError:
+ severity = "Error"
+ case protocol.SeverityWarning:
+ severity = "Warn"
+ case protocol.SeverityHint:
+ severity = "Hint"
+ }
+
+ // Location information
+ location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
+
+ // Source information (LSP name)
+ sourceInfo := ""
+ if diagnostic.Source != "" {
+ sourceInfo = diagnostic.Source
+ } else if source != "" {
+ sourceInfo = source
+ }
+
+ // Code information
+ codeInfo := ""
+ if diagnostic.Code != nil {
+ codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
+ }
+
+ // Tags information
+ tagsInfo := ""
+ if len(diagnostic.Tags) > 0 {
+ tags := []string{}
+ for _, tag := range diagnostic.Tags {
+ switch tag {
+ case protocol.Unnecessary:
+ tags = append(tags, "unnecessary")
+ case protocol.Deprecated:
+ tags = append(tags, "deprecated")
+ }
+ }
+ if len(tags) > 0 {
+ tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
+ }
+ }
+
+ // Assemble the full diagnostic message
+ return fmt.Sprintf("%s: %s [%s]%s%s %s",
+ severity,
+ location,
+ sourceInfo,
+ codeInfo,
+ tagsInfo,
+ diagnostic.Message)
+ }
+
+ for lspName, client := range lsps {
+ diagnostics := client.GetDiagnostics()
+ if len(diagnostics) > 0 {
+ for location, diags := range diagnostics {
+ isCurrentFile := location.Path() == filePath
+
+ // Group diagnostics by severity for better organization
+ for _, diag := range diags {
+ formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
+
+ if isCurrentFile {
+ fileDiagnostics = append(fileDiagnostics, formattedDiag)
+ } else {
+ projectDiagnostics = append(projectDiagnostics, formattedDiag)
+ }
+ }
+ }
+ }
+ }
+
+ // Sort diagnostics by severity (errors first) and then by location
+ sort.Slice(fileDiagnostics, func(i, j int) bool {
+ iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
+ jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
+ if iIsError != jIsError {
+ return iIsError // Errors come first
+ }
+ return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
+ })
+
+ sort.Slice(projectDiagnostics, func(i, j int) bool {
+ iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
+ jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
+ if iIsError != jIsError {
+ return iIsError
+ }
+ return projectDiagnostics[i] < projectDiagnostics[j]
+ })
+
+ output := ""
+
+ if len(fileDiagnostics) > 0 {
+ output += "\n<file_diagnostics>\n"
+ if len(fileDiagnostics) > 10 {
+ output += strings.Join(fileDiagnostics[:10], "\n")
+ output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
+ } else {
+ output += strings.Join(fileDiagnostics, "\n")
+ }
+ output += "\n</file_diagnostics>\n"
+ }
+
+ if len(projectDiagnostics) > 0 {
+ output += "\n<project_diagnostics>\n"
+ if len(projectDiagnostics) > 10 {
+ output += strings.Join(projectDiagnostics[:10], "\n")
+ output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
+ } else {
+ output += strings.Join(projectDiagnostics, "\n")
+ }
+ output += "\n</project_diagnostics>\n"
+ }
+
+ // Add summary counts
+ if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
+ fileErrors := countSeverity(fileDiagnostics, "Error")
+ fileWarnings := countSeverity(fileDiagnostics, "Warn")
+ projectErrors := countSeverity(projectDiagnostics, "Error")
+ projectWarnings := countSeverity(projectDiagnostics, "Warn")
+
+ output += "\n<diagnostic_summary>\n"
+ output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
+ output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
+ output += "</diagnostic_summary>\n"
+ }
+
+ return output
+}
+
+// Helper function to count diagnostics by severity
+func countSeverity(diagnostics []string, severity string) int {
+ count := 0
+ for _, diag := range diagnostics {
+ if strings.HasPrefix(diag, severity) {
+ count++
+ }
+ }
+ return count
+}
+
+func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
+ return &diagnosticsTool{
+ lspClients,
+ }
+}
diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go
index 8c5427a58..c84bbd7a0 100644
--- a/internal/llm/tools/edit.go
+++ b/internal/llm/tools/edit.go
@@ -10,11 +10,14 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/sergi/go-diff/diffmatchpatch"
)
-type editTool struct{}
+type editTool struct {
+ lspClients map[string]*lsp.Client
+}
const (
EditToolName = "edit"
@@ -71,6 +74,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
params.FilePath = filepath.Join(wd, params.FilePath)
}
+ notifyLspOpenFile(ctx, params.FilePath, e.lspClients)
if params.OldString == "" {
result, err := createNewFile(params.FilePath, params.NewString)
if err != nil {
@@ -91,6 +95,9 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
if err != nil {
return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
}
+
+ result = fmt.Sprintf("<result>\n%s\n</result>\n", result)
+ result += appendDiagnostics(params.FilePath, e.lspClients)
return NewTextResponse(result), nil
}
@@ -296,18 +303,18 @@ func GenerateDiff(oldContent, newContent string) string {
switch diff.Type {
case diffmatchpatch.DiffInsert:
- for _, line := range strings.Split(text, "\n") {
+ for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString("+ " + line + "\n")
}
case diffmatchpatch.DiffDelete:
- for _, line := range strings.Split(text, "\n") {
+ for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString("- " + line + "\n")
}
case diffmatchpatch.DiffEqual:
if len(text) > 40 {
_, _ = buff.WriteString(" " + text[:20] + "..." + text[len(text)-20:] + "\n")
} else {
- for _, line := range strings.Split(text, "\n") {
+ for line := range strings.SplitSeq(text, "\n") {
_, _ = buff.WriteString(" " + line + "\n")
}
}
@@ -366,6 +373,8 @@ When making edits:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
}
-func NewEditTool() BaseTool {
- return &editTool{}
+func NewEditTool(lspClients map[string]*lsp.Client) BaseTool {
+ return &editTool{
+ lspClients,
+ }
}
diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go
index d76cb1a2e..64592f67d 100644
--- a/internal/llm/tools/shell/shell.go
+++ b/internal/llm/tools/shell/shell.go
@@ -221,7 +221,7 @@ func (s *PersistentShell) killChildren() {
return
}
- for _, pidStr := range strings.Split(string(output), "\n") {
+ for pidStr := range strings.SplitSeq(string(output), "\n") {
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
var pid int
fmt.Sscanf(pidStr, "%d", &pid)
diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go
index dca522b9c..743cef6f4 100644
--- a/internal/llm/tools/view.go
+++ b/internal/llm/tools/view.go
@@ -11,9 +11,12 @@ import (
"strings"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
)
-type viewTool struct{}
+type viewTool struct {
+ lspClients map[string]*lsp.Client
+}
const (
ViewToolName = "view"
@@ -127,15 +130,18 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
}
+ notifyLspOpenFile(ctx, filePath, v.lspClients)
+ output := "<file>\n"
// Format the output with line numbers
- output := addLineNumbers(content, params.Offset+1)
+ output += addLineNumbers(content, params.Offset+1)
// Add a note if the content was truncated
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
params.Offset+len(strings.Split(content, "\n")))
}
-
+ output += "\n</file>\n"
+ output += appendDiagnostics(filePath, v.lspClients)
recordFileRead(filePath)
return NewTextResponse(output), nil
}
@@ -155,10 +161,10 @@ func addLineNumbers(content string, startLine int) string {
numStr := fmt.Sprintf("%d", lineNum)
if len(numStr) >= 6 {
- result = append(result, fmt.Sprintf("%s\t%s", numStr, line))
+ result = append(result, fmt.Sprintf("%s|%s", numStr, line))
} else {
paddedNum := fmt.Sprintf("%6s", numStr)
- result = append(result, fmt.Sprintf("%s\t|%s", paddedNum, line))
+ result = append(result, fmt.Sprintf("%s|%s", paddedNum, line))
}
}
@@ -173,8 +179,9 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
defer file.Close()
lineCount := 0
+
+ scanner := NewLineScanner(file)
if offset > 0 {
- scanner := NewLineScanner(file)
for lineCount < offset && scanner.Scan() {
lineCount++
}
@@ -192,7 +199,6 @@ func readTextFile(filePath string, offset, limit int) (string, int, error) {
var lines []string
lineCount = offset
- scanner := NewLineScanner(file)
for scanner.Scan() && len(lines) < limit {
lineCount++
@@ -290,6 +296,8 @@ TIPS:
- When viewing large files, use the offset parameter to read specific sections`
}
-func NewViewTool() BaseTool {
- return &viewTool{}
+func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
+ return &viewTool{
+ lspClients,
+ }
}
diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go
index 003753d08..3d66d64e2 100644
--- a/internal/llm/tools/write.go
+++ b/internal/llm/tools/write.go
@@ -9,10 +9,13 @@ import (
"time"
"github.com/kujtimiihoxha/termai/internal/config"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
)
-type writeTool struct{}
+type writeTool struct {
+ lspClients map[string]*lsp.Client
+}
const (
WriteToolName = "write"
@@ -96,6 +99,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
if err = os.MkdirAll(dir, 0o755); err != nil {
return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
}
+
+ notifyLspOpenFile(ctx, filePath, w.lspClients)
p := permission.Default.Request(
permission.CreatePermissionRequest{
Path: filePath,
@@ -122,7 +127,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
recordFileWrite(filePath)
recordFileRead(filePath)
- return NewTextResponse(fmt.Sprintf("File successfully written: %s", filePath)), nil
+ result := fmt.Sprintf("File successfully written: %s", filePath)
+ result = fmt.Sprintf("<result>\n%s\n</result>", result)
+ result += appendDiagnostics(filePath, w.lspClients)
+ return NewTextResponse(result), nil
}
func writeDescription() string {
@@ -156,6 +164,8 @@ TIPS:
- Always include descriptive comments when making changes to existing code`
}
-func NewWriteTool() BaseTool {
- return &writeTool{}
+func NewWriteTool(lspClients map[string]*lsp.Client) BaseTool {
+ return &writeTool{
+ lspClients,
+ }
}
diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go
index 1c92e3baa..893a48b62 100644
--- a/internal/llm/tools/write_test.go
+++ b/internal/llm/tools/write_test.go
@@ -8,13 +8,14 @@ import (
"testing"
"time"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/permission"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteTool_Info(t *testing.T) {
- tool := NewWriteTool()
+ tool := NewWriteTool(make(map[string]*lsp.Client))
info := tool.Info()
assert.Equal(t, WriteToolName, info.Name)
@@ -40,11 +41,11 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("creates a new file successfully", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
filePath := filepath.Join(tempDir, "new_file.txt")
content := "This is a test content"
-
+
params := WriteParams{
FilePath: filePath,
Content: content,
@@ -70,11 +71,11 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("creates file with nested directories", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
content := "Content in nested directory"
-
+
params := WriteParams{
FilePath: filePath,
Content: content,
@@ -100,17 +101,17 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("updates existing file", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
// Create a file first
filePath := filepath.Join(tempDir, "existing_file.txt")
initialContent := "Initial content"
- err := os.WriteFile(filePath, []byte(initialContent), 0644)
+ err := os.WriteFile(filePath, []byte(initialContent), 0o644)
require.NoError(t, err)
-
+
// Record the file read to avoid modification time check failure
recordFileRead(filePath)
-
+
// Update the file
updatedContent := "Updated content"
params := WriteParams{
@@ -138,8 +139,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles invalid parameters", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
call := ToolCall{
Name: WriteToolName,
Input: "invalid json",
@@ -152,8 +153,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles missing file_path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
params := WriteParams{
FilePath: "",
Content: "Some content",
@@ -174,8 +175,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles missing content", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
params := WriteParams{
FilePath: filepath.Join(tempDir, "file.txt"),
Content: "",
@@ -196,13 +197,13 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles writing to a directory path", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
// Create a directory
dirPath := filepath.Join(tempDir, "test_dir")
- err := os.Mkdir(dirPath, 0755)
+ err := os.Mkdir(dirPath, 0o755)
require.NoError(t, err)
-
+
params := WriteParams{
FilePath: dirPath,
Content: "Some content",
@@ -223,8 +224,8 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("handles permission denied", func(t *testing.T) {
permission.Default = newMockPermissionService(false)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
filePath := filepath.Join(tempDir, "permission_denied.txt")
params := WriteParams{
FilePath: filePath,
@@ -242,7 +243,7 @@ func TestWriteTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "Permission denied")
-
+
// Verify file was not created
_, err = os.Stat(filePath)
assert.True(t, os.IsNotExist(err))
@@ -250,14 +251,14 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("detects file modified since last read", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
// Create a file
filePath := filepath.Join(tempDir, "modified_file.txt")
initialContent := "Initial content"
- err := os.WriteFile(filePath, []byte(initialContent), 0644)
+ err := os.WriteFile(filePath, []byte(initialContent), 0o644)
require.NoError(t, err)
-
+
// Record an old read time
fileRecordMutex.Lock()
fileRecords[filePath] = fileRecord{
@@ -265,7 +266,7 @@ func TestWriteTool_Run(t *testing.T) {
readTime: time.Now().Add(-1 * time.Hour),
}
fileRecordMutex.Unlock()
-
+
// Try to update the file
params := WriteParams{
FilePath: filePath,
@@ -283,7 +284,7 @@ func TestWriteTool_Run(t *testing.T) {
response, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Contains(t, response.Content, "has been modified since it was last read")
-
+
// Verify file was not modified
fileContent, err := os.ReadFile(filePath)
require.NoError(t, err)
@@ -292,17 +293,17 @@ func TestWriteTool_Run(t *testing.T) {
t.Run("skips writing when content is identical", func(t *testing.T) {
permission.Default = newMockPermissionService(true)
- tool := NewWriteTool()
-
+ tool := NewWriteTool(make(map[string]*lsp.Client))
+
// Create a file
filePath := filepath.Join(tempDir, "identical_content.txt")
content := "Content that won't change"
- err := os.WriteFile(filePath, []byte(content), 0644)
+ err := os.WriteFile(filePath, []byte(content), 0o644)
require.NoError(t, err)
-
+
// Record a read time
recordFileRead(filePath)
-
+
// Try to write the same content
params := WriteParams{
FilePath: filePath,
@@ -321,4 +322,5 @@ func TestWriteTool_Run(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, response.Content, "already contains the exact content")
})
-} \ No newline at end of file
+}
+