healthapp
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

311 lines
8.5 KiB

package handler
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"healthapi/internal/logic"
"healthapi/internal/model"
"healthapi/internal/svc"
"healthapi/internal/types"
"healthapi/pkg/ai"
"healthapi/pkg/errorx"
"github.com/zeromicro/go-zero/rest/httpx"
)
func SendMessageStreamHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.SendMessageReq
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
// 获取用户 ID
userID, err := logic.GetUserIDFromCtx(r.Context())
if err != nil {
httpx.ErrorCtx(r.Context(), w, errorx.ErrUnauthorized)
return
}
// 验证对话属于该用户
var conversation model.Conversation
if err := svcCtx.DB.Where("id = ? AND user_id = ?", req.Id, userID).First(&conversation).Error; err != nil {
httpx.ErrorCtx(r.Context(), w, errorx.ErrNotFound)
return
}
// 保存用户消息
userMessage := model.Message{
ConversationID: conversation.ID,
Role: model.RoleUser,
Content: req.Content,
}
if err := svcCtx.DB.Create(&userMessage).Error; err != nil {
httpx.ErrorCtx(r.Context(), w, errorx.ErrServerError)
return
}
// 设置 SSE 响应头(与原 server 保持一致)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // 禁用 nginx 缓冲
w.Header().Set("Access-Control-Allow-Origin", "*")
// 发送用户消息 ID
msgData, _ := json.Marshal(map[string]interface{}{
"type": "user_message",
"message_id": userMessage.ID,
})
fmt.Fprintf(w, "data: %s\n\n", msgData)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
// 获取历史消息
var historyMessages []model.Message
svcCtx.DB.Where("conversation_id = ?", conversation.ID).
Order("created_at DESC").
Limit(svcCtx.Config.AI.MaxHistoryMessages).
Find(&historyMessages)
// 构建系统提示
systemPrompt := buildSystemPromptForStream(svcCtx, userID)
// 构建 AI 消息
aiMessages := []ai.Message{{Role: "system", Content: systemPrompt}}
for i := len(historyMessages) - 1; i >= 0; i-- {
aiMessages = append(aiMessages, ai.Message{
Role: historyMessages[i].Role,
Content: historyMessages[i].Content,
})
}
// 创建收集器
collector := &responseCollector{writer: w}
// 调用 AI 流式服务
err = svcCtx.AIClient.ChatStream(r.Context(), aiMessages, collector)
if err != nil {
errData, _ := json.Marshal(map[string]interface{}{
"type": "error",
"error": err.Error(),
})
fmt.Fprintf(w, "data: %s\n\n", errData)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return
}
// 保存 AI 回复
assistantMessage := model.Message{
ConversationID: conversation.ID,
Role: model.RoleAssistant,
Content: collector.content,
}
svcCtx.DB.Create(&assistantMessage)
// 更新对话标题
if conversation.Title == "新对话" {
title := req.Content
if len(title) > 50 {
title = title[:50] + "..."
}
svcCtx.DB.Model(&conversation).Update("title", title)
}
// 发送完成消息(使用 "end" 类型,与原 server 和前端保持一致)
endData, _ := json.Marshal(map[string]interface{}{
"type": "end",
"message_id": assistantMessage.ID,
})
fmt.Fprintf(w, "data: %s\n\n", endData)
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
}
// responseCollector 收集响应内容(使用缓冲区按行解析,与原 server 一致)
type responseCollector struct {
writer http.ResponseWriter
content string
buffer string
}
func (c *responseCollector) Write(p []byte) (n int, err error) {
// 累积数据到 buffer
c.buffer += string(p)
// 按行处理
for {
idx := strings.Index(c.buffer, "\n")
if idx == -1 {
break
}
line := c.buffer[:idx]
c.buffer = c.buffer[idx+1:]
// 解析 SSE 数据提取内容
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data: ") {
jsonStr := strings.TrimPrefix(line, "data: ")
jsonStr = strings.TrimSpace(jsonStr)
if jsonStr != "" && jsonStr != "[DONE]" {
var data struct {
Type string `json:"type"`
Content string `json:"content"`
}
if err := json.Unmarshal([]byte(jsonStr), &data); err == nil {
if data.Type == "content" {
c.content += data.Content
}
}
}
}
}
// 同时写入原始 writer
return c.writer.Write(p)
}
// 系统提示词模板(与原 server 保持一致)
const systemPromptTemplate = `# 用户相关信息
## 用户信息
%s
## 用户体质
%s
## 用户病史
%s
## 用户家族病史
%s
## 用户过敏记录
%s
`
func buildSystemPromptForStream(svcCtx *svc.ServiceContext, userID uint) string {
var userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo string
// 获取用户健康档案
var profile model.HealthProfile
if err := svcCtx.DB.Where("user_id = ?", userID).First(&profile).Error; err == nil && profile.ID > 0 {
// 基本信息
age := calculateAge(profile.BirthDate)
gender := "未知"
if profile.Gender == "male" {
gender = "男"
} else if profile.Gender == "female" {
gender = "女"
}
bmi := float64(0)
if profile.Height > 0 && profile.Weight > 0 {
heightM := float64(profile.Height) / 100
bmi = float64(profile.Weight) / (heightM * heightM)
}
userProfile = fmt.Sprintf("性别:%s,年龄:%d岁,BMI:%.1f", gender, age, bmi)
// 获取病史记录
var medicalHistories []model.MedicalHistory
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&medicalHistories)
if len(medicalHistories) > 0 {
var items []string
for _, h := range medicalHistories {
status := "治疗中"
if h.Status == "cured" {
status = "已治愈"
} else if h.Status == "controlled" {
status = "已控制"
}
items = append(items, fmt.Sprintf("- %s(%s,%s)", h.DiseaseName, h.DiagnosedDate, status))
}
medicalInfo = fmt.Sprintf("共%d条记录:\n%s", len(medicalHistories), strings.Join(items, "\n"))
} else {
medicalInfo = "暂无病史记录"
}
// 获取家族病史
var familyHistories []model.FamilyHistory
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&familyHistories)
if len(familyHistories) > 0 {
var items []string
for _, h := range familyHistories {
relation := h.Relation
switch relation {
case "father":
relation = "父亲"
case "mother":
relation = "母亲"
case "grandparent":
relation = "祖父母"
case "sibling":
relation = "兄弟姐妹"
}
items = append(items, fmt.Sprintf("- %s:%s", relation, h.DiseaseName))
}
familyInfo = fmt.Sprintf("共%d条记录:\n%s", len(familyHistories), strings.Join(items, "\n"))
} else {
familyInfo = "暂无家族病史"
}
// 获取过敏记录
var allergyRecords []model.AllergyRecord
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&allergyRecords)
if len(allergyRecords) > 0 {
var items []string
for _, r := range allergyRecords {
severity := "轻度"
if r.Severity == "moderate" {
severity = "中度"
} else if r.Severity == "severe" {
severity = "重度"
}
items = append(items, fmt.Sprintf("- %s(%s,%s)", r.Allergen, r.AllergyType, severity))
}
allergyInfo = fmt.Sprintf("共%d条记录:\n%s", len(allergyRecords), strings.Join(items, "\n"))
} else {
allergyInfo = "暂无过敏记录"
}
} else {
userProfile = "暂无基本信息"
medicalInfo = "暂无病史记录"
familyInfo = "暂无家族病史"
allergyInfo = "暂无过敏记录"
}
// 获取用户体质信息
var assessment model.ConstitutionAssessment
if err := svcCtx.DB.Where("user_id = ?", userID).Order("assessed_at DESC").First(&assessment).Error; err == nil && assessment.ID > 0 {
constitutionName := model.ConstitutionNames[assessment.PrimaryConstitution]
description := model.ConstitutionDescriptions[assessment.PrimaryConstitution]
constitutionInfo = fmt.Sprintf("主体质:%s\n特征:%s", constitutionName, description)
} else {
constitutionInfo = "暂未进行体质测评"
}
return fmt.Sprintf(systemPromptTemplate, userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo)
}
// calculateAge 计算年龄
func calculateAge(birthDate *time.Time) int {
if birthDate == nil {
return 0
}
now := time.Now()
age := now.Year() - birthDate.Year()
if now.YearDay() < birthDate.YearDay() {
age--
}
return age
}