Browse Source
- Add CSV export endpoint (GET /ai/usage/export) with UTF-8 BOM for Excel - Implement Wenxin (Baidu ERNIE) provider with Chat + ChatStream SSE - Add wenxin case to provider factory - Seed wenxin provider + ERNIE 4.0/3.5 models - Add Casbin policy for usage export - Add export button to frontend AIUsagePagemaster
6 changed files with 394 additions and 3 deletions
@ -0,0 +1,250 @@ |
|||||
|
package provider |
||||
|
|
||||
|
import ( |
||||
|
"bufio" |
||||
|
"bytes" |
||||
|
"context" |
||||
|
"encoding/json" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"net/http" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models.
|
||||
|
// Wenxin uses a unique API format that requires custom HTTP implementation.
|
||||
|
type WenxinProvider struct { |
||||
|
baseUrl string |
||||
|
apiKey string |
||||
|
} |
||||
|
|
||||
|
func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider { |
||||
|
if baseUrl == "" { |
||||
|
baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" |
||||
|
} |
||||
|
return &WenxinProvider{ |
||||
|
baseUrl: baseUrl, |
||||
|
apiKey: apiKey, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (p *WenxinProvider) Name() string { |
||||
|
return "wenxin" |
||||
|
} |
||||
|
|
||||
|
// wenxinMessage is Wenxin's message format
|
||||
|
type wenxinMessage struct { |
||||
|
Role string `json:"role"` |
||||
|
Content string `json:"content"` |
||||
|
} |
||||
|
|
||||
|
// wenxinRequest is the request body for Wenxin API
|
||||
|
type wenxinRequest struct { |
||||
|
Messages []wenxinMessage `json:"messages"` |
||||
|
Stream bool `json:"stream"` |
||||
|
Temperature float64 `json:"temperature,omitempty"` |
||||
|
MaxOutputTokens int `json:"max_output_tokens,omitempty"` |
||||
|
} |
||||
|
|
||||
|
// wenxinResponse is the non-streaming response from Wenxin API
|
||||
|
type wenxinResponse struct { |
||||
|
Id string `json:"id"` |
||||
|
Result string `json:"result"` |
||||
|
IsTruncated bool `json:"is_truncated"` |
||||
|
NeedClearHistory bool `json:"need_clear_history"` |
||||
|
Usage struct { |
||||
|
PromptTokens int `json:"prompt_tokens"` |
||||
|
CompletionTokens int `json:"completion_tokens"` |
||||
|
TotalTokens int `json:"total_tokens"` |
||||
|
} `json:"usage"` |
||||
|
ErrorCode int `json:"error_code"` |
||||
|
ErrorMsg string `json:"error_msg"` |
||||
|
} |
||||
|
|
||||
|
// wenxinStreamResponse is a single SSE chunk from Wenxin API
|
||||
|
type wenxinStreamResponse struct { |
||||
|
Id string `json:"id"` |
||||
|
Result string `json:"result"` |
||||
|
IsEnd bool `json:"is_end"` |
||||
|
IsTruncated bool `json:"is_truncated"` |
||||
|
NeedClearHistory bool `json:"need_clear_history"` |
||||
|
Usage struct { |
||||
|
PromptTokens int `json:"prompt_tokens"` |
||||
|
CompletionTokens int `json:"completion_tokens"` |
||||
|
TotalTokens int `json:"total_tokens"` |
||||
|
} `json:"usage"` |
||||
|
ErrorCode int `json:"error_code"` |
||||
|
ErrorMsg string `json:"error_msg"` |
||||
|
} |
||||
|
|
||||
|
func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
||||
|
wenxinReq := p.buildRequest(req, false) |
||||
|
|
||||
|
body, err := json.Marshal(wenxinReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin marshal request: %w", err) |
||||
|
} |
||||
|
|
||||
|
url := p.buildURL(req.Model) |
||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin create request: %w", err) |
||||
|
} |
||||
|
httpReq.Header.Set("Content-Type", "application/json") |
||||
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) |
||||
|
|
||||
|
resp, err := http.DefaultClient.Do(httpReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin http request: %w", err) |
||||
|
} |
||||
|
defer resp.Body.Close() |
||||
|
|
||||
|
respBody, err := io.ReadAll(resp.Body) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin read response: %w", err) |
||||
|
} |
||||
|
|
||||
|
var wenxinResp wenxinResponse |
||||
|
if err := json.Unmarshal(respBody, &wenxinResp); err != nil { |
||||
|
return nil, fmt.Errorf("wenxin unmarshal response: %w", err) |
||||
|
} |
||||
|
|
||||
|
if wenxinResp.ErrorCode != 0 { |
||||
|
return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg) |
||||
|
} |
||||
|
|
||||
|
return &ChatResponse{ |
||||
|
Content: wenxinResp.Result, |
||||
|
Model: req.Model, |
||||
|
InputTokens: wenxinResp.Usage.PromptTokens, |
||||
|
OutputTokens: wenxinResp.Usage.CompletionTokens, |
||||
|
FinishReason: "stop", |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
||||
|
wenxinReq := p.buildRequest(req, true) |
||||
|
|
||||
|
body, err := json.Marshal(wenxinReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin marshal request: %w", err) |
||||
|
} |
||||
|
|
||||
|
url := p.buildURL(req.Model) |
||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin create request: %w", err) |
||||
|
} |
||||
|
httpReq.Header.Set("Content-Type", "application/json") |
||||
|
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) |
||||
|
|
||||
|
resp, err := http.DefaultClient.Do(httpReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("wenxin http request: %w", err) |
||||
|
} |
||||
|
|
||||
|
if resp.StatusCode != http.StatusOK { |
||||
|
resp.Body.Close() |
||||
|
return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode) |
||||
|
} |
||||
|
|
||||
|
ch := make(chan *StreamChunk, 64) |
||||
|
|
||||
|
go func() { |
||||
|
defer close(ch) |
||||
|
defer resp.Body.Close() |
||||
|
|
||||
|
scanner := bufio.NewScanner(resp.Body) |
||||
|
for scanner.Scan() { |
||||
|
line := scanner.Text() |
||||
|
if !strings.HasPrefix(line, "data: ") { |
||||
|
continue |
||||
|
} |
||||
|
data := strings.TrimPrefix(line, "data: ") |
||||
|
if data == "" { |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
var streamResp wenxinStreamResponse |
||||
|
if err := json.Unmarshal([]byte(data), &streamResp); err != nil { |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
if streamResp.ErrorCode != 0 { |
||||
|
select { |
||||
|
case ch <- &StreamChunk{ |
||||
|
Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg), |
||||
|
Done: true, |
||||
|
FinishReason: "error", |
||||
|
}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
chunk := &StreamChunk{ |
||||
|
Content: streamResp.Result, |
||||
|
} |
||||
|
|
||||
|
if streamResp.IsEnd { |
||||
|
chunk.Done = true |
||||
|
chunk.FinishReason = "stop" |
||||
|
chunk.InputTokens = streamResp.Usage.PromptTokens |
||||
|
chunk.OutputTokens = streamResp.Usage.CompletionTokens |
||||
|
} |
||||
|
|
||||
|
select { |
||||
|
case ch <- chunk: |
||||
|
case <-ctx.Done(): |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
if streamResp.IsEnd { |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Ensure a final Done chunk
|
||||
|
select { |
||||
|
case ch <- &StreamChunk{Done: true}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
}() |
||||
|
|
||||
|
return ch, nil |
||||
|
} |
||||
|
|
||||
|
func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest { |
||||
|
messages := make([]wenxinMessage, 0, len(req.Messages)) |
||||
|
for _, m := range req.Messages { |
||||
|
if m.Role == "system" { |
||||
|
continue // Wenxin doesn't support system role in messages
|
||||
|
} |
||||
|
messages = append(messages, wenxinMessage{ |
||||
|
Role: m.Role, |
||||
|
Content: m.Content, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
wenxinReq := &wenxinRequest{ |
||||
|
Messages: messages, |
||||
|
Stream: stream, |
||||
|
} |
||||
|
if req.Temperature > 0 { |
||||
|
wenxinReq.Temperature = req.Temperature |
||||
|
} |
||||
|
if req.MaxTokens > 0 { |
||||
|
wenxinReq.MaxOutputTokens = req.MaxTokens |
||||
|
} |
||||
|
return wenxinReq |
||||
|
} |
||||
|
|
||||
|
func (p *WenxinProvider) buildURL(model string) string { |
||||
|
// Wenxin uses model-specific endpoints
|
||||
|
// The baseUrl should include the full path for the model
|
||||
|
// e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k
|
||||
|
if strings.Contains(p.baseUrl, model) { |
||||
|
return p.baseUrl |
||||
|
} |
||||
|
return strings.TrimRight(p.baseUrl, "/") + "/" + model |
||||
|
} |
||||
@ -0,0 +1,26 @@ |
|||||
|
package ai |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
|
||||
|
"github.com/youruser/base/internal/logic/ai" |
||||
|
"github.com/youruser/base/internal/svc" |
||||
|
"github.com/youruser/base/internal/types" |
||||
|
"github.com/zeromicro/go-zero/rest/httpx" |
||||
|
) |
||||
|
|
||||
|
// 导出用量记录CSV
|
||||
|
func AiUsageExportHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { |
||||
|
return func(w http.ResponseWriter, r *http.Request) { |
||||
|
var req types.AIUsageRecordListRequest |
||||
|
if err := httpx.Parse(r, &req); err != nil { |
||||
|
httpx.ErrorCtx(r.Context(), w, err) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
l := ai.NewAiUsageExportLogic(r.Context(), svcCtx) |
||||
|
if err := l.AiUsageExport(&req, w); err != nil { |
||||
|
httpx.ErrorCtx(r.Context(), w, err) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,95 @@ |
|||||
|
package ai |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"encoding/csv" |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/youruser/base/internal/svc" |
||||
|
"github.com/youruser/base/internal/types" |
||||
|
"github.com/youruser/base/model" |
||||
|
|
||||
|
"github.com/zeromicro/go-zero/core/logx" |
||||
|
) |
||||
|
|
||||
|
type AiUsageExportLogic struct { |
||||
|
logx.Logger |
||||
|
ctx context.Context |
||||
|
svcCtx *svc.ServiceContext |
||||
|
} |
||||
|
|
||||
|
func NewAiUsageExportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiUsageExportLogic { |
||||
|
return &AiUsageExportLogic{ |
||||
|
Logger: logx.WithContext(ctx), |
||||
|
ctx: ctx, |
||||
|
svcCtx: svcCtx, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *AiUsageExportLogic) AiUsageExport(req *types.AIUsageRecordListRequest, w http.ResponseWriter) error { |
||||
|
// Get current user from context
|
||||
|
userId := req.UserId |
||||
|
if userId == 0 { |
||||
|
if uid, ok := l.ctx.Value("userId").(int64); ok { |
||||
|
userId = uid |
||||
|
} else if uidJson, ok2 := l.ctx.Value("userId").(float64); ok2 { |
||||
|
userId = int64(uidJson) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Fetch all records (large page size for export)
|
||||
|
records, _, err := model.AIUsageRecordFindList(l.ctx, l.svcCtx.DB, userId, req.ModelId, req.Status, 1, 10000) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// Build lookup caches
|
||||
|
userMap := make(map[int64]string) |
||||
|
providerMap := make(map[int64]string) |
||||
|
for _, r := range records { |
||||
|
if _, ok := userMap[r.UserId]; !ok { |
||||
|
user, err := model.FindOne(l.ctx, l.svcCtx.DB, r.UserId) |
||||
|
if err == nil { |
||||
|
userMap[r.UserId] = user.Username |
||||
|
} |
||||
|
} |
||||
|
if _, ok := providerMap[r.ProviderId]; !ok { |
||||
|
p, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, r.ProviderId) |
||||
|
if err == nil { |
||||
|
providerMap[r.ProviderId] = p.DisplayName |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Write CSV response
|
||||
|
filename := fmt.Sprintf("ai-usage-%s.csv", time.Now().Format("20060102")) |
||||
|
w.Header().Set("Content-Type", "text/csv; charset=utf-8") |
||||
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) |
||||
|
// BOM for Excel UTF-8 support
|
||||
|
w.Write([]byte{0xEF, 0xBB, 0xBF}) |
||||
|
|
||||
|
writer := csv.NewWriter(w) |
||||
|
defer writer.Flush() |
||||
|
|
||||
|
// Header row
|
||||
|
writer.Write([]string{"时间", "用户", "平台", "模型", "输入Tokens", "输出Tokens", "费用(¥)", "延迟(ms)", "状态", "错误信息"}) |
||||
|
|
||||
|
for _, r := range records { |
||||
|
writer.Write([]string{ |
||||
|
r.CreatedAt.Format("2006-01-02 15:04:05"), |
||||
|
userMap[r.UserId], |
||||
|
providerMap[r.ProviderId], |
||||
|
r.ModelId, |
||||
|
fmt.Sprintf("%d", r.InputTokens), |
||||
|
fmt.Sprintf("%d", r.OutputTokens), |
||||
|
fmt.Sprintf("%.4f", r.Cost), |
||||
|
fmt.Sprintf("%d", r.LatencyMs), |
||||
|
r.Status, |
||||
|
r.ErrorMessage, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
Loading…
Reference in new issue