healthapp
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

165 lines
3.7 KiB

package ai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
const AliyunBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
type AliyunClient struct {
apiKey string
model string
}
func NewAliyunClient(cfg *Config) *AliyunClient {
model := cfg.Model
if model == "" {
model = "qwen-turbo"
}
return &AliyunClient{
apiKey: cfg.APIKey,
model: model,
}
}
type aliyunRequest struct {
Model string `json:"model"`
Input struct {
Messages []Message `json:"messages"`
} `json:"input"`
Parameters struct {
ResultFormat string `json:"result_format"`
MaxTokens int `json:"max_tokens,omitempty"`
} `json:"parameters"`
}
type aliyunResponse struct {
Output struct {
Text string `json:"text"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
} `json:"output"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
Code string `json:"code"`
Message string `json:"message"`
}
func (c *AliyunClient) Chat(ctx context.Context, messages []Message) (string, error) {
if c.apiKey == "" {
return "", fmt.Errorf("阿里云通义千问 API Key 未配置,请在 config.yaml 中设置 ai.aliyun.api_key")
}
reqBody := aliyunRequest{
Model: c.model,
}
reqBody.Input.Messages = messages
reqBody.Parameters.ResultFormat = "message"
body, _ := json.Marshal(reqBody)
req, err := http.NewRequestWithContext(ctx, "POST", AliyunBaseURL, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("调用AI服务失败: %v", err)
}
defer resp.Body.Close()
var result aliyunResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("解析AI响应失败: %v", err)
}
if result.Code != "" {
return "", fmt.Errorf("AI服务错误: %s - %s", result.Code, result.Message)
}
// 兼容两种返回格式
if len(result.Output.Choices) > 0 {
return result.Output.Choices[0].Message.Content, nil
}
if result.Output.Text != "" {
return result.Output.Text, nil
}
return "", fmt.Errorf("AI未返回有效响应")
}
func (c *AliyunClient) ChatStream(ctx context.Context, messages []Message, writer io.Writer) error {
if c.apiKey == "" {
return fmt.Errorf("阿里云通义千问 API Key 未配置")
}
reqBody := aliyunRequest{
Model: c.model,
}
reqBody.Input.Messages = messages
reqBody.Parameters.ResultFormat = "message"
body, _ := json.Marshal(reqBody)
req, err := http.NewRequestWithContext(ctx, "POST", AliyunBaseURL, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("X-DashScope-SSE", "enable") // 启用流式输出
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// 解析 SSE 流
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err == io.EOF {
break
}
if err != nil {
return err
}
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
data = strings.TrimSpace(data)
if data == "[DONE]" {
break
}
var streamResp aliyunResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Output.Choices) > 0 {
content := streamResp.Output.Choices[0].Message.Content
writer.Write([]byte(content))
}
}
}
return nil
}