healthapp
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

78 lines
2.7 KiB

package impl
import (
"health-ai/internal/database"
"health-ai/internal/model"
)
type ConversationRepository struct{}
func NewConversationRepository() *ConversationRepository {
return &ConversationRepository{}
}
// Create 创建对话
func (r *ConversationRepository) Create(conv *model.Conversation) error {
return database.DB.Create(conv).Error
}
// GetByID 根据ID获取对话(含消息)
func (r *ConversationRepository) GetByID(id uint) (*model.Conversation, error) {
var conv model.Conversation
err := database.DB.Preload("Messages").First(&conv, id).Error
return &conv, err
}
// GetByUserID 获取用户的所有对话
func (r *ConversationRepository) GetByUserID(userID uint) ([]model.Conversation, error) {
var convs []model.Conversation
err := database.DB.Where("user_id = ?", userID).Order("updated_at DESC").Find(&convs).Error
return convs, err
}
// Delete 删除对话(同时删除消息)
func (r *ConversationRepository) Delete(id uint) error {
// 先删除消息
database.DB.Where("conversation_id = ?", id).Delete(&model.Message{})
return database.DB.Delete(&model.Conversation{}, id).Error
}
// AddMessage 添加消息
func (r *ConversationRepository) AddMessage(msg *model.Message) error {
// 同时更新对话的更新时间
database.DB.Model(&model.Conversation{}).Where("id = ?", msg.ConversationID).Update("updated_at", msg.CreatedAt)
return database.DB.Create(msg).Error
}
// GetMessages 获取对话的消息
func (r *ConversationRepository) GetMessages(convID uint) ([]model.Message, error) {
var messages []model.Message
err := database.DB.Where("conversation_id = ?", convID).Order("created_at ASC").Find(&messages).Error
return messages, err
}
// GetRecentMessages 获取对话最近的N条消息
func (r *ConversationRepository) GetRecentMessages(convID uint, limit int) ([]model.Message, error) {
var messages []model.Message
err := database.DB.Where("conversation_id = ?", convID).Order("created_at DESC").Limit(limit).Find(&messages).Error
if err != nil {
return nil, err
}
// 反转顺序,使消息按时间正序排列
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, nil
}
// UpdateTitle 更新对话标题
func (r *ConversationRepository) UpdateTitle(id uint, title string) error {
return database.DB.Model(&model.Conversation{}).Where("id = ?", id).Update("title", title).Error
}
// CheckOwnership 检查对话是否属于用户
func (r *ConversationRepository) CheckOwnership(convID, userID uint) bool {
var count int64
database.DB.Model(&model.Conversation{}).Where("id = ? AND user_id = ?", convID, userID).Count(&count)
return count > 0
}