healthapp
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

159 lines
3.3 KiB

package ai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
type OpenAIClient struct {
apiKey string
baseURL string
model string
}
func NewOpenAIClient(cfg *Config) *OpenAIClient {
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
model := cfg.Model
if model == "" {
model = "gpt-3.5-turbo"
}
return &OpenAIClient{
apiKey: cfg.APIKey,
baseURL: baseURL,
model: model,
}
}
type openaiRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
type openaiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error"`
}
func (c *OpenAIClient) Chat(ctx context.Context, messages []Message) (string, error) {
if c.apiKey == "" {
return "", fmt.Errorf("OpenAI API Key 未配置")
}
reqBody := openaiRequest{
Model: c.model,
Messages: messages,
Stream: false,
}
body, _ := json.Marshal(reqBody)
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result openaiResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if result.Error != nil {
return "", fmt.Errorf("OpenAI error: %s", result.Error.Message)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("no response")
}
return result.Choices[0].Message.Content, nil
}
func (c *OpenAIClient) ChatStream(ctx context.Context, messages []Message, writer io.Writer) error {
if c.apiKey == "" {
return fmt.Errorf("OpenAI API Key 未配置")
}
reqBody := openaiRequest{
Model: c.model,
Messages: messages,
Stream: true,
}
body, _ := json.Marshal(reqBody)
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err == io.EOF {
break
}
if err != nil {
return err
}
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data:") {
data := strings.TrimPrefix(line, "data:")
data = strings.TrimSpace(data)
if data == "[DONE]" {
break
}
var streamResp openaiResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" {
sseData := fmt.Sprintf("data: {\"type\":\"content\",\"content\":%s}\n\n", jsonEscape(streamResp.Choices[0].Delta.Content))
writer.Write([]byte(sseData))
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
}
}
return nil
}