You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.7 KiB
78 lines
2.7 KiB
package impl
|
|
|
|
import (
|
|
"health-ai/internal/database"
|
|
"health-ai/internal/model"
|
|
)
|
|
|
|
type ConversationRepository struct{}
|
|
|
|
func NewConversationRepository() *ConversationRepository {
|
|
return &ConversationRepository{}
|
|
}
|
|
|
|
// Create 创建对话
|
|
func (r *ConversationRepository) Create(conv *model.Conversation) error {
|
|
return database.DB.Create(conv).Error
|
|
}
|
|
|
|
// GetByID 根据ID获取对话(含消息)
|
|
func (r *ConversationRepository) GetByID(id uint) (*model.Conversation, error) {
|
|
var conv model.Conversation
|
|
err := database.DB.Preload("Messages").First(&conv, id).Error
|
|
return &conv, err
|
|
}
|
|
|
|
// GetByUserID 获取用户的所有对话
|
|
func (r *ConversationRepository) GetByUserID(userID uint) ([]model.Conversation, error) {
|
|
var convs []model.Conversation
|
|
err := database.DB.Where("user_id = ?", userID).Order("updated_at DESC").Find(&convs).Error
|
|
return convs, err
|
|
}
|
|
|
|
// Delete 删除对话(同时删除消息)
|
|
func (r *ConversationRepository) Delete(id uint) error {
|
|
// 先删除消息
|
|
database.DB.Where("conversation_id = ?", id).Delete(&model.Message{})
|
|
return database.DB.Delete(&model.Conversation{}, id).Error
|
|
}
|
|
|
|
// AddMessage 添加消息
|
|
func (r *ConversationRepository) AddMessage(msg *model.Message) error {
|
|
// 同时更新对话的更新时间
|
|
database.DB.Model(&model.Conversation{}).Where("id = ?", msg.ConversationID).Update("updated_at", msg.CreatedAt)
|
|
return database.DB.Create(msg).Error
|
|
}
|
|
|
|
// GetMessages 获取对话的消息
|
|
func (r *ConversationRepository) GetMessages(convID uint) ([]model.Message, error) {
|
|
var messages []model.Message
|
|
err := database.DB.Where("conversation_id = ?", convID).Order("created_at ASC").Find(&messages).Error
|
|
return messages, err
|
|
}
|
|
|
|
// GetRecentMessages 获取对话最近的N条消息
|
|
func (r *ConversationRepository) GetRecentMessages(convID uint, limit int) ([]model.Message, error) {
|
|
var messages []model.Message
|
|
err := database.DB.Where("conversation_id = ?", convID).Order("created_at DESC").Limit(limit).Find(&messages).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// 反转顺序,使消息按时间正序排列
|
|
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
|
messages[i], messages[j] = messages[j], messages[i]
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
// UpdateTitle 更新对话标题
|
|
func (r *ConversationRepository) UpdateTitle(id uint, title string) error {
|
|
return database.DB.Model(&model.Conversation{}).Where("id = ?", id).Update("title", title).Error
|
|
}
|
|
|
|
// CheckOwnership 检查对话是否属于用户
|
|
func (r *ConversationRepository) CheckOwnership(convID, userID uint) bool {
|
|
var count int64
|
|
database.DB.Model(&model.Conversation{}).Where("id = ? AND user_id = ?", convID, userID).Count(&count)
|
|
return count > 0
|
|
}
|
|
|