healthapp
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

19 KiB

06-AI对话模块

目标

实现 AI 健康问诊对话功能,支持多轮对话、结合用户体质信息、流式响应。


前置要求

  • 体质辨识模块已完成
  • 已有 AI API Key(OpenAI / 通义千问)

实施步骤

步骤 1:创建 AI 客户端抽象

创建 server/internal/service/ai/client.go

package ai

import (
	"context"
	"io"
)

// AIClient AI 客户端接口
type AIClient interface {
	Chat(ctx context.Context, messages []Message) (string, error)
	ChatStream(ctx context.Context, messages []Message, writer io.Writer) error
}

// Message 对话消息
type Message struct {
	Role    string `json:"role"`    // system, user, assistant
	Content string `json:"content"`
}

// Config AI 配置
type Config struct {
	Provider string
	APIKey   string
	BaseURL  string
	Model    string
}

步骤 2:实现 OpenAI 客户端

创建 server/internal/service/ai/openai.go

package ai

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
)

type OpenAIClient struct {
	apiKey  string
	baseURL string
	model   string
}

func NewOpenAIClient(cfg *Config) *OpenAIClient {
	baseURL := cfg.BaseURL
	if baseURL == "" {
		baseURL = "https://api.openai.com/v1"
	}
	model := cfg.Model
	if model == "" {
		model = "gpt-3.5-turbo"
	}
	return &OpenAIClient{
		apiKey:  cfg.APIKey,
		baseURL: baseURL,
		model:   model,
	}
}

type openAIRequest struct {
	Model    string    `json:"model"`
	Messages []Message `json:"messages"`
	Stream   bool      `json:"stream"`
}

type openAIResponse struct {
	Choices []struct {
		Message struct {
			Content string `json:"content"`
		} `json:"message"`
		Delta struct {
			Content string `json:"content"`
		} `json:"delta"`
	} `json:"choices"`
}

func (c *OpenAIClient) Chat(ctx context.Context, messages []Message) (string, error) {
	reqBody := openAIRequest{
		Model:    c.model,
		Messages: messages,
		Stream:   false,
	}

	body, _ := json.Marshal(reqBody)
	req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
	if err != nil {
		return "", err
	}

	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", "Bearer "+c.apiKey)

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return "", err
	}
	defer resp.Body.Close()

	var result openAIResponse
	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
		return "", err
	}

	if len(result.Choices) == 0 {
		return "", fmt.Errorf("no response from AI")
	}

	return result.Choices[0].Message.Content, nil
}

func (c *OpenAIClient) ChatStream(ctx context.Context, messages []Message, writer io.Writer) error {
	reqBody := openAIRequest{
		Model:    c.model,
		Messages: messages,
		Stream:   true,
	}

	body, _ := json.Marshal(reqBody)
	req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
	if err != nil {
		return err
	}

	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", "Bearer "+c.apiKey)

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// 读取 SSE 流
	reader := resp.Body
	buf := make([]byte, 1024)
	for {
		n, err := reader.Read(buf)
		if err == io.EOF {
			break
		}
		if err != nil {
			return err
		}
		writer.Write(buf[:n])
	}

	return nil
}

步骤 2.5:实现阿里云通义千问客户端

创建 server/internal/service/ai/aliyun.go

package ai

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
)

const AliyunBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"

type AliyunClient struct {
	apiKey string
	model  string
}

func NewAliyunClient(cfg *Config) *AliyunClient {
	model := cfg.Model
	if model == "" {
		model = "qwen-turbo"
	}
	return &AliyunClient{
		apiKey: cfg.APIKey,
		model:  model,
	}
}

type aliyunRequest struct {
	Model string `json:"model"`
	Input struct {
		Messages []Message `json:"messages"`
	} `json:"input"`
	Parameters struct {
		ResultFormat string `json:"result_format"`
		MaxTokens    int    `json:"max_tokens,omitempty"`
	} `json:"parameters"`
}

type aliyunResponse struct {
	Output struct {
		Text    string `json:"text"`
		Choices []struct {
			Message struct {
				Content string `json:"content"`
			} `json:"message"`
		} `json:"choices"`
	} `json:"output"`
	Usage struct {
		InputTokens  int `json:"input_tokens"`
		OutputTokens int `json:"output_tokens"`
	} `json:"usage"`
	Code    string `json:"code"`
	Message string `json:"message"`
}

func (c *AliyunClient) Chat(ctx context.Context, messages []Message) (string, error) {
	reqBody := aliyunRequest{
		Model: c.model,
	}
	reqBody.Input.Messages = messages
	reqBody.Parameters.ResultFormat = "message"

	body, _ := json.Marshal(reqBody)
	req, err := http.NewRequestWithContext(ctx, "POST", AliyunBaseURL, bytes.NewReader(body))
	if err != nil {
		return "", err
	}

	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", "Bearer "+c.apiKey)

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return "", err
	}
	defer resp.Body.Close()

	var result aliyunResponse
	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
		return "", err
	}

	if result.Code != "" {
		return "", fmt.Errorf("aliyun API error: %s - %s", result.Code, result.Message)
	}

	// 兼容两种返回格式
	if len(result.Output.Choices) > 0 {
		return result.Output.Choices[0].Message.Content, nil
	}
	if result.Output.Text != "" {
		return result.Output.Text, nil
	}

	return "", fmt.Errorf("no response from AI")
}

func (c *AliyunClient) ChatStream(ctx context.Context, messages []Message, writer io.Writer) error {
	reqBody := aliyunRequest{
		Model: c.model,
	}
	reqBody.Input.Messages = messages
	reqBody.Parameters.ResultFormat = "message"

	body, _ := json.Marshal(reqBody)
	req, err := http.NewRequestWithContext(ctx, "POST", AliyunBaseURL, bytes.NewReader(body))
	if err != nil {
		return err
	}

	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", "Bearer "+c.apiKey)
	req.Header.Set("X-DashScope-SSE", "enable") // 启用流式输出

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// 读取 SSE 流
	buf := make([]byte, 1024)
	for {
		n, err := resp.Body.Read(buf)
		if err == io.EOF {
			break
		}
		if err != nil {
			return err
		}
		writer.Write(buf[:n])
	}

	return nil
}

步骤 2.6:创建 AI 客户端工厂

创建 server/internal/service/ai/factory.go

package ai

import "health-ai/internal/config"

// NewAIClient 根据配置创建 AI 客户端
func NewAIClient(cfg *config.AIConfig) AIClient {
	switch cfg.Provider {
	case "aliyun":
		return NewAliyunClient(&Config{
			APIKey: cfg.Aliyun.APIKey,
			Model:  cfg.Aliyun.Model,
		})
	case "openai":
		fallthrough
	default:
		return NewOpenAIClient(&Config{
			APIKey:  cfg.OpenAI.APIKey,
			BaseURL: cfg.OpenAI.BaseURL,
			Model:   cfg.OpenAI.Model,
		})
	}
}

步骤 3:创建对话 Repository

创建 server/internal/repository/impl/conversation.go

package impl

import (
	"health-ai/internal/database"
	"health-ai/internal/model"
)

type ConversationRepository struct{}

func NewConversationRepository() *ConversationRepository {
	return &ConversationRepository{}
}

func (r *ConversationRepository) Create(conv *model.Conversation) error {
	return database.DB.Create(conv).Error
}

func (r *ConversationRepository) GetByID(id uint) (*model.Conversation, error) {
	var conv model.Conversation
	err := database.DB.Preload("Messages").First(&conv, id).Error
	return &conv, err
}

func (r *ConversationRepository) GetByUserID(userID uint) ([]model.Conversation, error) {
	var convs []model.Conversation
	err := database.DB.Where("user_id = ?", userID).Order("updated_at DESC").Find(&convs).Error
	return convs, err
}

func (r *ConversationRepository) Delete(id uint) error {
	// 先删除消息
	database.DB.Where("conversation_id = ?", id).Delete(&model.Message{})
	return database.DB.Delete(&model.Conversation{}, id).Error
}

func (r *ConversationRepository) AddMessage(msg *model.Message) error {
	return database.DB.Create(msg).Error
}

func (r *ConversationRepository) GetMessages(convID uint) ([]model.Message, error) {
	var messages []model.Message
	err := database.DB.Where("conversation_id = ?", convID).Order("created_at ASC").Find(&messages).Error
	return messages, err
}

func (r *ConversationRepository) UpdateTitle(id uint, title string) error {
	return database.DB.Model(&model.Conversation{}).Where("id = ?", id).Update("title", title).Error
}

步骤 4:创建对话 Service

创建 server/internal/service/conversation.go

package service

import (
	"context"
	"fmt"
	"io"
	"time"

	"health-ai/internal/config"
	"health-ai/internal/model"
	"health-ai/internal/repository/impl"
	"health-ai/internal/service/ai"
)

type ConversationService struct {
	convRepo         *impl.ConversationRepository
	constitutionRepo *impl.ConstitutionRepository
	healthRepo       *impl.HealthRepository
	aiClient         ai.AIClient
}

func NewConversationService() *ConversationService {
	// 使用工厂方法创建 AI 客户端
	client := ai.NewAIClient(&config.AppConfig.AI)

	return &ConversationService{
		convRepo:         impl.NewConversationRepository(),
		constitutionRepo: impl.NewConstitutionRepository(),
		healthRepo:       impl.NewHealthRepository(),
		aiClient:         client,
	}
}

// 获取用户对话列表
func (s *ConversationService) GetConversations(userID uint) ([]model.Conversation, error) {
	return s.convRepo.GetByUserID(userID)
}

// 创建新对话
func (s *ConversationService) CreateConversation(userID uint, title string) (*model.Conversation, error) {
	if title == "" {
		title = "新对话 " + time.Now().Format("01-02 15:04")
	}
	conv := &model.Conversation{
		UserID: userID,
		Title:  title,
	}
	if err := s.convRepo.Create(conv); err != nil {
		return nil, err
	}
	return conv, nil
}

// 获取对话详情
func (s *ConversationService) GetConversation(id uint) (*model.Conversation, error) {
	return s.convRepo.GetByID(id)
}

// 删除对话
func (s *ConversationService) DeleteConversation(id uint) error {
	return s.convRepo.Delete(id)
}

// 发送消息
func (s *ConversationService) SendMessage(ctx context.Context, userID uint, convID uint, content string) (string, error) {
	// 保存用户消息
	userMsg := &model.Message{
		ConversationID: convID,
		Role:           "user",
		Content:        content,
	}
	if err := s.convRepo.AddMessage(userMsg); err != nil {
		return "", err
	}

	// 构建对话上下文
	messages := s.buildMessages(userID, convID, content)

	// 调用 AI
	response, err := s.aiClient.Chat(ctx, messages)
	if err != nil {
		return "", err
	}

	// 保存 AI 回复
	assistantMsg := &model.Message{
		ConversationID: convID,
		Role:           "assistant",
		Content:        response,
	}
	if err := s.convRepo.AddMessage(assistantMsg); err != nil {
		return "", err
	}

	return response, nil
}

// 流式发送消息
func (s *ConversationService) SendMessageStream(ctx context.Context, userID uint, convID uint, content string, writer io.Writer) error {
	// 保存用户消息
	userMsg := &model.Message{
		ConversationID: convID,
		Role:           "user",
		Content:        content,
	}
	if err := s.convRepo.AddMessage(userMsg); err != nil {
		return err
	}

	// 构建对话上下文
	messages := s.buildMessages(userID, convID, content)

	// 调用 AI 流式接口
	return s.aiClient.ChatStream(ctx, messages, writer)
}

// 构建消息上下文
func (s *ConversationService) buildMessages(userID uint, convID uint, currentMsg string) []ai.Message {
	messages := []ai.Message{}

	// 系统提示词
	systemPrompt := s.buildSystemPrompt(userID)
	messages = append(messages, ai.Message{
		Role:    "system",
		Content: systemPrompt,
	})

	// 历史消息(限制数量避免超出 token 限制)
	historyMsgs, _ := s.convRepo.GetMessages(convID)
	
	// 限制历史消息数量
	maxHistory := config.AppConfig.AI.MaxHistoryMessages
	if maxHistory <= 0 {
		maxHistory = 10 // 默认10条
	}
	if len(historyMsgs) > maxHistory {
		historyMsgs = historyMsgs[len(historyMsgs)-maxHistory:]
	}
	
	for _, msg := range historyMsgs {
		messages = append(messages, ai.Message{
			Role:    msg.Role,
			Content: msg.Content,
		})
	}

	return messages
}

// 系统提示词模板(完整版见 design.md 4.3.1 节)
const systemPromptTemplate = `# 角色定义
你是"健康AI助手",一个专业的健康咨询助理。你基于中医体质辨识理论,为用户提供个性化的健康建议。

## 重要声明
- 你不是专业医师,仅提供健康咨询和养生建议
- 你的建议不能替代医生的诊断和治疗
- 遇到以下情况,必须立即建议用户就医:
  * 胸痛、呼吸困难、剧烈头痛
  * 高烧不退(超过39°C持续24小时)
  * 意识模糊、晕厥
  * 严重外伤、大量出血
  * 持续剧烈腹痛
  * 疑似中风症状(口眼歪斜、肢体无力、言语不清)

## 用户信息
%s

## 用户体质
%s

## 用药历史
%s

## 回答原则
1. 回答控制在200字以内,简洁明了
2. 根据用户体质给出针对性建议
3. 用药建议优先推荐非处方中成药或食疗,注明"建议咨询药师"
4. 不推荐处方药,不做疾病诊断

## 回答格式
【情况分析】一句话概括
【建议】
1. 具体建议
【用药参考】(如适用)
- 药品名称:用法(建议咨询药师)
【提醒】注意事项或就医建议`

// 构建系统提示词(包含用户体质信息)
func (s *ConversationService) buildSystemPrompt(userID uint) string {
	var userProfile, constitutionInfo, medicationHistory string

	// 获取用户基本信息
	profile, err := s.healthRepo.GetProfileByUserID(userID)
	if err == nil && profile.ID > 0 {
		age := calculateAge(profile.BirthDate)
		gender := "未知"
		if profile.Gender == "male" {
			gender = "男"
		} else if profile.Gender == "female" {
			gender = "女"
		}
		userProfile = fmt.Sprintf("性别:%s,年龄:%d岁,BMI:%.1f", gender, age, profile.BMI)
	} else {
		userProfile = "暂无"
	}

	// 获取用户体质信息
	constitution, err := s.constitutionRepo.GetLatestAssessment(userID)
	if err == nil && constitution.ID > 0 {
		constitutionName := model.ConstitutionNames[constitution.PrimaryConstitution]
		description := model.ConstitutionDescriptions[constitution.PrimaryConstitution]
		constitutionInfo = fmt.Sprintf("主体质:%s\n特征:%s", constitutionName, description)
	} else {
		constitutionInfo = "暂未测评"
	}

	// 获取用药历史
	medications, err := s.healthRepo.GetMedications(userID)
	if err == nil && len(medications) > 0 {
		var medNames []string
		for _, m := range medications {
			medNames = append(medNames, m.Name)
		}
		medicationHistory = "近期用药:" + strings.Join(medNames, "、")
	} else {
		medicationHistory = "暂无记录"
	}

	return fmt.Sprintf(systemPromptTemplate, userProfile, constitutionInfo, medicationHistory)

	// 获取用户健康档案
	profile, err := s.healthRepo.GetProfileByUserID(userID)
	if err == nil && profile.ID > 0 {
		basePrompt += fmt.Sprintf(`

用户基本信息:
- 性别:%s
- BMI:%.1f`, profile.Gender, profile.BMI)
	}

	return basePrompt
}

步骤 5:创建对话 Handler

创建 server/internal/api/handler/conversation.go

package handler

import (
	"health-ai/internal/api/middleware"
	"health-ai/internal/service"
	"health-ai/pkg/response"

	"github.com/gin-gonic/gin"
)

type ConversationHandler struct {
	convService *service.ConversationService
}

func NewConversationHandler() *ConversationHandler {
	return &ConversationHandler{
		convService: service.NewConversationService(),
	}
}

func (h *ConversationHandler) GetConversations(c *gin.Context) {
	userID := middleware.GetUserID(c)
	convs, err := h.convService.GetConversations(userID)
	if err != nil {
		response.Error(c, 500, err.Error())
		return
	}
	response.Success(c, convs)
}

func (h *ConversationHandler) CreateConversation(c *gin.Context) {
	userID := middleware.GetUserID(c)
	var req struct {
		Title string `json:"title"`
	}
	c.ShouldBindJSON(&req)

	conv, err := h.convService.CreateConversation(userID, req.Title)
	if err != nil {
		response.Error(c, 500, err.Error())
		return
	}
	response.Success(c, conv)
}

func (h *ConversationHandler) GetConversation(c *gin.Context) {
	id := c.Param("id")
	var convID uint
	fmt.Sscanf(id, "%d", &convID)

	conv, err := h.convService.GetConversation(convID)
	if err != nil {
		response.Error(c, 500, err.Error())
		return
	}
	response.Success(c, conv)
}

func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
	id := c.Param("id")
	var convID uint
	fmt.Sscanf(id, "%d", &convID)

	if err := h.convService.DeleteConversation(convID); err != nil {
		response.Error(c, 500, err.Error())
		return
	}
	response.Success(c, nil)
}

type SendMessageRequest struct {
	Content string `json:"content" binding:"required"`
}

func (h *ConversationHandler) SendMessage(c *gin.Context) {
	userID := middleware.GetUserID(c)
	id := c.Param("id")
	var convID uint
	fmt.Sscanf(id, "%d", &convID)

	var req SendMessageRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		response.Error(c, 400, "参数错误: "+err.Error())
		return
	}

	reply, err := h.convService.SendMessage(c.Request.Context(), userID, convID, req.Content)
	if err != nil {
		response.Error(c, 500, err.Error())
		return
	}

	response.Success(c, gin.H{
		"reply": reply,
	})
}

步骤 6:更新路由

router.go 中添加对话路由。


API 接口说明

GET /api/conversations

获取对话列表

POST /api/conversations

创建新对话

GET /api/conversations/:id

获取对话详情(含消息)

DELETE /api/conversations/:id

删除对话

POST /api/conversations/:id/messages

发送消息并获取 AI 回复


需要创建的文件清单

文件路径 说明
internal/service/ai/client.go AI 客户端接口定义
internal/service/ai/openai.go OpenAI 客户端实现
internal/service/ai/aliyun.go 阿里云通义千问客户端实现
internal/service/ai/factory.go AI 客户端工厂
internal/repository/impl/conversation.go 对话 Repository
internal/service/conversation.go 对话 Service
internal/api/handler/conversation.go 对话 Handler

AI 服务配置说明

config.yaml 中配置 AI 服务:

ai:
  provider: aliyun              # 可选: openai, aliyun
  max_history_messages: 10      # 最大历史消息数
  
  openai:
    api_key: "sk-xxx"           # OpenAI API Key
    base_url: "https://api.openai.com/v1"
    model: "gpt-3.5-turbo"
  
  aliyun:
    api_key: "sk-xxx"           # 阿里云 DashScope API Key
    model: "qwen-turbo"         # 可选: qwen-turbo, qwen-plus, qwen-max

验收标准

  • 创建/获取/删除对话正常
  • 发送消息返回 AI 回复
  • AI 回复结合用户体质和用药历史
  • 对话历史正确保存(限制消息数量)
  • 支持 OpenAI 和阿里云通义千问切换
  • 无 API Key 时给出友好提示
  • 紧急情况提示用户就医

预计耗时

40-50 分钟


下一步

完成后进入 02-后端开发/07-健康档案模块.md