diff options
| author | Kujtim Hoxha <[email protected]> | 2025-03-27 22:35:48 +0100 |
|---|---|---|
| committer | Kujtim Hoxha <[email protected]> | 2025-04-01 13:38:54 +0200 |
| commit | afd9ad0560d76c2a6d161dad52553b10ff428905 (patch) | |
| tree | 69f78b05ff0d7952cd3e3c9332f001e66abb2faf /internal/message | |
| parent | 904061c243f70696bfe781e97bf4e392e6954d07 (diff) | |
| download | opencode-afd9ad0560d76c2a6d161dad52553b10ff428905.tar.gz opencode-afd9ad0560d76c2a6d161dad52553b10ff428905.zip | |
rework llm
Diffstat (limited to 'internal/message')
| -rw-r--r-- | internal/message/message.go | 158 |
1 files changed, 127 insertions, 31 deletions
diff --git a/internal/message/message.go b/internal/message/message.go index e61fcef6d..157c15c7c 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -2,26 +2,65 @@ package message import ( "context" + "database/sql" "encoding/json" - "github.com/cloudwego/eino/schema" "github.com/google/uuid" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/pubsub" ) +type MessageRole string + +const ( + Assistant MessageRole = "assistant" + User MessageRole = "user" + System MessageRole = "system" + Tool MessageRole = "tool" +) + +type ToolResult struct { + ToolCallID string + Content string + IsError bool + // TODO: support for images +} + +type ToolCall struct { + ID string + Name string + Input string + Type string +} + type Message struct { - ID string - SessionID string - MessageData schema.Message + ID string + SessionID string + + // NEW + Role MessageRole + Content string + Thinking string - CreatedAt int64 - UpdatedAt int64 + Finished bool + + ToolResults []ToolResult + ToolCalls []ToolCall + CreatedAt int64 + UpdatedAt int64 +} + +type CreateMessageParams struct { + Role MessageRole + Content string + ToolCalls []ToolCall + ToolResults []ToolResult } type Service interface { pubsub.Suscriber[Message] - Create(sessionID string, messageData schema.Message) (Message, error) + Create(sessionID string, params CreateMessageParams) (Message, error) + Update(message Message) error Get(id string) (Message, error) List(sessionID string) ([]Message, error) Delete(id string) error @@ -34,35 +73,46 @@ type service struct { ctx context.Context } -func (s *service) Create(sessionID string, messageData schema.Message) (Message, error) { - messageDataJSON, err := json.Marshal(messageData) +func (s *service) Delete(id string) error { + message, err := s.Get(id) + if err != nil { + return err + } + err = s.q.DeleteMessage(s.ctx, message.ID) + if err != nil { + return err + } + s.Publish(pubsub.DeletedEvent, message) + return nil +} + +func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) { + toolCallsStr, err := json.Marshal(params.ToolCalls) + if err != nil { + return Message{}, err + } + toolResultsStr, err := json.Marshal(params.ToolResults) if err != nil { return Message{}, err } dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{ ID: uuid.New().String(), SessionID: sessionID, - MessageData: string(messageDataJSON), + Role: string(params.Role), + Finished: params.Role != Assistant, + Content: params.Content, + ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true}, + ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true}, }) if err != nil { return Message{}, err } - message := s.fromDBItem(dbMessage) - s.Publish(pubsub.CreatedEvent, message) - return message, nil -} - -func (s *service) Delete(id string) error { - message, err := s.Get(id) - if err != nil { - return err - } - err = s.q.DeleteMessage(s.ctx, message.ID) + message, err := s.fromDBItem(dbMessage) if err != nil { - return err + return Message{}, err } - s.Publish(pubsub.DeletedEvent, message) - return nil + s.Publish(pubsub.CreatedEvent, message) + return message, nil } func (s *service) DeleteSessionMessages(sessionID string) error { @@ -81,12 +131,36 @@ func (s *service) DeleteSessionMessages(sessionID string) error { return nil } +func (s *service) Update(message Message) error { + toolCallsStr, err := json.Marshal(message.ToolCalls) + if err != nil { + return err + } + toolResultsStr, err := json.Marshal(message.ToolResults) + if err != nil { + return err + } + err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{ + ID: message.ID, + Content: message.Content, + Thinking: message.Thinking, + Finished: message.Finished, + ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true}, + ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true}, + }) + if err != nil { + return err + } + s.Publish(pubsub.UpdatedEvent, message) + return nil +} + func (s *service) Get(id string) (Message, error) { dbMessage, err := s.q.GetMessage(s.ctx, id) if err != nil { return Message{}, err } - return s.fromDBItem(dbMessage), nil + return s.fromDBItem(dbMessage) } func (s *service) List(sessionID string) ([]Message, error) { @@ -96,21 +170,43 @@ func (s *service) List(sessionID string) ([]Message, error) { } messages := make([]Message, len(dbMessages)) for i, dbMessage := range dbMessages { - messages[i] = s.fromDBItem(dbMessage) + messages[i], err = s.fromDBItem(dbMessage) + if err != nil { + return nil, err + } } return messages, nil } -func (s *service) fromDBItem(item db.Message) Message { - var messageData schema.Message - json.Unmarshal([]byte(item.MessageData), &messageData) +func (s *service) fromDBItem(item db.Message) (Message, error) { + toolCalls := make([]ToolCall, 0) + if item.ToolCalls.Valid { + err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls) + if err != nil { + return Message{}, err + } + } + + toolResults := make([]ToolResult, 0) + if item.ToolResults.Valid { + err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults) + if err != nil { + return Message{}, err + } + } + return Message{ ID: item.ID, SessionID: item.SessionID, - MessageData: messageData, + Role: MessageRole(item.Role), + Content: item.Content, + Thinking: item.Thinking, + Finished: item.Finished, + ToolCalls: toolCalls, + ToolResults: toolResults, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, - } + }, nil } func NewService(ctx context.Context, q db.Querier) Service { |
