summaryrefslogtreecommitdiffhomepage
path: root/internal/message
diff options
context:
space:
mode:
authorKujtim Hoxha <[email protected]>2025-03-27 22:35:48 +0100
committerKujtim Hoxha <[email protected]>2025-04-01 13:38:54 +0200
commitafd9ad0560d76c2a6d161dad52553b10ff428905 (patch)
tree69f78b05ff0d7952cd3e3c9332f001e66abb2faf /internal/message
parent904061c243f70696bfe781e97bf4e392e6954d07 (diff)
downloadopencode-afd9ad0560d76c2a6d161dad52553b10ff428905.tar.gz
opencode-afd9ad0560d76c2a6d161dad52553b10ff428905.zip
rework llm
Diffstat (limited to 'internal/message')
-rw-r--r--internal/message/message.go158
1 files changed, 127 insertions, 31 deletions
diff --git a/internal/message/message.go b/internal/message/message.go
index e61fcef6d..157c15c7c 100644
--- a/internal/message/message.go
+++ b/internal/message/message.go
@@ -2,26 +2,65 @@ package message
import (
"context"
+ "database/sql"
"encoding/json"
- "github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/pubsub"
)
+type MessageRole string
+
+const (
+ Assistant MessageRole = "assistant"
+ User MessageRole = "user"
+ System MessageRole = "system"
+ Tool MessageRole = "tool"
+)
+
+type ToolResult struct {
+ ToolCallID string
+ Content string
+ IsError bool
+ // TODO: support for images
+}
+
+type ToolCall struct {
+ ID string
+ Name string
+ Input string
+ Type string
+}
+
type Message struct {
- ID string
- SessionID string
- MessageData schema.Message
+ ID string
+ SessionID string
+
+ // NEW
+ Role MessageRole
+ Content string
+ Thinking string
- CreatedAt int64
- UpdatedAt int64
+ Finished bool
+
+ ToolResults []ToolResult
+ ToolCalls []ToolCall
+ CreatedAt int64
+ UpdatedAt int64
+}
+
+type CreateMessageParams struct {
+ Role MessageRole
+ Content string
+ ToolCalls []ToolCall
+ ToolResults []ToolResult
}
type Service interface {
pubsub.Suscriber[Message]
- Create(sessionID string, messageData schema.Message) (Message, error)
+ Create(sessionID string, params CreateMessageParams) (Message, error)
+ Update(message Message) error
Get(id string) (Message, error)
List(sessionID string) ([]Message, error)
Delete(id string) error
@@ -34,35 +73,46 @@ type service struct {
ctx context.Context
}
-func (s *service) Create(sessionID string, messageData schema.Message) (Message, error) {
- messageDataJSON, err := json.Marshal(messageData)
+func (s *service) Delete(id string) error {
+ message, err := s.Get(id)
+ if err != nil {
+ return err
+ }
+ err = s.q.DeleteMessage(s.ctx, message.ID)
+ if err != nil {
+ return err
+ }
+ s.Publish(pubsub.DeletedEvent, message)
+ return nil
+}
+
+func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
+ toolCallsStr, err := json.Marshal(params.ToolCalls)
+ if err != nil {
+ return Message{}, err
+ }
+ toolResultsStr, err := json.Marshal(params.ToolResults)
if err != nil {
return Message{}, err
}
dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
ID: uuid.New().String(),
SessionID: sessionID,
- MessageData: string(messageDataJSON),
+ Role: string(params.Role),
+ Finished: params.Role != Assistant,
+ Content: params.Content,
+ ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
+ ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
})
if err != nil {
return Message{}, err
}
- message := s.fromDBItem(dbMessage)
- s.Publish(pubsub.CreatedEvent, message)
- return message, nil
-}
-
-func (s *service) Delete(id string) error {
- message, err := s.Get(id)
- if err != nil {
- return err
- }
- err = s.q.DeleteMessage(s.ctx, message.ID)
+ message, err := s.fromDBItem(dbMessage)
if err != nil {
- return err
+ return Message{}, err
}
- s.Publish(pubsub.DeletedEvent, message)
- return nil
+ s.Publish(pubsub.CreatedEvent, message)
+ return message, nil
}
func (s *service) DeleteSessionMessages(sessionID string) error {
@@ -81,12 +131,36 @@ func (s *service) DeleteSessionMessages(sessionID string) error {
return nil
}
+func (s *service) Update(message Message) error {
+ toolCallsStr, err := json.Marshal(message.ToolCalls)
+ if err != nil {
+ return err
+ }
+ toolResultsStr, err := json.Marshal(message.ToolResults)
+ if err != nil {
+ return err
+ }
+ err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
+ ID: message.ID,
+ Content: message.Content,
+ Thinking: message.Thinking,
+ Finished: message.Finished,
+ ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
+ ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
+ })
+ if err != nil {
+ return err
+ }
+ s.Publish(pubsub.UpdatedEvent, message)
+ return nil
+}
+
func (s *service) Get(id string) (Message, error) {
dbMessage, err := s.q.GetMessage(s.ctx, id)
if err != nil {
return Message{}, err
}
- return s.fromDBItem(dbMessage), nil
+ return s.fromDBItem(dbMessage)
}
func (s *service) List(sessionID string) ([]Message, error) {
@@ -96,21 +170,43 @@ func (s *service) List(sessionID string) ([]Message, error) {
}
messages := make([]Message, len(dbMessages))
for i, dbMessage := range dbMessages {
- messages[i] = s.fromDBItem(dbMessage)
+ messages[i], err = s.fromDBItem(dbMessage)
+ if err != nil {
+ return nil, err
+ }
}
return messages, nil
}
-func (s *service) fromDBItem(item db.Message) Message {
- var messageData schema.Message
- json.Unmarshal([]byte(item.MessageData), &messageData)
+func (s *service) fromDBItem(item db.Message) (Message, error) {
+ toolCalls := make([]ToolCall, 0)
+ if item.ToolCalls.Valid {
+ err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
+ if err != nil {
+ return Message{}, err
+ }
+ }
+
+ toolResults := make([]ToolResult, 0)
+ if item.ToolResults.Valid {
+ err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
+ if err != nil {
+ return Message{}, err
+ }
+ }
+
return Message{
ID: item.ID,
SessionID: item.SessionID,
- MessageData: messageData,
+ Role: MessageRole(item.Role),
+ Content: item.Content,
+ Thinking: item.Thinking,
+ Finished: item.Finished,
+ ToolCalls: toolCalls,
+ ToolResults: toolResults,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
- }
+ }, nil
}
func NewService(ctx context.Context, q db.Querier) Service {