diff options
| author | Kujtim Hoxha <[email protected]> | 2025-04-03 15:20:15 +0200 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-03 17:23:41 +0200 |
| commit | cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0 (patch) | |
| tree | a822bfde1463a7080c0ea06dd17796d7a1617d3d /internal/message | |
| parent | afd9ad0560d76c2a6d161dad52553b10ff428905 (diff) | |
| download | opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.tar.gz opencode-cfdd687216799cb5b47f099f1e7cd5dd16b3bdd0.zip | |
add initial lsp support
Diffstat (limited to 'internal/message')
| -rw-r--r-- | internal/message/content.go | 244 | ||||
| -rw-r--r-- | internal/message/message.go | 245 |
2 files changed, 393 insertions, 96 deletions
diff --git a/internal/message/content.go b/internal/message/content.go new file mode 100644 index 000000000..2604cd68a --- /dev/null +++ b/internal/message/content.go @@ -0,0 +1,244 @@ +package message + +import ( + "encoding/base64" +) + +type MessageRole string + +const ( + Assistant MessageRole = "assistant" + User MessageRole = "user" + System MessageRole = "system" + Tool MessageRole = "tool" +) + +type ContentPart interface { + isPart() +} + +type ReasoningContent struct { + Thinking string `json:"thinking"` +} + +func (tc ReasoningContent) String() string { + return tc.Thinking +} +func (ReasoningContent) isPart() {} + +type TextContent struct { + Text string `json:"text"` +} + +func (tc TextContent) String() string { + return tc.Text +} + +func (TextContent) isPart() {} + +type ImageURLContent struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +func (iuc ImageURLContent) String() string { + return iuc.URL +} + +func (ImageURLContent) isPart() {} + +type BinaryContent struct { + MIMEType string + Data []byte +} + +func (bc BinaryContent) String() string { + base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) + return "data:" + bc.MIMEType + ";base64," + base64Encoded +} + +func (BinaryContent) isPart() {} + +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Input string `json:"input"` + Type string `json:"type"` + Finished bool `json:"finished"` +} + +func (ToolCall) isPart() {} + +type ToolResult struct { + ToolCallID string `json:"tool_call_id"` + Name string `json:"name"` + Content string `json:"content"` + IsError bool `json:"is_error"` +} + +func (ToolResult) isPart() {} + +type Finish struct { + Reason string `json:"reason"` +} + +func (Finish) isPart() {} + +type Message struct { + ID string + Role MessageRole + SessionID string + Parts []ContentPart + + CreatedAt int64 + UpdatedAt int64 +} + +func (m *Message) Content() TextContent { + for _, part := range m.Parts { + if c, ok := part.(TextContent); ok { + return c + } + } + return TextContent{} +} + +func (m *Message) ReasoningContent() ReasoningContent { + for _, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + return c + } + } + return ReasoningContent{} +} + +func (m *Message) ImageURLContent() []ImageURLContent { + imageURLContents := make([]ImageURLContent, 0) + for _, part := range m.Parts { + if c, ok := part.(ImageURLContent); ok { + imageURLContents = append(imageURLContents, c) + } + } + return imageURLContents +} + +func (m *Message) BinaryContent() []BinaryContent { + binaryContents := make([]BinaryContent, 0) + for _, part := range m.Parts { + if c, ok := part.(BinaryContent); ok { + binaryContents = append(binaryContents, c) + } + } + return binaryContents +} + +func (m *Message) ToolCalls() []ToolCall { + toolCalls := make([]ToolCall, 0) + for _, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + toolCalls = append(toolCalls, c) + } + } + return toolCalls +} + +func (m *Message) ToolResults() []ToolResult { + toolResults := make([]ToolResult, 0) + for _, part := range m.Parts { + if c, ok := part.(ToolResult); ok { + toolResults = append(toolResults, c) + } + } + return toolResults +} + +func (m *Message) IsFinished() bool { + for _, part := range m.Parts { + if _, ok := part.(Finish); ok { + return true + } + } + return false +} + +func (m *Message) FinishReason() string { + for _, part := range m.Parts { + if c, ok := part.(Finish); ok { + return c.Reason + } + } + return "" +} + +func (m *Message) IsThinking() bool { + if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() { + return true + } + return false +} + +func (m *Message) AppendContent(delta string) { + found := false + for i, part := range m.Parts { + if c, ok := part.(TextContent); ok { + m.Parts[i] = TextContent{Text: c.Text + delta} + found = true + } + } + if !found { + m.Parts = append(m.Parts, TextContent{Text: delta}) + } +} + +func (m *Message) AppendReasoningContent(delta string) { + found := false + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta} + found = true + } + } + if !found { + m.Parts = append(m.Parts, ReasoningContent{Thinking: delta}) + } +} + +func (m *Message) AddToolCall(tc ToolCall) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == tc.ID { + m.Parts[i] = tc + return + } + } + } + m.Parts = append(m.Parts, tc) +} + +func (m *Message) SetToolCalls(tc []ToolCall) { + for _, toolCall := range tc { + m.Parts = append(m.Parts, toolCall) + } +} + +func (m *Message) AddToolResult(tr ToolResult) { + m.Parts = append(m.Parts, tr) +} + +func (m *Message) SetToolResults(tr []ToolResult) { + for _, toolResult := range tr { + m.Parts = append(m.Parts, toolResult) + } +} + +func (m *Message) AddFinish(reason string) { + m.Parts = append(m.Parts, Finish{Reason: reason}) +} + +func (m *Message) AddImageURL(url, detail string) { + m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail}) +} + +func (m *Message) AddBinary(mimeType string, data []byte) { + m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data}) +} diff --git a/internal/message/message.go b/internal/message/message.go index 157c15c7c..13cf54048 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -2,59 +2,17 @@ package message import ( "context" - "database/sql" "encoding/json" + "fmt" "github.com/google/uuid" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/pubsub" ) -type MessageRole string - -const ( - Assistant MessageRole = "assistant" - User MessageRole = "user" - System MessageRole = "system" - Tool MessageRole = "tool" -) - -type ToolResult struct { - ToolCallID string - Content string - IsError bool - // TODO: support for images -} - -type ToolCall struct { - ID string - Name string - Input string - Type string -} - -type Message struct { - ID string - SessionID string - - // NEW - Role MessageRole - Content string - Thinking string - - Finished bool - - ToolResults []ToolResult - ToolCalls []ToolCall - CreatedAt int64 - UpdatedAt int64 -} - type CreateMessageParams struct { - Role MessageRole - Content string - ToolCalls []ToolCall - ToolResults []ToolResult + Role MessageRole + Parts []ContentPart } type Service interface { @@ -73,6 +31,14 @@ type service struct { ctx context.Context } +func NewService(ctx context.Context, q db.Querier) Service { + return &service{ + Broker: pubsub.NewBroker[Message](), + q: q, + ctx: ctx, + } +} + func (s *service) Delete(id string) error { message, err := s.Get(id) if err != nil { @@ -87,22 +53,21 @@ func (s *service) Delete(id string) error { } func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) { - toolCallsStr, err := json.Marshal(params.ToolCalls) - if err != nil { - return Message{}, err + if params.Role != Assistant { + params.Parts = append(params.Parts, Finish{ + Reason: "stop", + }) } - toolResultsStr, err := json.Marshal(params.ToolResults) + partsJSON, err := marshallParts(params.Parts) if err != nil { return Message{}, err } + dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{ - ID: uuid.New().String(), - SessionID: sessionID, - Role: string(params.Role), - Finished: params.Role != Assistant, - Content: params.Content, - ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true}, - ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true}, + ID: uuid.New().String(), + SessionID: sessionID, + Role: string(params.Role), + Parts: string(partsJSON), }) if err != nil { return Message{}, err @@ -132,21 +97,13 @@ func (s *service) DeleteSessionMessages(sessionID string) error { } func (s *service) Update(message Message) error { - toolCallsStr, err := json.Marshal(message.ToolCalls) - if err != nil { - return err - } - toolResultsStr, err := json.Marshal(message.ToolResults) + parts, err := marshallParts(message.Parts) if err != nil { return err } err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{ - ID: message.ID, - Content: message.Content, - Thinking: message.Thinking, - Finished: message.Finished, - ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true}, - ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true}, + ID: message.ID, + Parts: string(parts), }) if err != nil { return err @@ -179,40 +136,136 @@ func (s *service) List(sessionID string) ([]Message, error) { } func (s *service) fromDBItem(item db.Message) (Message, error) { - toolCalls := make([]ToolCall, 0) - if item.ToolCalls.Valid { - err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls) - if err != nil { - return Message{}, err - } + parts, err := unmarshallParts([]byte(item.Parts)) + if err != nil { + return Message{}, err } + return Message{ + ID: item.ID, + SessionID: item.SessionID, + Role: MessageRole(item.Role), + Parts: parts, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + }, nil +} - toolResults := make([]ToolResult, 0) - if item.ToolResults.Valid { - err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults) - if err != nil { - return Message{}, err +type partType string + +const ( + reasoningType partType = "reasoning" + textType partType = "text" + imageURLType partType = "image_url" + binaryType partType = "binary" + toolCallType partType = "tool_call" + toolResultType partType = "tool_result" + finishType partType = "finish" +) + +type partWrapper struct { + Type partType `json:"type"` + Data ContentPart `json:"data"` +} + +func marshallParts(parts []ContentPart) ([]byte, error) { + wrappedParts := make([]partWrapper, len(parts)) + + for i, part := range parts { + var typ partType + + switch part.(type) { + case ReasoningContent: + typ = reasoningType + case TextContent: + typ = textType + case ImageURLContent: + typ = imageURLType + case BinaryContent: + typ = binaryType + case ToolCall: + typ = toolCallType + case ToolResult: + typ = toolResultType + case Finish: + typ = finishType + default: + return nil, fmt.Errorf("unknown part type: %T", part) } - } - return Message{ - ID: item.ID, - SessionID: item.SessionID, - Role: MessageRole(item.Role), - Content: item.Content, - Thinking: item.Thinking, - Finished: item.Finished, - ToolCalls: toolCalls, - ToolResults: toolResults, - CreatedAt: item.CreatedAt, - UpdatedAt: item.UpdatedAt, - }, nil + wrappedParts[i] = partWrapper{ + Type: typ, + Data: part, + } + } + return json.Marshal(wrappedParts) } -func NewService(ctx context.Context, q db.Querier) Service { - return &service{ - Broker: pubsub.NewBroker[Message](), - q: q, - ctx: ctx, +func unmarshallParts(data []byte) ([]ContentPart, error) { + temp := []json.RawMessage{} + + if err := json.Unmarshal(data, &temp); err != nil { + return nil, err } + + parts := make([]ContentPart, 0) + + for _, rawPart := range temp { + var wrapper struct { + Type partType `json:"type"` + Data json.RawMessage `json:"data"` + } + + if err := json.Unmarshal(rawPart, &wrapper); err != nil { + return nil, err + } + + switch wrapper.Type { + case reasoningType: + part := ReasoningContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case textType: + part := TextContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case imageURLType: + part := ImageURLContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + case binaryType: + part := BinaryContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolCallType: + part := ToolCall{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolResultType: + part := ToolResult{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case finishType: + part := Finish{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + default: + return nil, fmt.Errorf("unknown part type: %s", wrapper.Type) + } + + } + + return parts, nil } |
