You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

347 lines
10 KiB

package ai
import (
"context"
"errors"
"fmt"
"time"
"github.com/youruser/base/internal/ai/billing"
"github.com/youruser/base/internal/ai/provider"
"github.com/youruser/base/internal/svc"
"github.com/youruser/base/internal/types"
"github.com/youruser/base/model"
"github.com/zeromicro/go-zero/core/logx"
)
type AiChatCompletionsLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic {
return &AiChatCompletionsLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
// Chat handles non-streaming chat completions
func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) {
startTime := time.Now()
userId, _ := l.ctx.Value("userId").(int64)
if userId == 0 {
return nil, errors.New("unauthorized")
}
// 1. Look up model
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
if err != nil {
return nil, fmt.Errorf("model not found: %s", req.Model)
}
// 2. Look up provider
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
if err != nil {
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
}
// 3. Select API key
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
if err != nil {
return nil, err
}
// 4. Estimate cost and freeze
estimatedCost := l.estimateCost(req, aiModel)
quotaSvc := billing.NewQuotaService()
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
return nil, fmt.Errorf("insufficient balance: %v", err)
}
// 5. Build provider and call
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, err
}
// Convert messages
chatReq := l.buildChatRequest(req, aiModel)
chatResp, err := p.Chat(l.ctx, chatReq)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, fmt.Errorf("AI request failed: %v", err)
}
latencyMs := int(time.Since(startTime).Milliseconds())
// 6. Calculate actual cost and settle
actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel)
quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
// 7. Record usage
usageSvc := billing.NewUsageService()
usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{
UserId: userId,
ProviderId: aiProvider.Id,
ModelId: req.Model,
InputTokens: chatResp.InputTokens,
OutputTokens: chatResp.OutputTokens,
Cost: actualCost,
ApiKeyId: apiKeyId,
Status: "ok",
LatencyMs: latencyMs,
})
// 8. Save messages to conversation if conversation_id provided
if req.ConversationId > 0 {
l.saveMessages(req, chatResp, aiModel, latencyMs)
}
// 9. Build response
return &types.AIChatCompletionResponse{
Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
Object: "chat.completion",
Model: req.Model,
Choices: []types.AIChatCompletionChoice{
{
Index: 0,
FinishReason: chatResp.FinishReason,
Message: types.AIChatMessage{
Role: "assistant",
Content: chatResp.Content,
},
},
},
Usage: types.AIChatCompletionUsage{
PromptTokens: chatResp.InputTokens,
CompletionTokens: chatResp.OutputTokens,
TotalTokens: chatResp.InputTokens + chatResp.OutputTokens,
},
}, nil
}
// ChatStream handles streaming chat completions
func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) {
userId, _ := l.ctx.Value("userId").(int64)
if userId == 0 {
return nil, errors.New("unauthorized")
}
// 1. Look up model
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
if err != nil {
return nil, fmt.Errorf("model not found: %s", req.Model)
}
// 2. Look up provider
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
if err != nil {
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
}
// 3. Select API key
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
if err != nil {
return nil, err
}
// 4. Estimate and freeze
estimatedCost := l.estimateCost(req, aiModel)
quotaSvc := billing.NewQuotaService()
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
return nil, fmt.Errorf("insufficient balance: %v", err)
}
// 5. Build provider
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, err
}
// 6. Start stream
chatReq := l.buildChatRequest(req, aiModel)
startTime := time.Now()
streamChan, err := p.ChatStream(l.ctx, chatReq)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, fmt.Errorf("AI stream failed: %v", err)
}
// 7. Wrap stream channel — accumulate content for post-processing
outChan := make(chan *provider.StreamChunk, 100)
go func() {
defer close(outChan)
var fullContent string
var inputTokens, outputTokens int
for chunk := range streamChan {
fullContent += chunk.Content
if chunk.InputTokens > 0 {
inputTokens = chunk.InputTokens
}
if chunk.OutputTokens > 0 {
outputTokens = chunk.OutputTokens
}
outChan <- chunk
}
latencyMs := int(time.Since(startTime).Milliseconds())
// Post-stream: calculate cost, settle billing, record usage
actualCost := l.calculateCost(inputTokens, outputTokens, aiModel)
bgCtx := context.Background()
quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
usageSvc := billing.NewUsageService()
usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{
UserId: userId,
ProviderId: aiProvider.Id,
ModelId: req.Model,
InputTokens: inputTokens,
OutputTokens: outputTokens,
Cost: actualCost,
ApiKeyId: apiKeyId,
Status: "ok",
LatencyMs: latencyMs,
})
// Save messages
if req.ConversationId > 0 {
totalTokens := int64(inputTokens + outputTokens)
usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
// Save user message (last one in req.Messages)
if len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: lastMsg.Role,
Content: lastMsg.Content,
ModelId: req.Model,
})
}
// Save assistant message
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: "assistant",
Content: fullContent,
TokenCount: outputTokens,
Cost: actualCost,
ModelId: req.Model,
LatencyMs: latencyMs,
})
}
}()
return outChan, nil
}
// selectApiKey selects the best API key: user's own key first, then system key
func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) {
// Try user's own key first
userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId)
if err == nil && len(userKeys) > 0 {
for _, key := range userKeys {
if key.IsActive {
return key.KeyValue, key.Id, nil
}
}
}
// Fall back to system key (userId=0)
systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId)
if err != nil || len(systemKeys) == 0 {
return "", 0, fmt.Errorf("no API key available for this provider")
}
return systemKeys[0].KeyValue, 0, nil
}
// estimateCost estimates cost based on input message length
func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 {
// Rough estimation: ~4 chars per token for input, estimate 1000 output tokens
totalChars := 0
for _, msg := range req.Messages {
totalChars += len(msg.Content)
}
estimatedInputTokens := totalChars / 4
estimatedOutputTokens := 1000
if req.MaxTokens > 0 {
estimatedOutputTokens = req.MaxTokens
}
cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice +
float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice
return cost
}
// calculateCost computes actual cost from real token counts
func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 {
return float64(inputTokens)/1000.0*aiModel.InputPrice +
float64(outputTokens)/1000.0*aiModel.OutputPrice
}
// buildChatRequest converts types request to provider request
func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest {
messages := make([]provider.ChatMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content}
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = aiModel.MaxTokens
}
temperature := req.Temperature
if temperature <= 0 {
temperature = 0.7
}
return &provider.ChatRequest{
Model: req.Model,
Messages: messages,
MaxTokens: maxTokens,
Temperature: temperature,
Stream: req.Stream,
}
}
// saveMessages saves user and assistant messages to conversation
func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) {
actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel)
// Save user message
if len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: lastMsg.Role,
Content: lastMsg.Content,
ModelId: req.Model,
})
}
// Save assistant message
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: "assistant",
Content: resp.Content,
TokenCount: resp.OutputTokens,
Cost: actualCost,
ModelId: req.Model,
LatencyMs: latencyMs,
})
// Update conversation stats
totalTokens := int64(resp.InputTokens + resp.OutputTokens)
billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
}