package handler import ( "encoding/json" "fmt" "net/http" "strings" "time" "healthapi/internal/logic" "healthapi/internal/model" "healthapi/internal/svc" "healthapi/internal/types" "healthapi/pkg/ai" "healthapi/pkg/errorx" "github.com/zeromicro/go-zero/rest/httpx" ) func SendMessageStreamHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.SendMessageReq if err := httpx.Parse(r, &req); err != nil { httpx.ErrorCtx(r.Context(), w, err) return } // 获取用户 ID userID, err := logic.GetUserIDFromCtx(r.Context()) if err != nil { httpx.ErrorCtx(r.Context(), w, errorx.ErrUnauthorized) return } // 验证对话属于该用户 var conversation model.Conversation if err := svcCtx.DB.Where("id = ? AND user_id = ?", req.Id, userID).First(&conversation).Error; err != nil { httpx.ErrorCtx(r.Context(), w, errorx.ErrNotFound) return } // 保存用户消息 userMessage := model.Message{ ConversationID: conversation.ID, Role: model.RoleUser, Content: req.Content, } if err := svcCtx.DB.Create(&userMessage).Error; err != nil { httpx.ErrorCtx(r.Context(), w, errorx.ErrServerError) return } // 设置 SSE 响应头(与原 server 保持一致) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") // 禁用 nginx 缓冲 w.Header().Set("Access-Control-Allow-Origin", "*") // 发送用户消息 ID msgData, _ := json.Marshal(map[string]interface{}{ "type": "user_message", "message_id": userMessage.ID, }) fmt.Fprintf(w, "data: %s\n\n", msgData) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } // 获取历史消息 var historyMessages []model.Message svcCtx.DB.Where("conversation_id = ?", conversation.ID). Order("created_at DESC"). Limit(svcCtx.Config.AI.MaxHistoryMessages). Find(&historyMessages) // 构建系统提示 systemPrompt := buildSystemPromptForStream(svcCtx, userID) // 构建 AI 消息 aiMessages := []ai.Message{{Role: "system", Content: systemPrompt}} for i := len(historyMessages) - 1; i >= 0; i-- { aiMessages = append(aiMessages, ai.Message{ Role: historyMessages[i].Role, Content: historyMessages[i].Content, }) } // 创建收集器 collector := &responseCollector{writer: w} // 调用 AI 流式服务 err = svcCtx.AIClient.ChatStream(r.Context(), aiMessages, collector) if err != nil { errData, _ := json.Marshal(map[string]interface{}{ "type": "error", "error": err.Error(), }) fmt.Fprintf(w, "data: %s\n\n", errData) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } return } // 保存 AI 回复 assistantMessage := model.Message{ ConversationID: conversation.ID, Role: model.RoleAssistant, Content: collector.content, } svcCtx.DB.Create(&assistantMessage) // 更新对话标题 if conversation.Title == "新对话" { title := req.Content if len(title) > 50 { title = title[:50] + "..." } svcCtx.DB.Model(&conversation).Update("title", title) } // 发送完成消息(使用 "end" 类型,与原 server 和前端保持一致) endData, _ := json.Marshal(map[string]interface{}{ "type": "end", "message_id": assistantMessage.ID, }) fmt.Fprintf(w, "data: %s\n\n", endData) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } } } // responseCollector 收集响应内容(使用缓冲区按行解析,与原 server 一致) type responseCollector struct { writer http.ResponseWriter content string buffer string } func (c *responseCollector) Write(p []byte) (n int, err error) { // 累积数据到 buffer c.buffer += string(p) // 按行处理 for { idx := strings.Index(c.buffer, "\n") if idx == -1 { break } line := c.buffer[:idx] c.buffer = c.buffer[idx+1:] // 解析 SSE 数据提取内容 line = strings.TrimSpace(line) if strings.HasPrefix(line, "data: ") { jsonStr := strings.TrimPrefix(line, "data: ") jsonStr = strings.TrimSpace(jsonStr) if jsonStr != "" && jsonStr != "[DONE]" { var data struct { Type string `json:"type"` Content string `json:"content"` } if err := json.Unmarshal([]byte(jsonStr), &data); err == nil { if data.Type == "content" { c.content += data.Content } } } } } // 同时写入原始 writer return c.writer.Write(p) } // 系统提示词模板(与原 server 保持一致) const systemPromptTemplate = `# 用户相关信息 ## 用户信息 %s ## 用户体质 %s ## 用户病史 %s ## 用户家族病史 %s ## 用户过敏记录 %s ` func buildSystemPromptForStream(svcCtx *svc.ServiceContext, userID uint) string { var userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo string // 获取用户健康档案 var profile model.HealthProfile if err := svcCtx.DB.Where("user_id = ?", userID).First(&profile).Error; err == nil && profile.ID > 0 { // 基本信息 age := calculateAge(profile.BirthDate) gender := "未知" if profile.Gender == "male" { gender = "男" } else if profile.Gender == "female" { gender = "女" } bmi := float64(0) if profile.Height > 0 && profile.Weight > 0 { heightM := float64(profile.Height) / 100 bmi = float64(profile.Weight) / (heightM * heightM) } userProfile = fmt.Sprintf("性别:%s,年龄:%d岁,BMI:%.1f", gender, age, bmi) // 获取病史记录 var medicalHistories []model.MedicalHistory svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&medicalHistories) if len(medicalHistories) > 0 { var items []string for _, h := range medicalHistories { status := "治疗中" if h.Status == "cured" { status = "已治愈" } else if h.Status == "controlled" { status = "已控制" } items = append(items, fmt.Sprintf("- %s(%s,%s)", h.DiseaseName, h.DiagnosedDate, status)) } medicalInfo = fmt.Sprintf("共%d条记录:\n%s", len(medicalHistories), strings.Join(items, "\n")) } else { medicalInfo = "暂无病史记录" } // 获取家族病史 var familyHistories []model.FamilyHistory svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&familyHistories) if len(familyHistories) > 0 { var items []string for _, h := range familyHistories { relation := h.Relation switch relation { case "father": relation = "父亲" case "mother": relation = "母亲" case "grandparent": relation = "祖父母" case "sibling": relation = "兄弟姐妹" } items = append(items, fmt.Sprintf("- %s:%s", relation, h.DiseaseName)) } familyInfo = fmt.Sprintf("共%d条记录:\n%s", len(familyHistories), strings.Join(items, "\n")) } else { familyInfo = "暂无家族病史" } // 获取过敏记录 var allergyRecords []model.AllergyRecord svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&allergyRecords) if len(allergyRecords) > 0 { var items []string for _, r := range allergyRecords { severity := "轻度" if r.Severity == "moderate" { severity = "中度" } else if r.Severity == "severe" { severity = "重度" } items = append(items, fmt.Sprintf("- %s(%s,%s)", r.Allergen, r.AllergyType, severity)) } allergyInfo = fmt.Sprintf("共%d条记录:\n%s", len(allergyRecords), strings.Join(items, "\n")) } else { allergyInfo = "暂无过敏记录" } } else { userProfile = "暂无基本信息" medicalInfo = "暂无病史记录" familyInfo = "暂无家族病史" allergyInfo = "暂无过敏记录" } // 获取用户体质信息 var assessment model.ConstitutionAssessment if err := svcCtx.DB.Where("user_id = ?", userID).Order("assessed_at DESC").First(&assessment).Error; err == nil && assessment.ID > 0 { constitutionName := model.ConstitutionNames[assessment.PrimaryConstitution] description := model.ConstitutionDescriptions[assessment.PrimaryConstitution] constitutionInfo = fmt.Sprintf("主体质:%s\n特征:%s", constitutionName, description) } else { constitutionInfo = "暂未进行体质测评" } return fmt.Sprintf(systemPromptTemplate, userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo) } // calculateAge 计算年龄 func calculateAge(birthDate *time.Time) int { if birthDate == nil { return 0 } now := time.Now() age := now.Year() - birthDate.Year() if now.YearDay() < birthDate.YearDay() { age-- } return age }