package provider import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" ) // WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models. // Wenxin uses a unique API format that requires custom HTTP implementation. type WenxinProvider struct { baseUrl string apiKey string } func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider { if baseUrl == "" { baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" } return &WenxinProvider{ baseUrl: baseUrl, apiKey: apiKey, } } func (p *WenxinProvider) Name() string { return "wenxin" } // wenxinMessage is Wenxin's message format type wenxinMessage struct { Role string `json:"role"` Content string `json:"content"` } // wenxinRequest is the request body for Wenxin API type wenxinRequest struct { Messages []wenxinMessage `json:"messages"` Stream bool `json:"stream"` Temperature float64 `json:"temperature,omitempty"` MaxOutputTokens int `json:"max_output_tokens,omitempty"` } // wenxinResponse is the non-streaming response from Wenxin API type wenxinResponse struct { Id string `json:"id"` Result string `json:"result"` IsTruncated bool `json:"is_truncated"` NeedClearHistory bool `json:"need_clear_history"` Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` ErrorCode int `json:"error_code"` ErrorMsg string `json:"error_msg"` } // wenxinStreamResponse is a single SSE chunk from Wenxin API type wenxinStreamResponse struct { Id string `json:"id"` Result string `json:"result"` IsEnd bool `json:"is_end"` IsTruncated bool `json:"is_truncated"` NeedClearHistory bool `json:"need_clear_history"` Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` ErrorCode int `json:"error_code"` ErrorMsg string `json:"error_msg"` } func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { wenxinReq := p.buildRequest(req, false) body, err := json.Marshal(wenxinReq) if err != nil { return nil, fmt.Errorf("wenxin marshal request: %w", err) } url := p.buildURL(req.Model) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("wenxin create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) resp, err := http.DefaultClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("wenxin http request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("wenxin read response: %w", err) } var wenxinResp wenxinResponse if err := json.Unmarshal(respBody, &wenxinResp); err != nil { return nil, fmt.Errorf("wenxin unmarshal response: %w", err) } if wenxinResp.ErrorCode != 0 { return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg) } return &ChatResponse{ Content: wenxinResp.Result, Model: req.Model, InputTokens: wenxinResp.Usage.PromptTokens, OutputTokens: wenxinResp.Usage.CompletionTokens, FinishReason: "stop", }, nil } func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { wenxinReq := p.buildRequest(req, true) body, err := json.Marshal(wenxinReq) if err != nil { return nil, fmt.Errorf("wenxin marshal request: %w", err) } url := p.buildURL(req.Model) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("wenxin create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) resp, err := http.DefaultClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("wenxin http request: %w", err) } if resp.StatusCode != http.StatusOK { resp.Body.Close() return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode) } ch := make(chan *StreamChunk, 64) go func() { defer close(ch) defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "" { continue } var streamResp wenxinStreamResponse if err := json.Unmarshal([]byte(data), &streamResp); err != nil { continue } if streamResp.ErrorCode != 0 { select { case ch <- &StreamChunk{ Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg), Done: true, FinishReason: "error", }: case <-ctx.Done(): } return } chunk := &StreamChunk{ Content: streamResp.Result, } if streamResp.IsEnd { chunk.Done = true chunk.FinishReason = "stop" chunk.InputTokens = streamResp.Usage.PromptTokens chunk.OutputTokens = streamResp.Usage.CompletionTokens } select { case ch <- chunk: case <-ctx.Done(): return } if streamResp.IsEnd { return } } // Ensure a final Done chunk select { case ch <- &StreamChunk{Done: true}: case <-ctx.Done(): } }() return ch, nil } func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest { messages := make([]wenxinMessage, 0, len(req.Messages)) for _, m := range req.Messages { if m.Role == "system" { continue // Wenxin doesn't support system role in messages } messages = append(messages, wenxinMessage{ Role: m.Role, Content: m.Content, }) } wenxinReq := &wenxinRequest{ Messages: messages, Stream: stream, } if req.Temperature > 0 { wenxinReq.Temperature = req.Temperature } if req.MaxTokens > 0 { wenxinReq.MaxOutputTokens = req.MaxTokens } return wenxinReq } func (p *WenxinProvider) buildURL(model string) string { // Wenxin uses model-specific endpoints // The baseUrl should include the full path for the model // e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k if strings.Contains(p.baseUrl, model) { return p.baseUrl } return strings.TrimRight(p.baseUrl, "/") + "/" + model }