You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
311 lines
8.5 KiB
311 lines
8.5 KiB
package handler
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"healthapi/internal/logic"
|
|
"healthapi/internal/model"
|
|
"healthapi/internal/svc"
|
|
"healthapi/internal/types"
|
|
"healthapi/pkg/ai"
|
|
"healthapi/pkg/errorx"
|
|
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
)
|
|
|
|
func SendMessageStreamHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
var req types.SendMessageReq
|
|
if err := httpx.Parse(r, &req); err != nil {
|
|
httpx.ErrorCtx(r.Context(), w, err)
|
|
return
|
|
}
|
|
|
|
// 获取用户 ID
|
|
userID, err := logic.GetUserIDFromCtx(r.Context())
|
|
if err != nil {
|
|
httpx.ErrorCtx(r.Context(), w, errorx.ErrUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 验证对话属于该用户
|
|
var conversation model.Conversation
|
|
if err := svcCtx.DB.Where("id = ? AND user_id = ?", req.Id, userID).First(&conversation).Error; err != nil {
|
|
httpx.ErrorCtx(r.Context(), w, errorx.ErrNotFound)
|
|
return
|
|
}
|
|
|
|
// 保存用户消息
|
|
userMessage := model.Message{
|
|
ConversationID: conversation.ID,
|
|
Role: model.RoleUser,
|
|
Content: req.Content,
|
|
}
|
|
if err := svcCtx.DB.Create(&userMessage).Error; err != nil {
|
|
httpx.ErrorCtx(r.Context(), w, errorx.ErrServerError)
|
|
return
|
|
}
|
|
|
|
// 设置 SSE 响应头(与原 server 保持一致)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no") // 禁用 nginx 缓冲
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
|
// 发送用户消息 ID
|
|
msgData, _ := json.Marshal(map[string]interface{}{
|
|
"type": "user_message",
|
|
"message_id": userMessage.ID,
|
|
})
|
|
fmt.Fprintf(w, "data: %s\n\n", msgData)
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
|
|
// 获取历史消息
|
|
var historyMessages []model.Message
|
|
svcCtx.DB.Where("conversation_id = ?", conversation.ID).
|
|
Order("created_at DESC").
|
|
Limit(svcCtx.Config.AI.MaxHistoryMessages).
|
|
Find(&historyMessages)
|
|
|
|
// 构建系统提示
|
|
systemPrompt := buildSystemPromptForStream(svcCtx, userID)
|
|
|
|
// 构建 AI 消息
|
|
aiMessages := []ai.Message{{Role: "system", Content: systemPrompt}}
|
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
|
aiMessages = append(aiMessages, ai.Message{
|
|
Role: historyMessages[i].Role,
|
|
Content: historyMessages[i].Content,
|
|
})
|
|
}
|
|
|
|
// 创建收集器
|
|
collector := &responseCollector{writer: w}
|
|
|
|
// 调用 AI 流式服务
|
|
err = svcCtx.AIClient.ChatStream(r.Context(), aiMessages, collector)
|
|
if err != nil {
|
|
errData, _ := json.Marshal(map[string]interface{}{
|
|
"type": "error",
|
|
"error": err.Error(),
|
|
})
|
|
fmt.Fprintf(w, "data: %s\n\n", errData)
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
return
|
|
}
|
|
|
|
// 保存 AI 回复
|
|
assistantMessage := model.Message{
|
|
ConversationID: conversation.ID,
|
|
Role: model.RoleAssistant,
|
|
Content: collector.content,
|
|
}
|
|
svcCtx.DB.Create(&assistantMessage)
|
|
|
|
// 更新对话标题
|
|
if conversation.Title == "新对话" {
|
|
title := req.Content
|
|
if len(title) > 50 {
|
|
title = title[:50] + "..."
|
|
}
|
|
svcCtx.DB.Model(&conversation).Update("title", title)
|
|
}
|
|
|
|
// 发送完成消息(使用 "end" 类型,与原 server 和前端保持一致)
|
|
endData, _ := json.Marshal(map[string]interface{}{
|
|
"type": "end",
|
|
"message_id": assistantMessage.ID,
|
|
})
|
|
fmt.Fprintf(w, "data: %s\n\n", endData)
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// responseCollector 收集响应内容(使用缓冲区按行解析,与原 server 一致)
|
|
type responseCollector struct {
|
|
writer http.ResponseWriter
|
|
content string
|
|
buffer string
|
|
}
|
|
|
|
func (c *responseCollector) Write(p []byte) (n int, err error) {
|
|
// 累积数据到 buffer
|
|
c.buffer += string(p)
|
|
|
|
// 按行处理
|
|
for {
|
|
idx := strings.Index(c.buffer, "\n")
|
|
if idx == -1 {
|
|
break
|
|
}
|
|
line := c.buffer[:idx]
|
|
c.buffer = c.buffer[idx+1:]
|
|
|
|
// 解析 SSE 数据提取内容
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "data: ") {
|
|
jsonStr := strings.TrimPrefix(line, "data: ")
|
|
jsonStr = strings.TrimSpace(jsonStr)
|
|
if jsonStr != "" && jsonStr != "[DONE]" {
|
|
var data struct {
|
|
Type string `json:"type"`
|
|
Content string `json:"content"`
|
|
}
|
|
if err := json.Unmarshal([]byte(jsonStr), &data); err == nil {
|
|
if data.Type == "content" {
|
|
c.content += data.Content
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 同时写入原始 writer
|
|
return c.writer.Write(p)
|
|
}
|
|
|
|
// 系统提示词模板(与原 server 保持一致)
|
|
const systemPromptTemplate = `# 用户相关信息
|
|
|
|
## 用户信息
|
|
%s
|
|
|
|
## 用户体质
|
|
%s
|
|
|
|
## 用户病史
|
|
%s
|
|
|
|
## 用户家族病史
|
|
%s
|
|
|
|
## 用户过敏记录
|
|
%s
|
|
|
|
`
|
|
|
|
func buildSystemPromptForStream(svcCtx *svc.ServiceContext, userID uint) string {
|
|
var userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo string
|
|
|
|
// 获取用户健康档案
|
|
var profile model.HealthProfile
|
|
if err := svcCtx.DB.Where("user_id = ?", userID).First(&profile).Error; err == nil && profile.ID > 0 {
|
|
// 基本信息
|
|
age := calculateAge(profile.BirthDate)
|
|
gender := "未知"
|
|
if profile.Gender == "male" {
|
|
gender = "男"
|
|
} else if profile.Gender == "female" {
|
|
gender = "女"
|
|
}
|
|
bmi := float64(0)
|
|
if profile.Height > 0 && profile.Weight > 0 {
|
|
heightM := float64(profile.Height) / 100
|
|
bmi = float64(profile.Weight) / (heightM * heightM)
|
|
}
|
|
userProfile = fmt.Sprintf("性别:%s,年龄:%d岁,BMI:%.1f", gender, age, bmi)
|
|
|
|
// 获取病史记录
|
|
var medicalHistories []model.MedicalHistory
|
|
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&medicalHistories)
|
|
if len(medicalHistories) > 0 {
|
|
var items []string
|
|
for _, h := range medicalHistories {
|
|
status := "治疗中"
|
|
if h.Status == "cured" {
|
|
status = "已治愈"
|
|
} else if h.Status == "controlled" {
|
|
status = "已控制"
|
|
}
|
|
items = append(items, fmt.Sprintf("- %s(%s,%s)", h.DiseaseName, h.DiagnosedDate, status))
|
|
}
|
|
medicalInfo = fmt.Sprintf("共%d条记录:\n%s", len(medicalHistories), strings.Join(items, "\n"))
|
|
} else {
|
|
medicalInfo = "暂无病史记录"
|
|
}
|
|
|
|
// 获取家族病史
|
|
var familyHistories []model.FamilyHistory
|
|
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&familyHistories)
|
|
if len(familyHistories) > 0 {
|
|
var items []string
|
|
for _, h := range familyHistories {
|
|
relation := h.Relation
|
|
switch relation {
|
|
case "father":
|
|
relation = "父亲"
|
|
case "mother":
|
|
relation = "母亲"
|
|
case "grandparent":
|
|
relation = "祖父母"
|
|
case "sibling":
|
|
relation = "兄弟姐妹"
|
|
}
|
|
items = append(items, fmt.Sprintf("- %s:%s", relation, h.DiseaseName))
|
|
}
|
|
familyInfo = fmt.Sprintf("共%d条记录:\n%s", len(familyHistories), strings.Join(items, "\n"))
|
|
} else {
|
|
familyInfo = "暂无家族病史"
|
|
}
|
|
|
|
// 获取过敏记录
|
|
var allergyRecords []model.AllergyRecord
|
|
svcCtx.DB.Where("health_profile_id = ?", profile.ID).Find(&allergyRecords)
|
|
if len(allergyRecords) > 0 {
|
|
var items []string
|
|
for _, r := range allergyRecords {
|
|
severity := "轻度"
|
|
if r.Severity == "moderate" {
|
|
severity = "中度"
|
|
} else if r.Severity == "severe" {
|
|
severity = "重度"
|
|
}
|
|
items = append(items, fmt.Sprintf("- %s(%s,%s)", r.Allergen, r.AllergyType, severity))
|
|
}
|
|
allergyInfo = fmt.Sprintf("共%d条记录:\n%s", len(allergyRecords), strings.Join(items, "\n"))
|
|
} else {
|
|
allergyInfo = "暂无过敏记录"
|
|
}
|
|
} else {
|
|
userProfile = "暂无基本信息"
|
|
medicalInfo = "暂无病史记录"
|
|
familyInfo = "暂无家族病史"
|
|
allergyInfo = "暂无过敏记录"
|
|
}
|
|
|
|
// 获取用户体质信息
|
|
var assessment model.ConstitutionAssessment
|
|
if err := svcCtx.DB.Where("user_id = ?", userID).Order("assessed_at DESC").First(&assessment).Error; err == nil && assessment.ID > 0 {
|
|
constitutionName := model.ConstitutionNames[assessment.PrimaryConstitution]
|
|
description := model.ConstitutionDescriptions[assessment.PrimaryConstitution]
|
|
constitutionInfo = fmt.Sprintf("主体质:%s\n特征:%s", constitutionName, description)
|
|
} else {
|
|
constitutionInfo = "暂未进行体质测评"
|
|
}
|
|
|
|
return fmt.Sprintf(systemPromptTemplate, userProfile, constitutionInfo, medicalInfo, familyInfo, allergyInfo)
|
|
}
|
|
|
|
// calculateAge 计算年龄
|
|
func calculateAge(birthDate *time.Time) int {
|
|
if birthDate == nil {
|
|
return 0
|
|
}
|
|
now := time.Now()
|
|
age := now.Year() - birthDate.Year()
|
|
if now.YearDay() < birthDate.YearDay() {
|
|
age--
|
|
}
|
|
return age
|
|
}
|
|
|