You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
6.4 KiB
250 lines
6.4 KiB
package provider
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models.
|
|
// Wenxin uses a unique API format that requires custom HTTP implementation.
|
|
type WenxinProvider struct {
|
|
baseUrl string
|
|
apiKey string
|
|
}
|
|
|
|
func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider {
|
|
if baseUrl == "" {
|
|
baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
|
|
}
|
|
return &WenxinProvider{
|
|
baseUrl: baseUrl,
|
|
apiKey: apiKey,
|
|
}
|
|
}
|
|
|
|
func (p *WenxinProvider) Name() string {
|
|
return "wenxin"
|
|
}
|
|
|
|
// wenxinMessage is Wenxin's message format
|
|
type wenxinMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// wenxinRequest is the request body for Wenxin API
|
|
type wenxinRequest struct {
|
|
Messages []wenxinMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
|
}
|
|
|
|
// wenxinResponse is the non-streaming response from Wenxin API
|
|
type wenxinResponse struct {
|
|
Id string `json:"id"`
|
|
Result string `json:"result"`
|
|
IsTruncated bool `json:"is_truncated"`
|
|
NeedClearHistory bool `json:"need_clear_history"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
ErrorCode int `json:"error_code"`
|
|
ErrorMsg string `json:"error_msg"`
|
|
}
|
|
|
|
// wenxinStreamResponse is a single SSE chunk from Wenxin API
|
|
type wenxinStreamResponse struct {
|
|
Id string `json:"id"`
|
|
Result string `json:"result"`
|
|
IsEnd bool `json:"is_end"`
|
|
IsTruncated bool `json:"is_truncated"`
|
|
NeedClearHistory bool `json:"need_clear_history"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
ErrorCode int `json:"error_code"`
|
|
ErrorMsg string `json:"error_msg"`
|
|
}
|
|
|
|
func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
|
wenxinReq := p.buildRequest(req, false)
|
|
|
|
body, err := json.Marshal(wenxinReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin marshal request: %w", err)
|
|
}
|
|
|
|
url := p.buildURL(req.Model)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
|
|
resp, err := http.DefaultClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin http request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin read response: %w", err)
|
|
}
|
|
|
|
var wenxinResp wenxinResponse
|
|
if err := json.Unmarshal(respBody, &wenxinResp); err != nil {
|
|
return nil, fmt.Errorf("wenxin unmarshal response: %w", err)
|
|
}
|
|
|
|
if wenxinResp.ErrorCode != 0 {
|
|
return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg)
|
|
}
|
|
|
|
return &ChatResponse{
|
|
Content: wenxinResp.Result,
|
|
Model: req.Model,
|
|
InputTokens: wenxinResp.Usage.PromptTokens,
|
|
OutputTokens: wenxinResp.Usage.CompletionTokens,
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) {
|
|
wenxinReq := p.buildRequest(req, true)
|
|
|
|
body, err := json.Marshal(wenxinReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin marshal request: %w", err)
|
|
}
|
|
|
|
url := p.buildURL(req.Model)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
|
|
resp, err := http.DefaultClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wenxin http request: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
resp.Body.Close()
|
|
return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
ch := make(chan *StreamChunk, 64)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
defer resp.Body.Close()
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "" {
|
|
continue
|
|
}
|
|
|
|
var streamResp wenxinStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if streamResp.ErrorCode != 0 {
|
|
select {
|
|
case ch <- &StreamChunk{
|
|
Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg),
|
|
Done: true,
|
|
FinishReason: "error",
|
|
}:
|
|
case <-ctx.Done():
|
|
}
|
|
return
|
|
}
|
|
|
|
chunk := &StreamChunk{
|
|
Content: streamResp.Result,
|
|
}
|
|
|
|
if streamResp.IsEnd {
|
|
chunk.Done = true
|
|
chunk.FinishReason = "stop"
|
|
chunk.InputTokens = streamResp.Usage.PromptTokens
|
|
chunk.OutputTokens = streamResp.Usage.CompletionTokens
|
|
}
|
|
|
|
select {
|
|
case ch <- chunk:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
|
|
if streamResp.IsEnd {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Ensure a final Done chunk
|
|
select {
|
|
case ch <- &StreamChunk{Done: true}:
|
|
case <-ctx.Done():
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest {
|
|
messages := make([]wenxinMessage, 0, len(req.Messages))
|
|
for _, m := range req.Messages {
|
|
if m.Role == "system" {
|
|
continue // Wenxin doesn't support system role in messages
|
|
}
|
|
messages = append(messages, wenxinMessage{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
})
|
|
}
|
|
|
|
wenxinReq := &wenxinRequest{
|
|
Messages: messages,
|
|
Stream: stream,
|
|
}
|
|
if req.Temperature > 0 {
|
|
wenxinReq.Temperature = req.Temperature
|
|
}
|
|
if req.MaxTokens > 0 {
|
|
wenxinReq.MaxOutputTokens = req.MaxTokens
|
|
}
|
|
return wenxinReq
|
|
}
|
|
|
|
func (p *WenxinProvider) buildURL(model string) string {
|
|
// Wenxin uses model-specific endpoints
|
|
// The baseUrl should include the full path for the model
|
|
// e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k
|
|
if strings.Contains(p.baseUrl, model) {
|
|
return p.baseUrl
|
|
}
|
|
return strings.TrimRight(p.baseUrl, "/") + "/" + model
|
|
}
|
|
|