Browse Source
- Add CSV export endpoint (GET /ai/usage/export) with UTF-8 BOM for Excel - Implement Wenxin (Baidu ERNIE) provider with Chat + ChatStream SSE - Add wenxin case to provider factory - Seed wenxin provider + ERNIE 4.0/3.5 models - Add Casbin policy for usage export - Add export button to frontend AIUsagePagemaster
6 changed files with 394 additions and 3 deletions
@ -0,0 +1,250 @@ |
|||
package provider |
|||
|
|||
import ( |
|||
"bufio" |
|||
"bytes" |
|||
"context" |
|||
"encoding/json" |
|||
"fmt" |
|||
"io" |
|||
"net/http" |
|||
"strings" |
|||
) |
|||
|
|||
// WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models.
|
|||
// Wenxin uses a unique API format that requires custom HTTP implementation.
|
|||
type WenxinProvider struct { |
|||
baseUrl string |
|||
apiKey string |
|||
} |
|||
|
|||
func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider { |
|||
if baseUrl == "" { |
|||
baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" |
|||
} |
|||
return &WenxinProvider{ |
|||
baseUrl: baseUrl, |
|||
apiKey: apiKey, |
|||
} |
|||
} |
|||
|
|||
func (p *WenxinProvider) Name() string { |
|||
return "wenxin" |
|||
} |
|||
|
|||
// wenxinMessage is Wenxin's message format
|
|||
type wenxinMessage struct { |
|||
Role string `json:"role"` |
|||
Content string `json:"content"` |
|||
} |
|||
|
|||
// wenxinRequest is the request body for Wenxin API
|
|||
type wenxinRequest struct { |
|||
Messages []wenxinMessage `json:"messages"` |
|||
Stream bool `json:"stream"` |
|||
Temperature float64 `json:"temperature,omitempty"` |
|||
MaxOutputTokens int `json:"max_output_tokens,omitempty"` |
|||
} |
|||
|
|||
// wenxinResponse is the non-streaming response from Wenxin API
|
|||
type wenxinResponse struct { |
|||
Id string `json:"id"` |
|||
Result string `json:"result"` |
|||
IsTruncated bool `json:"is_truncated"` |
|||
NeedClearHistory bool `json:"need_clear_history"` |
|||
Usage struct { |
|||
PromptTokens int `json:"prompt_tokens"` |
|||
CompletionTokens int `json:"completion_tokens"` |
|||
TotalTokens int `json:"total_tokens"` |
|||
} `json:"usage"` |
|||
ErrorCode int `json:"error_code"` |
|||
ErrorMsg string `json:"error_msg"` |
|||
} |
|||
|
|||
// wenxinStreamResponse is a single SSE chunk from Wenxin API
|
|||
type wenxinStreamResponse struct { |
|||
Id string `json:"id"` |
|||
Result string `json:"result"` |
|||
IsEnd bool `json:"is_end"` |
|||
IsTruncated bool `json:"is_truncated"` |
|||
NeedClearHistory bool `json:"need_clear_history"` |
|||
Usage struct { |
|||
PromptTokens int `json:"prompt_tokens"` |
|||
CompletionTokens int `json:"completion_tokens"` |
|||
TotalTokens int `json:"total_tokens"` |
|||
} `json:"usage"` |
|||
ErrorCode int `json:"error_code"` |
|||
ErrorMsg string `json:"error_msg"` |
|||
} |
|||
|
|||
func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
|||
wenxinReq := p.buildRequest(req, false) |
|||
|
|||
body, err := json.Marshal(wenxinReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin marshal request: %w", err) |
|||
} |
|||
|
|||
url := p.buildURL(req.Model) |
|||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin create request: %w", err) |
|||
} |
|||
httpReq.Header.Set("Content-Type", "application/json") |
|||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) |
|||
|
|||
resp, err := http.DefaultClient.Do(httpReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin http request: %w", err) |
|||
} |
|||
defer resp.Body.Close() |
|||
|
|||
respBody, err := io.ReadAll(resp.Body) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin read response: %w", err) |
|||
} |
|||
|
|||
var wenxinResp wenxinResponse |
|||
if err := json.Unmarshal(respBody, &wenxinResp); err != nil { |
|||
return nil, fmt.Errorf("wenxin unmarshal response: %w", err) |
|||
} |
|||
|
|||
if wenxinResp.ErrorCode != 0 { |
|||
return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg) |
|||
} |
|||
|
|||
return &ChatResponse{ |
|||
Content: wenxinResp.Result, |
|||
Model: req.Model, |
|||
InputTokens: wenxinResp.Usage.PromptTokens, |
|||
OutputTokens: wenxinResp.Usage.CompletionTokens, |
|||
FinishReason: "stop", |
|||
}, nil |
|||
} |
|||
|
|||
func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
|||
wenxinReq := p.buildRequest(req, true) |
|||
|
|||
body, err := json.Marshal(wenxinReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin marshal request: %w", err) |
|||
} |
|||
|
|||
url := p.buildURL(req.Model) |
|||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin create request: %w", err) |
|||
} |
|||
httpReq.Header.Set("Content-Type", "application/json") |
|||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) |
|||
|
|||
resp, err := http.DefaultClient.Do(httpReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("wenxin http request: %w", err) |
|||
} |
|||
|
|||
if resp.StatusCode != http.StatusOK { |
|||
resp.Body.Close() |
|||
return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode) |
|||
} |
|||
|
|||
ch := make(chan *StreamChunk, 64) |
|||
|
|||
go func() { |
|||
defer close(ch) |
|||
defer resp.Body.Close() |
|||
|
|||
scanner := bufio.NewScanner(resp.Body) |
|||
for scanner.Scan() { |
|||
line := scanner.Text() |
|||
if !strings.HasPrefix(line, "data: ") { |
|||
continue |
|||
} |
|||
data := strings.TrimPrefix(line, "data: ") |
|||
if data == "" { |
|||
continue |
|||
} |
|||
|
|||
var streamResp wenxinStreamResponse |
|||
if err := json.Unmarshal([]byte(data), &streamResp); err != nil { |
|||
continue |
|||
} |
|||
|
|||
if streamResp.ErrorCode != 0 { |
|||
select { |
|||
case ch <- &StreamChunk{ |
|||
Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg), |
|||
Done: true, |
|||
FinishReason: "error", |
|||
}: |
|||
case <-ctx.Done(): |
|||
} |
|||
return |
|||
} |
|||
|
|||
chunk := &StreamChunk{ |
|||
Content: streamResp.Result, |
|||
} |
|||
|
|||
if streamResp.IsEnd { |
|||
chunk.Done = true |
|||
chunk.FinishReason = "stop" |
|||
chunk.InputTokens = streamResp.Usage.PromptTokens |
|||
chunk.OutputTokens = streamResp.Usage.CompletionTokens |
|||
} |
|||
|
|||
select { |
|||
case ch <- chunk: |
|||
case <-ctx.Done(): |
|||
return |
|||
} |
|||
|
|||
if streamResp.IsEnd { |
|||
return |
|||
} |
|||
} |
|||
|
|||
// Ensure a final Done chunk
|
|||
select { |
|||
case ch <- &StreamChunk{Done: true}: |
|||
case <-ctx.Done(): |
|||
} |
|||
}() |
|||
|
|||
return ch, nil |
|||
} |
|||
|
|||
func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest { |
|||
messages := make([]wenxinMessage, 0, len(req.Messages)) |
|||
for _, m := range req.Messages { |
|||
if m.Role == "system" { |
|||
continue // Wenxin doesn't support system role in messages
|
|||
} |
|||
messages = append(messages, wenxinMessage{ |
|||
Role: m.Role, |
|||
Content: m.Content, |
|||
}) |
|||
} |
|||
|
|||
wenxinReq := &wenxinRequest{ |
|||
Messages: messages, |
|||
Stream: stream, |
|||
} |
|||
if req.Temperature > 0 { |
|||
wenxinReq.Temperature = req.Temperature |
|||
} |
|||
if req.MaxTokens > 0 { |
|||
wenxinReq.MaxOutputTokens = req.MaxTokens |
|||
} |
|||
return wenxinReq |
|||
} |
|||
|
|||
func (p *WenxinProvider) buildURL(model string) string { |
|||
// Wenxin uses model-specific endpoints
|
|||
// The baseUrl should include the full path for the model
|
|||
// e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k
|
|||
if strings.Contains(p.baseUrl, model) { |
|||
return p.baseUrl |
|||
} |
|||
return strings.TrimRight(p.baseUrl, "/") + "/" + model |
|||
} |
|||
@ -0,0 +1,26 @@ |
|||
package ai |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/youruser/base/internal/logic/ai" |
|||
"github.com/youruser/base/internal/svc" |
|||
"github.com/youruser/base/internal/types" |
|||
"github.com/zeromicro/go-zero/rest/httpx" |
|||
) |
|||
|
|||
// 导出用量记录CSV
|
|||
func AiUsageExportHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { |
|||
return func(w http.ResponseWriter, r *http.Request) { |
|||
var req types.AIUsageRecordListRequest |
|||
if err := httpx.Parse(r, &req); err != nil { |
|||
httpx.ErrorCtx(r.Context(), w, err) |
|||
return |
|||
} |
|||
|
|||
l := ai.NewAiUsageExportLogic(r.Context(), svcCtx) |
|||
if err := l.AiUsageExport(&req, w); err != nil { |
|||
httpx.ErrorCtx(r.Context(), w, err) |
|||
} |
|||
} |
|||
} |
|||
@ -0,0 +1,95 @@ |
|||
package ai |
|||
|
|||
import ( |
|||
"context" |
|||
"encoding/csv" |
|||
"fmt" |
|||
"net/http" |
|||
"time" |
|||
|
|||
"github.com/youruser/base/internal/svc" |
|||
"github.com/youruser/base/internal/types" |
|||
"github.com/youruser/base/model" |
|||
|
|||
"github.com/zeromicro/go-zero/core/logx" |
|||
) |
|||
|
|||
type AiUsageExportLogic struct { |
|||
logx.Logger |
|||
ctx context.Context |
|||
svcCtx *svc.ServiceContext |
|||
} |
|||
|
|||
func NewAiUsageExportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiUsageExportLogic { |
|||
return &AiUsageExportLogic{ |
|||
Logger: logx.WithContext(ctx), |
|||
ctx: ctx, |
|||
svcCtx: svcCtx, |
|||
} |
|||
} |
|||
|
|||
func (l *AiUsageExportLogic) AiUsageExport(req *types.AIUsageRecordListRequest, w http.ResponseWriter) error { |
|||
// Get current user from context
|
|||
userId := req.UserId |
|||
if userId == 0 { |
|||
if uid, ok := l.ctx.Value("userId").(int64); ok { |
|||
userId = uid |
|||
} else if uidJson, ok2 := l.ctx.Value("userId").(float64); ok2 { |
|||
userId = int64(uidJson) |
|||
} |
|||
} |
|||
|
|||
// Fetch all records (large page size for export)
|
|||
records, _, err := model.AIUsageRecordFindList(l.ctx, l.svcCtx.DB, userId, req.ModelId, req.Status, 1, 10000) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
// Build lookup caches
|
|||
userMap := make(map[int64]string) |
|||
providerMap := make(map[int64]string) |
|||
for _, r := range records { |
|||
if _, ok := userMap[r.UserId]; !ok { |
|||
user, err := model.FindOne(l.ctx, l.svcCtx.DB, r.UserId) |
|||
if err == nil { |
|||
userMap[r.UserId] = user.Username |
|||
} |
|||
} |
|||
if _, ok := providerMap[r.ProviderId]; !ok { |
|||
p, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, r.ProviderId) |
|||
if err == nil { |
|||
providerMap[r.ProviderId] = p.DisplayName |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Write CSV response
|
|||
filename := fmt.Sprintf("ai-usage-%s.csv", time.Now().Format("20060102")) |
|||
w.Header().Set("Content-Type", "text/csv; charset=utf-8") |
|||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) |
|||
// BOM for Excel UTF-8 support
|
|||
w.Write([]byte{0xEF, 0xBB, 0xBF}) |
|||
|
|||
writer := csv.NewWriter(w) |
|||
defer writer.Flush() |
|||
|
|||
// Header row
|
|||
writer.Write([]string{"时间", "用户", "平台", "模型", "输入Tokens", "输出Tokens", "费用(¥)", "延迟(ms)", "状态", "错误信息"}) |
|||
|
|||
for _, r := range records { |
|||
writer.Write([]string{ |
|||
r.CreatedAt.Format("2006-01-02 15:04:05"), |
|||
userMap[r.UserId], |
|||
providerMap[r.ProviderId], |
|||
r.ModelId, |
|||
fmt.Sprintf("%d", r.InputTokens), |
|||
fmt.Sprintf("%d", r.OutputTokens), |
|||
fmt.Sprintf("%.4f", r.Cost), |
|||
fmt.Sprintf("%d", r.LatencyMs), |
|||
r.Status, |
|||
r.ErrorMessage, |
|||
}) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
Loading…
Reference in new issue