From f1d154959575454f3a1fdebf463e1ed8438c3ba0 Mon Sep 17 00:00:00 2001 From: dark Date: Sat, 14 Feb 2026 22:15:09 +0800 Subject: [PATCH] feat: implement AI chat completions with SSE streaming - Custom SSE handler (replaces goctl stub) - Chat() for non-streaming, ChatStream() for SSE - Provider selection (user key > system key) - Quota freeze/settle/unfreeze billing flow - Usage recording + conversation message saving --- .../handler/ai/aichatcompletionshandler.go | 48 ++- .../logic/ai/aichatcompletionslogic.go | 327 +++++++++++++++++- 2 files changed, 359 insertions(+), 16 deletions(-) diff --git a/backend/internal/handler/ai/aichatcompletionshandler.go b/backend/internal/handler/ai/aichatcompletionshandler.go index 0c1a27b..d3e44fa 100644 --- a/backend/internal/handler/ai/aichatcompletionshandler.go +++ b/backend/internal/handler/ai/aichatcompletionshandler.go @@ -1,9 +1,8 @@ -// Code scaffolded by goctl. Safe to edit. -// goctl 1.9.2 - package ai import ( + "encoding/json" + "fmt" "net/http" "github.com/youruser/base/internal/logic/ai" @@ -12,21 +11,52 @@ import ( "github.com/zeromicro/go-zero/rest/httpx" ) -// AI 对话补全 func AiChatCompletionsHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.AIChatCompletionRequest - if err := httpx.Parse(r, &req); err != nil { + if err := httpx.ParseJsonBody(r, &req); err != nil { httpx.ErrorCtx(r.Context(), w, err) return } l := ai.NewAiChatCompletionsLogic(r.Context(), svcCtx) - err := l.AiChatCompletions(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + + if req.Stream { + // SSE streaming mode + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + streamChan, err := l.ChatStream(&req) + if err != nil { + errData, _ := json.Marshal(map[string]string{"error": err.Error()}) + fmt.Fprintf(w, "data: %s\n\n", errData) + flusher.Flush() + return + } + + for chunk := range streamChan { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() } else { - httpx.Ok(w) + // Normal (non-streaming) mode + resp, err := l.Chat(&req) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + } else { + httpx.OkJsonCtx(r.Context(), w, resp) + } } } } diff --git a/backend/internal/logic/ai/aichatcompletionslogic.go b/backend/internal/logic/ai/aichatcompletionslogic.go index be12fba..b1c29ab 100644 --- a/backend/internal/logic/ai/aichatcompletionslogic.go +++ b/backend/internal/logic/ai/aichatcompletionslogic.go @@ -1,13 +1,16 @@ -// Code scaffolded by goctl. Safe to edit. -// goctl 1.9.2 - package ai import ( "context" + "errors" + "fmt" + "time" + "github.com/youruser/base/internal/ai/billing" + "github.com/youruser/base/internal/ai/provider" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -18,7 +21,6 @@ type AiChatCompletionsLogic struct { svcCtx *svc.ServiceContext } -// AI 对话补全 func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic { return &AiChatCompletionsLogic{ Logger: logx.WithContext(ctx), @@ -27,8 +29,319 @@ func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) } } -func (l *AiChatCompletionsLogic) AiChatCompletions(req *types.AIChatCompletionRequest) error { - // todo: add your logic here and delete this line +// Chat handles non-streaming chat completions +func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) { + startTime := time.Now() + + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + // 1. Look up model + aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) + if err != nil { + return nil, fmt.Errorf("model not found: %s", req.Model) + } + + // 2. Look up provider + aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) + if err != nil { + return nil, fmt.Errorf("provider not found for model: %s", req.Model) + } + + // 3. Select API key + apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) + if err != nil { + return nil, err + } + + // 4. Estimate cost and freeze + estimatedCost := l.estimateCost(req, aiModel) + quotaSvc := billing.NewQuotaService() + if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { + return nil, fmt.Errorf("insufficient balance: %v", err) + } + + // 5. Build provider and call + p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) + if err != nil { + quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) + return nil, err + } + + // Convert messages + chatReq := l.buildChatRequest(req, aiModel) + chatResp, err := p.Chat(l.ctx, chatReq) + if err != nil { + quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) + return nil, fmt.Errorf("AI request failed: %v", err) + } + + latencyMs := int(time.Since(startTime).Milliseconds()) + + // 6. Calculate actual cost and settle + actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel) + quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) + + // 7. Record usage + usageSvc := billing.NewUsageService() + usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{ + UserId: userId, + ProviderId: aiProvider.Id, + ModelId: req.Model, + InputTokens: chatResp.InputTokens, + OutputTokens: chatResp.OutputTokens, + Cost: actualCost, + ApiKeyId: apiKeyId, + Status: "ok", + LatencyMs: latencyMs, + }) + + // 8. Save messages to conversation if conversation_id provided + if req.ConversationId > 0 { + l.saveMessages(req, chatResp, aiModel, latencyMs) + } + + // 9. Build response + return &types.AIChatCompletionResponse{ + Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + Object: "chat.completion", + Model: req.Model, + Choices: []types.AIChatCompletionChoice{ + { + Index: 0, + FinishReason: chatResp.FinishReason, + Message: types.AIChatMessage{ + Role: "assistant", + Content: chatResp.Content, + }, + }, + }, + Usage: types.AIChatCompletionUsage{ + PromptTokens: chatResp.InputTokens, + CompletionTokens: chatResp.OutputTokens, + TotalTokens: chatResp.InputTokens + chatResp.OutputTokens, + }, + }, nil +} + +// ChatStream handles streaming chat completions +func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) { + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + // 1. Look up model + aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) + if err != nil { + return nil, fmt.Errorf("model not found: %s", req.Model) + } + + // 2. Look up provider + aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) + if err != nil { + return nil, fmt.Errorf("provider not found for model: %s", req.Model) + } + + // 3. Select API key + apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) + if err != nil { + return nil, err + } + + // 4. Estimate and freeze + estimatedCost := l.estimateCost(req, aiModel) + quotaSvc := billing.NewQuotaService() + if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { + return nil, fmt.Errorf("insufficient balance: %v", err) + } + + // 5. Build provider + p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) + if err != nil { + quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) + return nil, err + } + + // 6. Start stream + chatReq := l.buildChatRequest(req, aiModel) + startTime := time.Now() + streamChan, err := p.ChatStream(l.ctx, chatReq) + if err != nil { + quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) + return nil, fmt.Errorf("AI stream failed: %v", err) + } + + // 7. Wrap stream channel — accumulate content for post-processing + outChan := make(chan *provider.StreamChunk, 100) + go func() { + defer close(outChan) + + var fullContent string + var inputTokens, outputTokens int + + for chunk := range streamChan { + fullContent += chunk.Content + if chunk.InputTokens > 0 { + inputTokens = chunk.InputTokens + } + if chunk.OutputTokens > 0 { + outputTokens = chunk.OutputTokens + } + outChan <- chunk + } + + latencyMs := int(time.Since(startTime).Milliseconds()) + + // Post-stream: calculate cost, settle billing, record usage + actualCost := l.calculateCost(inputTokens, outputTokens, aiModel) + bgCtx := context.Background() + quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) + + usageSvc := billing.NewUsageService() + usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{ + UserId: userId, + ProviderId: aiProvider.Id, + ModelId: req.Model, + InputTokens: inputTokens, + OutputTokens: outputTokens, + Cost: actualCost, + ApiKeyId: apiKeyId, + Status: "ok", + LatencyMs: latencyMs, + }) + + // Save messages + if req.ConversationId > 0 { + totalTokens := int64(inputTokens + outputTokens) + usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) + + // Save user message (last one in req.Messages) + if len(req.Messages) > 0 { + lastMsg := req.Messages[len(req.Messages)-1] + model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ + ConversationId: req.ConversationId, + Role: lastMsg.Role, + Content: lastMsg.Content, + ModelId: req.Model, + }) + } + // Save assistant message + model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ + ConversationId: req.ConversationId, + Role: "assistant", + Content: fullContent, + TokenCount: outputTokens, + Cost: actualCost, + ModelId: req.Model, + LatencyMs: latencyMs, + }) + } + }() + + return outChan, nil +} + +// selectApiKey selects the best API key: user's own key first, then system key +func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) { + // Try user's own key first + userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId) + if err == nil && len(userKeys) > 0 { + for _, key := range userKeys { + if key.IsActive { + return key.KeyValue, key.Id, nil + } + } + } + + // Fall back to system key (userId=0) + systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId) + if err != nil || len(systemKeys) == 0 { + return "", 0, fmt.Errorf("no API key available for this provider") + } + + return systemKeys[0].KeyValue, 0, nil +} + +// estimateCost estimates cost based on input message length +func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 { + // Rough estimation: ~4 chars per token for input, estimate 1000 output tokens + totalChars := 0 + for _, msg := range req.Messages { + totalChars += len(msg.Content) + } + estimatedInputTokens := totalChars / 4 + estimatedOutputTokens := 1000 + if req.MaxTokens > 0 { + estimatedOutputTokens = req.MaxTokens + } + + cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice + + float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice + return cost +} + +// calculateCost computes actual cost from real token counts +func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 { + return float64(inputTokens)/1000.0*aiModel.InputPrice + + float64(outputTokens)/1000.0*aiModel.OutputPrice +} + +// buildChatRequest converts types request to provider request +func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest { + messages := make([]provider.ChatMessage, len(req.Messages)) + for i, m := range req.Messages { + messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content} + } + + maxTokens := req.MaxTokens + if maxTokens <= 0 { + maxTokens = aiModel.MaxTokens + } + + temperature := req.Temperature + if temperature <= 0 { + temperature = 0.7 + } + + return &provider.ChatRequest{ + Model: req.Model, + Messages: messages, + MaxTokens: maxTokens, + Temperature: temperature, + Stream: req.Stream, + } +} + +// saveMessages saves user and assistant messages to conversation +func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) { + actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel) + + // Save user message + if len(req.Messages) > 0 { + lastMsg := req.Messages[len(req.Messages)-1] + model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ + ConversationId: req.ConversationId, + Role: lastMsg.Role, + Content: lastMsg.Content, + ModelId: req.Model, + }) + } + + // Save assistant message + model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ + ConversationId: req.ConversationId, + Role: "assistant", + Content: resp.Content, + TokenCount: resp.OutputTokens, + Cost: actualCost, + ModelId: req.Model, + LatencyMs: latencyMs, + }) - return nil + // Update conversation stats + totalTokens := int64(resp.InputTokens + resp.OutputTokens) + billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) }