package impl import ( "health-ai/internal/database" "health-ai/internal/model" ) type ConversationRepository struct{} func NewConversationRepository() *ConversationRepository { return &ConversationRepository{} } // Create 创建对话 func (r *ConversationRepository) Create(conv *model.Conversation) error { return database.DB.Create(conv).Error } // GetByID 根据ID获取对话(含消息) func (r *ConversationRepository) GetByID(id uint) (*model.Conversation, error) { var conv model.Conversation err := database.DB.Preload("Messages").First(&conv, id).Error return &conv, err } // GetByUserID 获取用户的所有对话 func (r *ConversationRepository) GetByUserID(userID uint) ([]model.Conversation, error) { var convs []model.Conversation err := database.DB.Where("user_id = ?", userID).Order("updated_at DESC").Find(&convs).Error return convs, err } // Delete 删除对话(同时删除消息) func (r *ConversationRepository) Delete(id uint) error { // 先删除消息 database.DB.Where("conversation_id = ?", id).Delete(&model.Message{}) return database.DB.Delete(&model.Conversation{}, id).Error } // AddMessage 添加消息 func (r *ConversationRepository) AddMessage(msg *model.Message) error { // 同时更新对话的更新时间 database.DB.Model(&model.Conversation{}).Where("id = ?", msg.ConversationID).Update("updated_at", msg.CreatedAt) return database.DB.Create(msg).Error } // GetMessages 获取对话的消息 func (r *ConversationRepository) GetMessages(convID uint) ([]model.Message, error) { var messages []model.Message err := database.DB.Where("conversation_id = ?", convID).Order("created_at ASC").Find(&messages).Error return messages, err } // GetRecentMessages 获取对话最近的N条消息 func (r *ConversationRepository) GetRecentMessages(convID uint, limit int) ([]model.Message, error) { var messages []model.Message err := database.DB.Where("conversation_id = ?", convID).Order("created_at DESC").Limit(limit).Find(&messages).Error if err != nil { return nil, err } // 反转顺序,使消息按时间正序排列 for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } return messages, nil } // UpdateTitle 更新对话标题 func (r *ConversationRepository) UpdateTitle(id uint, title string) error { return database.DB.Model(&model.Conversation{}).Where("id = ?", id).Update("title", title).Error } // CheckOwnership 检查对话是否属于用户 func (r *ConversationRepository) CheckOwnership(convID, userID uint) bool { var count int64 database.DB.Model(&model.Conversation{}).Where("id = ? AND user_id = ?", convID, userID).Count(&count) return count > 0 }