You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

250 lines
6.4 KiB

package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models.
// Wenxin uses a unique API format that requires custom HTTP implementation.
type WenxinProvider struct {
baseUrl string
apiKey string
}
func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider {
if baseUrl == "" {
baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
}
return &WenxinProvider{
baseUrl: baseUrl,
apiKey: apiKey,
}
}
func (p *WenxinProvider) Name() string {
return "wenxin"
}
// wenxinMessage is Wenxin's message format
type wenxinMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// wenxinRequest is the request body for Wenxin API
type wenxinRequest struct {
Messages []wenxinMessage `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
}
// wenxinResponse is the non-streaming response from Wenxin API
type wenxinResponse struct {
Id string `json:"id"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
// wenxinStreamResponse is a single SSE chunk from Wenxin API
type wenxinStreamResponse struct {
Id string `json:"id"`
Result string `json:"result"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
wenxinReq := p.buildRequest(req, false)
body, err := json.Marshal(wenxinReq)
if err != nil {
return nil, fmt.Errorf("wenxin marshal request: %w", err)
}
url := p.buildURL(req.Model)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("wenxin create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("wenxin http request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("wenxin read response: %w", err)
}
var wenxinResp wenxinResponse
if err := json.Unmarshal(respBody, &wenxinResp); err != nil {
return nil, fmt.Errorf("wenxin unmarshal response: %w", err)
}
if wenxinResp.ErrorCode != 0 {
return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg)
}
return &ChatResponse{
Content: wenxinResp.Result,
Model: req.Model,
InputTokens: wenxinResp.Usage.PromptTokens,
OutputTokens: wenxinResp.Usage.CompletionTokens,
FinishReason: "stop",
}, nil
}
func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) {
wenxinReq := p.buildRequest(req, true)
body, err := json.Marshal(wenxinReq)
if err != nil {
return nil, fmt.Errorf("wenxin marshal request: %w", err)
}
url := p.buildURL(req.Model)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("wenxin create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("wenxin http request: %w", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode)
}
ch := make(chan *StreamChunk, 64)
go func() {
defer close(ch)
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "" {
continue
}
var streamResp wenxinStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if streamResp.ErrorCode != 0 {
select {
case ch <- &StreamChunk{
Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg),
Done: true,
FinishReason: "error",
}:
case <-ctx.Done():
}
return
}
chunk := &StreamChunk{
Content: streamResp.Result,
}
if streamResp.IsEnd {
chunk.Done = true
chunk.FinishReason = "stop"
chunk.InputTokens = streamResp.Usage.PromptTokens
chunk.OutputTokens = streamResp.Usage.CompletionTokens
}
select {
case ch <- chunk:
case <-ctx.Done():
return
}
if streamResp.IsEnd {
return
}
}
// Ensure a final Done chunk
select {
case ch <- &StreamChunk{Done: true}:
case <-ctx.Done():
}
}()
return ch, nil
}
func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest {
messages := make([]wenxinMessage, 0, len(req.Messages))
for _, m := range req.Messages {
if m.Role == "system" {
continue // Wenxin doesn't support system role in messages
}
messages = append(messages, wenxinMessage{
Role: m.Role,
Content: m.Content,
})
}
wenxinReq := &wenxinRequest{
Messages: messages,
Stream: stream,
}
if req.Temperature > 0 {
wenxinReq.Temperature = req.Temperature
}
if req.MaxTokens > 0 {
wenxinReq.MaxOutputTokens = req.MaxTokens
}
return wenxinReq
}
func (p *WenxinProvider) buildURL(model string) string {
// Wenxin uses model-specific endpoints
// The baseUrl should include the full path for the model
// e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k
if strings.Contains(p.baseUrl, model) {
return p.baseUrl
}
return strings.TrimRight(p.baseUrl, "/") + "/" + model
}