package ai import ( "context" "errors" "fmt" "time" "github.com/youruser/base/internal/ai/billing" "github.com/youruser/base/internal/ai/provider" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) type AiChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic { return &AiChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } // Chat handles non-streaming chat completions func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) { startTime := time.Now() userId, _ := l.ctx.Value("userId").(int64) if userId == 0 { return nil, errors.New("unauthorized") } // 1. Look up model aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) if err != nil { return nil, fmt.Errorf("model not found: %s", req.Model) } // 2. Look up provider aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) if err != nil { return nil, fmt.Errorf("provider not found for model: %s", req.Model) } // 3. Select API key apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) if err != nil { return nil, err } // 4. Estimate cost and freeze estimatedCost := l.estimateCost(req, aiModel) quotaSvc := billing.NewQuotaService() if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { return nil, fmt.Errorf("insufficient balance: %v", err) } // 5. Build provider and call p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) if err != nil { quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) return nil, err } // Convert messages chatReq := l.buildChatRequest(req, aiModel) chatResp, err := p.Chat(l.ctx, chatReq) if err != nil { quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) return nil, fmt.Errorf("AI request failed: %v", err) } latencyMs := int(time.Since(startTime).Milliseconds()) // 6. Calculate actual cost and settle actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel) quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) // 7. Record usage usageSvc := billing.NewUsageService() usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{ UserId: userId, ProviderId: aiProvider.Id, ModelId: req.Model, InputTokens: chatResp.InputTokens, OutputTokens: chatResp.OutputTokens, Cost: actualCost, ApiKeyId: apiKeyId, Status: "ok", LatencyMs: latencyMs, }) // 8. Save messages to conversation if conversation_id provided if req.ConversationId > 0 { l.saveMessages(req, chatResp, aiModel, latencyMs) } // 9. Build response return &types.AIChatCompletionResponse{ Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), Object: "chat.completion", Model: req.Model, Choices: []types.AIChatCompletionChoice{ { Index: 0, FinishReason: chatResp.FinishReason, Message: types.AIChatMessage{ Role: "assistant", Content: chatResp.Content, }, }, }, Usage: types.AIChatCompletionUsage{ PromptTokens: chatResp.InputTokens, CompletionTokens: chatResp.OutputTokens, TotalTokens: chatResp.InputTokens + chatResp.OutputTokens, }, }, nil } // ChatStream handles streaming chat completions func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) { userId, _ := l.ctx.Value("userId").(int64) if userId == 0 { return nil, errors.New("unauthorized") } // 1. Look up model aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) if err != nil { return nil, fmt.Errorf("model not found: %s", req.Model) } // 2. Look up provider aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) if err != nil { return nil, fmt.Errorf("provider not found for model: %s", req.Model) } // 3. Select API key apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) if err != nil { return nil, err } // 4. Estimate and freeze estimatedCost := l.estimateCost(req, aiModel) quotaSvc := billing.NewQuotaService() if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { return nil, fmt.Errorf("insufficient balance: %v", err) } // 5. Build provider p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) if err != nil { quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) return nil, err } // 6. Start stream chatReq := l.buildChatRequest(req, aiModel) startTime := time.Now() streamChan, err := p.ChatStream(l.ctx, chatReq) if err != nil { quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) return nil, fmt.Errorf("AI stream failed: %v", err) } // 7. Wrap stream channel — accumulate content for post-processing outChan := make(chan *provider.StreamChunk, 100) go func() { defer close(outChan) var fullContent string var inputTokens, outputTokens int for chunk := range streamChan { fullContent += chunk.Content if chunk.InputTokens > 0 { inputTokens = chunk.InputTokens } if chunk.OutputTokens > 0 { outputTokens = chunk.OutputTokens } outChan <- chunk } latencyMs := int(time.Since(startTime).Milliseconds()) // Post-stream: calculate cost, settle billing, record usage actualCost := l.calculateCost(inputTokens, outputTokens, aiModel) bgCtx := context.Background() quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) usageSvc := billing.NewUsageService() usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{ UserId: userId, ProviderId: aiProvider.Id, ModelId: req.Model, InputTokens: inputTokens, OutputTokens: outputTokens, Cost: actualCost, ApiKeyId: apiKeyId, Status: "ok", LatencyMs: latencyMs, }) // Save messages if req.ConversationId > 0 { totalTokens := int64(inputTokens + outputTokens) usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) // Save user message (last one in req.Messages) if len(req.Messages) > 0 { lastMsg := req.Messages[len(req.Messages)-1] model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ ConversationId: req.ConversationId, Role: lastMsg.Role, Content: lastMsg.Content, ModelId: req.Model, }) } // Save assistant message model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ ConversationId: req.ConversationId, Role: "assistant", Content: fullContent, TokenCount: outputTokens, Cost: actualCost, ModelId: req.Model, LatencyMs: latencyMs, }) } }() return outChan, nil } // selectApiKey selects the best API key: user's own key first, then system key func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) { // Try user's own key first userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId) if err == nil && len(userKeys) > 0 { for _, key := range userKeys { if key.IsActive { return key.KeyValue, key.Id, nil } } } // Fall back to system key (userId=0) systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId) if err != nil || len(systemKeys) == 0 { return "", 0, fmt.Errorf("no API key available for this provider") } return systemKeys[0].KeyValue, 0, nil } // estimateCost estimates cost based on input message length func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 { // Rough estimation: ~4 chars per token for input, estimate 1000 output tokens totalChars := 0 for _, msg := range req.Messages { totalChars += len(msg.Content) } estimatedInputTokens := totalChars / 4 estimatedOutputTokens := 1000 if req.MaxTokens > 0 { estimatedOutputTokens = req.MaxTokens } cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice + float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice return cost } // calculateCost computes actual cost from real token counts func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 { return float64(inputTokens)/1000.0*aiModel.InputPrice + float64(outputTokens)/1000.0*aiModel.OutputPrice } // buildChatRequest converts types request to provider request func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest { messages := make([]provider.ChatMessage, len(req.Messages)) for i, m := range req.Messages { messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content} } maxTokens := req.MaxTokens if maxTokens <= 0 { maxTokens = aiModel.MaxTokens } temperature := req.Temperature if temperature <= 0 { temperature = 0.7 } return &provider.ChatRequest{ Model: req.Model, Messages: messages, MaxTokens: maxTokens, Temperature: temperature, Stream: req.Stream, } } // saveMessages saves user and assistant messages to conversation func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) { actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel) // Save user message if len(req.Messages) > 0 { lastMsg := req.Messages[len(req.Messages)-1] model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ ConversationId: req.ConversationId, Role: lastMsg.Role, Content: lastMsg.Content, ModelId: req.Model, }) } // Save assistant message model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ ConversationId: req.ConversationId, Role: "assistant", Content: resp.Content, TokenCount: resp.OutputTokens, Cost: actualCost, ModelId: req.Model, LatencyMs: latencyMs, }) // Update conversation stats totalTokens := int64(resp.InputTokens + resp.OutputTokens) billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) }