You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
10 KiB
347 lines
10 KiB
package ai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/youruser/base/internal/ai/billing"
|
|
"github.com/youruser/base/internal/ai/provider"
|
|
"github.com/youruser/base/internal/svc"
|
|
"github.com/youruser/base/internal/types"
|
|
"github.com/youruser/base/model"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
type AiChatCompletionsLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
}
|
|
|
|
func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic {
|
|
return &AiChatCompletionsLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
// Chat handles non-streaming chat completions
|
|
func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) {
|
|
startTime := time.Now()
|
|
|
|
userId, _ := l.ctx.Value("userId").(int64)
|
|
if userId == 0 {
|
|
return nil, errors.New("unauthorized")
|
|
}
|
|
|
|
// 1. Look up model
|
|
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model not found: %s", req.Model)
|
|
}
|
|
|
|
// 2. Look up provider
|
|
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
|
|
}
|
|
|
|
// 3. Select API key
|
|
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 4. Estimate cost and freeze
|
|
estimatedCost := l.estimateCost(req, aiModel)
|
|
quotaSvc := billing.NewQuotaService()
|
|
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
|
|
return nil, fmt.Errorf("insufficient balance: %v", err)
|
|
}
|
|
|
|
// 5. Build provider and call
|
|
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
|
|
if err != nil {
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
|
|
return nil, err
|
|
}
|
|
|
|
// Convert messages
|
|
chatReq := l.buildChatRequest(req, aiModel)
|
|
chatResp, err := p.Chat(l.ctx, chatReq)
|
|
if err != nil {
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
|
|
return nil, fmt.Errorf("AI request failed: %v", err)
|
|
}
|
|
|
|
latencyMs := int(time.Since(startTime).Milliseconds())
|
|
|
|
// 6. Calculate actual cost and settle
|
|
actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel)
|
|
quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
|
|
|
|
// 7. Record usage
|
|
usageSvc := billing.NewUsageService()
|
|
usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{
|
|
UserId: userId,
|
|
ProviderId: aiProvider.Id,
|
|
ModelId: req.Model,
|
|
InputTokens: chatResp.InputTokens,
|
|
OutputTokens: chatResp.OutputTokens,
|
|
Cost: actualCost,
|
|
ApiKeyId: apiKeyId,
|
|
Status: "ok",
|
|
LatencyMs: latencyMs,
|
|
})
|
|
|
|
// 8. Save messages to conversation if conversation_id provided
|
|
if req.ConversationId > 0 {
|
|
l.saveMessages(req, chatResp, aiModel, latencyMs)
|
|
}
|
|
|
|
// 9. Build response
|
|
return &types.AIChatCompletionResponse{
|
|
Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
|
Object: "chat.completion",
|
|
Model: req.Model,
|
|
Choices: []types.AIChatCompletionChoice{
|
|
{
|
|
Index: 0,
|
|
FinishReason: chatResp.FinishReason,
|
|
Message: types.AIChatMessage{
|
|
Role: "assistant",
|
|
Content: chatResp.Content,
|
|
},
|
|
},
|
|
},
|
|
Usage: types.AIChatCompletionUsage{
|
|
PromptTokens: chatResp.InputTokens,
|
|
CompletionTokens: chatResp.OutputTokens,
|
|
TotalTokens: chatResp.InputTokens + chatResp.OutputTokens,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// ChatStream handles streaming chat completions
|
|
func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) {
|
|
userId, _ := l.ctx.Value("userId").(int64)
|
|
if userId == 0 {
|
|
return nil, errors.New("unauthorized")
|
|
}
|
|
|
|
// 1. Look up model
|
|
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("model not found: %s", req.Model)
|
|
}
|
|
|
|
// 2. Look up provider
|
|
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
|
|
}
|
|
|
|
// 3. Select API key
|
|
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 4. Estimate and freeze
|
|
estimatedCost := l.estimateCost(req, aiModel)
|
|
quotaSvc := billing.NewQuotaService()
|
|
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
|
|
return nil, fmt.Errorf("insufficient balance: %v", err)
|
|
}
|
|
|
|
// 5. Build provider
|
|
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
|
|
if err != nil {
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
|
|
return nil, err
|
|
}
|
|
|
|
// 6. Start stream
|
|
chatReq := l.buildChatRequest(req, aiModel)
|
|
startTime := time.Now()
|
|
streamChan, err := p.ChatStream(l.ctx, chatReq)
|
|
if err != nil {
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
|
|
return nil, fmt.Errorf("AI stream failed: %v", err)
|
|
}
|
|
|
|
// 7. Wrap stream channel — accumulate content for post-processing
|
|
outChan := make(chan *provider.StreamChunk, 100)
|
|
go func() {
|
|
defer close(outChan)
|
|
|
|
var fullContent string
|
|
var inputTokens, outputTokens int
|
|
|
|
for chunk := range streamChan {
|
|
fullContent += chunk.Content
|
|
if chunk.InputTokens > 0 {
|
|
inputTokens = chunk.InputTokens
|
|
}
|
|
if chunk.OutputTokens > 0 {
|
|
outputTokens = chunk.OutputTokens
|
|
}
|
|
outChan <- chunk
|
|
}
|
|
|
|
latencyMs := int(time.Since(startTime).Milliseconds())
|
|
|
|
// Post-stream: calculate cost, settle billing, record usage
|
|
actualCost := l.calculateCost(inputTokens, outputTokens, aiModel)
|
|
bgCtx := context.Background()
|
|
quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
|
|
|
|
usageSvc := billing.NewUsageService()
|
|
usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{
|
|
UserId: userId,
|
|
ProviderId: aiProvider.Id,
|
|
ModelId: req.Model,
|
|
InputTokens: inputTokens,
|
|
OutputTokens: outputTokens,
|
|
Cost: actualCost,
|
|
ApiKeyId: apiKeyId,
|
|
Status: "ok",
|
|
LatencyMs: latencyMs,
|
|
})
|
|
|
|
// Save messages
|
|
if req.ConversationId > 0 {
|
|
totalTokens := int64(inputTokens + outputTokens)
|
|
usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
|
|
|
|
// Save user message (last one in req.Messages)
|
|
if len(req.Messages) > 0 {
|
|
lastMsg := req.Messages[len(req.Messages)-1]
|
|
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
|
|
ConversationId: req.ConversationId,
|
|
Role: lastMsg.Role,
|
|
Content: lastMsg.Content,
|
|
ModelId: req.Model,
|
|
})
|
|
}
|
|
// Save assistant message
|
|
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
|
|
ConversationId: req.ConversationId,
|
|
Role: "assistant",
|
|
Content: fullContent,
|
|
TokenCount: outputTokens,
|
|
Cost: actualCost,
|
|
ModelId: req.Model,
|
|
LatencyMs: latencyMs,
|
|
})
|
|
}
|
|
}()
|
|
|
|
return outChan, nil
|
|
}
|
|
|
|
// selectApiKey selects the best API key: user's own key first, then system key
|
|
func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) {
|
|
// Try user's own key first
|
|
userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId)
|
|
if err == nil && len(userKeys) > 0 {
|
|
for _, key := range userKeys {
|
|
if key.IsActive {
|
|
return key.KeyValue, key.Id, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall back to system key (userId=0)
|
|
systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId)
|
|
if err != nil || len(systemKeys) == 0 {
|
|
return "", 0, fmt.Errorf("no API key available for this provider")
|
|
}
|
|
|
|
return systemKeys[0].KeyValue, 0, nil
|
|
}
|
|
|
|
// estimateCost estimates cost based on input message length
|
|
func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 {
|
|
// Rough estimation: ~4 chars per token for input, estimate 1000 output tokens
|
|
totalChars := 0
|
|
for _, msg := range req.Messages {
|
|
totalChars += len(msg.Content)
|
|
}
|
|
estimatedInputTokens := totalChars / 4
|
|
estimatedOutputTokens := 1000
|
|
if req.MaxTokens > 0 {
|
|
estimatedOutputTokens = req.MaxTokens
|
|
}
|
|
|
|
cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice +
|
|
float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice
|
|
return cost
|
|
}
|
|
|
|
// calculateCost computes actual cost from real token counts
|
|
func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 {
|
|
return float64(inputTokens)/1000.0*aiModel.InputPrice +
|
|
float64(outputTokens)/1000.0*aiModel.OutputPrice
|
|
}
|
|
|
|
// buildChatRequest converts types request to provider request
|
|
func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest {
|
|
messages := make([]provider.ChatMessage, len(req.Messages))
|
|
for i, m := range req.Messages {
|
|
messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content}
|
|
}
|
|
|
|
maxTokens := req.MaxTokens
|
|
if maxTokens <= 0 {
|
|
maxTokens = aiModel.MaxTokens
|
|
}
|
|
|
|
temperature := req.Temperature
|
|
if temperature <= 0 {
|
|
temperature = 0.7
|
|
}
|
|
|
|
return &provider.ChatRequest{
|
|
Model: req.Model,
|
|
Messages: messages,
|
|
MaxTokens: maxTokens,
|
|
Temperature: temperature,
|
|
Stream: req.Stream,
|
|
}
|
|
}
|
|
|
|
// saveMessages saves user and assistant messages to conversation
|
|
func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) {
|
|
actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel)
|
|
|
|
// Save user message
|
|
if len(req.Messages) > 0 {
|
|
lastMsg := req.Messages[len(req.Messages)-1]
|
|
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
|
|
ConversationId: req.ConversationId,
|
|
Role: lastMsg.Role,
|
|
Content: lastMsg.Content,
|
|
ModelId: req.Model,
|
|
})
|
|
}
|
|
|
|
// Save assistant message
|
|
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
|
|
ConversationId: req.ConversationId,
|
|
Role: "assistant",
|
|
Content: resp.Content,
|
|
TokenCount: resp.OutputTokens,
|
|
Cost: actualCost,
|
|
ModelId: req.Model,
|
|
LatencyMs: latencyMs,
|
|
})
|
|
|
|
// Update conversation stats
|
|
totalTokens := int64(resp.InputTokens + resp.OutputTokens)
|
|
billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
|
|
}
|
|
|