diff --git a/backend/internal/ai/provider/factory.go b/backend/internal/ai/provider/factory.go index bae43f5..f7b85c1 100644 --- a/backend/internal/ai/provider/factory.go +++ b/backend/internal/ai/provider/factory.go @@ -6,12 +6,15 @@ import "fmt" // Supported sdkType values: // - "openai_compat": OpenAI-compatible APIs (OpenAI, Qwen, Zhipu, DeepSeek, etc.) // - "anthropic": Anthropic Claude models +// - "wenxin": Baidu Wenxin (ERNIE) models func NewProvider(sdkType, baseUrl, apiKey string) (AIProvider, error) { switch sdkType { case "openai_compat": return NewOpenAIProvider(baseUrl, apiKey), nil case "anthropic": return NewAnthropicProvider(baseUrl, apiKey), nil + case "wenxin": + return NewWenxinProvider(baseUrl, apiKey), nil default: return nil, fmt.Errorf("unsupported sdk_type: %s", sdkType) } diff --git a/backend/internal/ai/provider/wenxin.go b/backend/internal/ai/provider/wenxin.go new file mode 100644 index 0000000..5be38be --- /dev/null +++ b/backend/internal/ai/provider/wenxin.go @@ -0,0 +1,250 @@ +package provider + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// WenxinProvider implements AIProvider for Baidu Wenxin (ERNIE) models. +// Wenxin uses a unique API format that requires custom HTTP implementation. +type WenxinProvider struct { + baseUrl string + apiKey string +} + +func NewWenxinProvider(baseUrl, apiKey string) *WenxinProvider { + if baseUrl == "" { + baseUrl = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" + } + return &WenxinProvider{ + baseUrl: baseUrl, + apiKey: apiKey, + } +} + +func (p *WenxinProvider) Name() string { + return "wenxin" +} + +// wenxinMessage is Wenxin's message format +type wenxinMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// wenxinRequest is the request body for Wenxin API +type wenxinRequest struct { + Messages []wenxinMessage `json:"messages"` + Stream bool `json:"stream"` + Temperature float64 `json:"temperature,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` +} + +// wenxinResponse is the non-streaming response from Wenxin API +type wenxinResponse struct { + Id string `json:"id"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +// wenxinStreamResponse is a single SSE chunk from Wenxin API +type wenxinStreamResponse struct { + Id string `json:"id"` + Result string `json:"result"` + IsEnd bool `json:"is_end"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +func (p *WenxinProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + wenxinReq := p.buildRequest(req, false) + + body, err := json.Marshal(wenxinReq) + if err != nil { + return nil, fmt.Errorf("wenxin marshal request: %w", err) + } + + url := p.buildURL(req.Model) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("wenxin create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("wenxin http request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("wenxin read response: %w", err) + } + + var wenxinResp wenxinResponse + if err := json.Unmarshal(respBody, &wenxinResp); err != nil { + return nil, fmt.Errorf("wenxin unmarshal response: %w", err) + } + + if wenxinResp.ErrorCode != 0 { + return nil, fmt.Errorf("wenxin error %d: %s", wenxinResp.ErrorCode, wenxinResp.ErrorMsg) + } + + return &ChatResponse{ + Content: wenxinResp.Result, + Model: req.Model, + InputTokens: wenxinResp.Usage.PromptTokens, + OutputTokens: wenxinResp.Usage.CompletionTokens, + FinishReason: "stop", + }, nil +} + +func (p *WenxinProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { + wenxinReq := p.buildRequest(req, true) + + body, err := json.Marshal(wenxinReq) + if err != nil { + return nil, fmt.Errorf("wenxin marshal request: %w", err) + } + + url := p.buildURL(req.Model) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("wenxin create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("wenxin http request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("wenxin stream returned status %d", resp.StatusCode) + } + + ch := make(chan *StreamChunk, 64) + + go func() { + defer close(ch) + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "" { + continue + } + + var streamResp wenxinStreamResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + continue + } + + if streamResp.ErrorCode != 0 { + select { + case ch <- &StreamChunk{ + Content: fmt.Sprintf("[wenxin error: %s]", streamResp.ErrorMsg), + Done: true, + FinishReason: "error", + }: + case <-ctx.Done(): + } + return + } + + chunk := &StreamChunk{ + Content: streamResp.Result, + } + + if streamResp.IsEnd { + chunk.Done = true + chunk.FinishReason = "stop" + chunk.InputTokens = streamResp.Usage.PromptTokens + chunk.OutputTokens = streamResp.Usage.CompletionTokens + } + + select { + case ch <- chunk: + case <-ctx.Done(): + return + } + + if streamResp.IsEnd { + return + } + } + + // Ensure a final Done chunk + select { + case ch <- &StreamChunk{Done: true}: + case <-ctx.Done(): + } + }() + + return ch, nil +} + +func (p *WenxinProvider) buildRequest(req *ChatRequest, stream bool) *wenxinRequest { + messages := make([]wenxinMessage, 0, len(req.Messages)) + for _, m := range req.Messages { + if m.Role == "system" { + continue // Wenxin doesn't support system role in messages + } + messages = append(messages, wenxinMessage{ + Role: m.Role, + Content: m.Content, + }) + } + + wenxinReq := &wenxinRequest{ + Messages: messages, + Stream: stream, + } + if req.Temperature > 0 { + wenxinReq.Temperature = req.Temperature + } + if req.MaxTokens > 0 { + wenxinReq.MaxOutputTokens = req.MaxTokens + } + return wenxinReq +} + +func (p *WenxinProvider) buildURL(model string) string { + // Wenxin uses model-specific endpoints + // The baseUrl should include the full path for the model + // e.g., https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k + if strings.Contains(p.baseUrl, model) { + return p.baseUrl + } + return strings.TrimRight(p.baseUrl, "/") + "/" + model +} diff --git a/backend/internal/handler/ai/aiusageexporthandler.go b/backend/internal/handler/ai/aiusageexporthandler.go new file mode 100644 index 0000000..c61ffdf --- /dev/null +++ b/backend/internal/handler/ai/aiusageexporthandler.go @@ -0,0 +1,26 @@ +package ai + +import ( + "net/http" + + "github.com/youruser/base/internal/logic/ai" + "github.com/youruser/base/internal/svc" + "github.com/youruser/base/internal/types" + "github.com/zeromicro/go-zero/rest/httpx" +) + +// 导出用量记录CSV +func AiUsageExportHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req types.AIUsageRecordListRequest + if err := httpx.Parse(r, &req); err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } + + l := ai.NewAiUsageExportLogic(r.Context(), svcCtx) + if err := l.AiUsageExport(&req, w); err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } + } +} diff --git a/backend/internal/logic/ai/aiusageexportlogic.go b/backend/internal/logic/ai/aiusageexportlogic.go new file mode 100644 index 0000000..77f52d7 --- /dev/null +++ b/backend/internal/logic/ai/aiusageexportlogic.go @@ -0,0 +1,95 @@ +package ai + +import ( + "context" + "encoding/csv" + "fmt" + "net/http" + "time" + + "github.com/youruser/base/internal/svc" + "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" + + "github.com/zeromicro/go-zero/core/logx" +) + +type AiUsageExportLogic struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewAiUsageExportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiUsageExportLogic { + return &AiUsageExportLogic{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AiUsageExportLogic) AiUsageExport(req *types.AIUsageRecordListRequest, w http.ResponseWriter) error { + // Get current user from context + userId := req.UserId + if userId == 0 { + if uid, ok := l.ctx.Value("userId").(int64); ok { + userId = uid + } else if uidJson, ok2 := l.ctx.Value("userId").(float64); ok2 { + userId = int64(uidJson) + } + } + + // Fetch all records (large page size for export) + records, _, err := model.AIUsageRecordFindList(l.ctx, l.svcCtx.DB, userId, req.ModelId, req.Status, 1, 10000) + if err != nil { + return err + } + + // Build lookup caches + userMap := make(map[int64]string) + providerMap := make(map[int64]string) + for _, r := range records { + if _, ok := userMap[r.UserId]; !ok { + user, err := model.FindOne(l.ctx, l.svcCtx.DB, r.UserId) + if err == nil { + userMap[r.UserId] = user.Username + } + } + if _, ok := providerMap[r.ProviderId]; !ok { + p, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, r.ProviderId) + if err == nil { + providerMap[r.ProviderId] = p.DisplayName + } + } + } + + // Write CSV response + filename := fmt.Sprintf("ai-usage-%s.csv", time.Now().Format("20060102")) + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) + // BOM for Excel UTF-8 support + w.Write([]byte{0xEF, 0xBB, 0xBF}) + + writer := csv.NewWriter(w) + defer writer.Flush() + + // Header row + writer.Write([]string{"时间", "用户", "平台", "模型", "输入Tokens", "输出Tokens", "费用(¥)", "延迟(ms)", "状态", "错误信息"}) + + for _, r := range records { + writer.Write([]string{ + r.CreatedAt.Format("2006-01-02 15:04:05"), + userMap[r.UserId], + providerMap[r.ProviderId], + r.ModelId, + fmt.Sprintf("%d", r.InputTokens), + fmt.Sprintf("%d", r.OutputTokens), + fmt.Sprintf("%.4f", r.Cost), + fmt.Sprintf("%d", r.LatencyMs), + r.Status, + r.ErrorMessage, + }) + } + + return nil +} diff --git a/backend/internal/svc/servicecontext.go b/backend/internal/svc/servicecontext.go index bbe5573..998671a 100644 --- a/backend/internal/svc/servicecontext.go +++ b/backend/internal/svc/servicecontext.go @@ -282,8 +282,9 @@ func seedCasbinPolicies(enforcer *casbin.Enforcer) { {"user", "/api/v1/ai/key/:id", "PUT"}, {"user", "/api/v1/ai/key/:id", "DELETE"}, - // AI: user usage records + // AI: user usage records & export {"user", "/api/v1/ai/quota/records", "GET"}, + {"user", "/api/v1/ai/usage/export", "GET"}, // AI: admin provider/model management {"admin", "/api/v1/ai/providers", "GET"}, @@ -418,6 +419,7 @@ func seedAIProviders(db *gorm.DB) { {Name: "qwen", DisplayName: "阿里千问", BaseUrl: "https://dashscope.aliyuncs.com/compatible-mode/v1", SdkType: "openai_compat", Protocol: "openai", IsActive: true, SortOrder: 3}, {Name: "zhipu", DisplayName: "智谱 GLM", BaseUrl: "https://open.bigmodel.cn/api/paas/v4", SdkType: "openai_compat", Protocol: "openai", IsActive: true, SortOrder: 4}, {Name: "deepseek", DisplayName: "DeepSeek", BaseUrl: "https://api.deepseek.com/v1", SdkType: "openai_compat", Protocol: "openai", IsActive: true, SortOrder: 5}, + {Name: "wenxin", DisplayName: "百度文心", BaseUrl: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", SdkType: "wenxin", Protocol: "wenxin", IsActive: true, SortOrder: 6}, } for _, p := range providers { var existing model.AIProvider @@ -448,6 +450,8 @@ func seedAIModels(db *gorm.DB) { {ProviderId: providerIds["zhipu"], ModelId: "glm-4-flash", DisplayName: "GLM-4 Flash", InputPrice: 0.0001, OutputPrice: 0.0001, MaxTokens: 4096, ContextWindow: 128000, SupportsStream: true, SupportsVision: false, IsActive: true, SortOrder: 7}, {ProviderId: providerIds["deepseek"], ModelId: "deepseek-chat", DisplayName: "DeepSeek Chat", InputPrice: 0.00014, OutputPrice: 0.00028, MaxTokens: 8192, ContextWindow: 64000, SupportsStream: true, SupportsVision: false, IsActive: true, SortOrder: 8}, {ProviderId: providerIds["deepseek"], ModelId: "deepseek-reasoner", DisplayName: "DeepSeek Reasoner", InputPrice: 0.00055, OutputPrice: 0.00219, MaxTokens: 8192, ContextWindow: 64000, SupportsStream: true, SupportsVision: false, IsActive: true, SortOrder: 9}, + {ProviderId: providerIds["wenxin"], ModelId: "ernie-4.0-8k", DisplayName: "文心 ERNIE 4.0", InputPrice: 0.03, OutputPrice: 0.09, MaxTokens: 8192, ContextWindow: 8192, SupportsStream: true, SupportsVision: false, IsActive: true, SortOrder: 10}, + {ProviderId: providerIds["wenxin"], ModelId: "ernie-3.5-8k", DisplayName: "文心 ERNIE 3.5", InputPrice: 0.0008, OutputPrice: 0.002, MaxTokens: 8192, ContextWindow: 8192, SupportsStream: true, SupportsVision: false, IsActive: true, SortOrder: 11}, } for _, m := range models { var existing model.AIModel diff --git a/frontend/react-shadcn/pc/src/pages/AIUsagePage.tsx b/frontend/react-shadcn/pc/src/pages/AIUsagePage.tsx index 0e39805..87f4337 100644 --- a/frontend/react-shadcn/pc/src/pages/AIUsagePage.tsx +++ b/frontend/react-shadcn/pc/src/pages/AIUsagePage.tsx @@ -1,5 +1,5 @@ import { useState, useEffect, useCallback } from 'react' -import { BarChart3, Activity, Coins, Users } from 'lucide-react' +import { BarChart3, Activity, Coins, Users, Download } from 'lucide-react' import { Button } from '@/components/ui/Button' import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/Card' import { @@ -242,8 +242,21 @@ export function AIUsagePage() { {/* Usage Records Table */} - + 使用记录 ({total}) +