|
|
|
@ -1,13 +1,16 @@ |
|
|
|
// Code scaffolded by goctl. Safe to edit.
|
|
|
|
// goctl 1.9.2
|
|
|
|
|
|
|
|
package ai |
|
|
|
|
|
|
|
import ( |
|
|
|
"context" |
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/youruser/base/internal/ai/billing" |
|
|
|
"github.com/youruser/base/internal/ai/provider" |
|
|
|
"github.com/youruser/base/internal/svc" |
|
|
|
"github.com/youruser/base/internal/types" |
|
|
|
"github.com/youruser/base/model" |
|
|
|
|
|
|
|
"github.com/zeromicro/go-zero/core/logx" |
|
|
|
) |
|
|
|
@ -18,7 +21,6 @@ type AiChatCompletionsLogic struct { |
|
|
|
svcCtx *svc.ServiceContext |
|
|
|
} |
|
|
|
|
|
|
|
// AI 对话补全
|
|
|
|
func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic { |
|
|
|
return &AiChatCompletionsLogic{ |
|
|
|
Logger: logx.WithContext(ctx), |
|
|
|
@ -27,8 +29,319 @@ func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (l *AiChatCompletionsLogic) AiChatCompletions(req *types.AIChatCompletionRequest) error { |
|
|
|
// todo: add your logic here and delete this line
|
|
|
|
// Chat handles non-streaming chat completions
|
|
|
|
func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) { |
|
|
|
startTime := time.Now() |
|
|
|
|
|
|
|
userId, _ := l.ctx.Value("userId").(int64) |
|
|
|
if userId == 0 { |
|
|
|
return nil, errors.New("unauthorized") |
|
|
|
} |
|
|
|
|
|
|
|
// 1. Look up model
|
|
|
|
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("model not found: %s", req.Model) |
|
|
|
} |
|
|
|
|
|
|
|
// 2. Look up provider
|
|
|
|
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("provider not found for model: %s", req.Model) |
|
|
|
} |
|
|
|
|
|
|
|
// 3. Select API key
|
|
|
|
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
// 4. Estimate cost and freeze
|
|
|
|
estimatedCost := l.estimateCost(req, aiModel) |
|
|
|
quotaSvc := billing.NewQuotaService() |
|
|
|
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { |
|
|
|
return nil, fmt.Errorf("insufficient balance: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// 5. Build provider and call
|
|
|
|
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) |
|
|
|
if err != nil { |
|
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
// Convert messages
|
|
|
|
chatReq := l.buildChatRequest(req, aiModel) |
|
|
|
chatResp, err := p.Chat(l.ctx, chatReq) |
|
|
|
if err != nil { |
|
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) |
|
|
|
return nil, fmt.Errorf("AI request failed: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
latencyMs := int(time.Since(startTime).Milliseconds()) |
|
|
|
|
|
|
|
// 6. Calculate actual cost and settle
|
|
|
|
actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel) |
|
|
|
quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) |
|
|
|
|
|
|
|
// 7. Record usage
|
|
|
|
usageSvc := billing.NewUsageService() |
|
|
|
usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{ |
|
|
|
UserId: userId, |
|
|
|
ProviderId: aiProvider.Id, |
|
|
|
ModelId: req.Model, |
|
|
|
InputTokens: chatResp.InputTokens, |
|
|
|
OutputTokens: chatResp.OutputTokens, |
|
|
|
Cost: actualCost, |
|
|
|
ApiKeyId: apiKeyId, |
|
|
|
Status: "ok", |
|
|
|
LatencyMs: latencyMs, |
|
|
|
}) |
|
|
|
|
|
|
|
// 8. Save messages to conversation if conversation_id provided
|
|
|
|
if req.ConversationId > 0 { |
|
|
|
l.saveMessages(req, chatResp, aiModel, latencyMs) |
|
|
|
} |
|
|
|
|
|
|
|
// 9. Build response
|
|
|
|
return &types.AIChatCompletionResponse{ |
|
|
|
Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), |
|
|
|
Object: "chat.completion", |
|
|
|
Model: req.Model, |
|
|
|
Choices: []types.AIChatCompletionChoice{ |
|
|
|
{ |
|
|
|
Index: 0, |
|
|
|
FinishReason: chatResp.FinishReason, |
|
|
|
Message: types.AIChatMessage{ |
|
|
|
Role: "assistant", |
|
|
|
Content: chatResp.Content, |
|
|
|
}, |
|
|
|
}, |
|
|
|
}, |
|
|
|
Usage: types.AIChatCompletionUsage{ |
|
|
|
PromptTokens: chatResp.InputTokens, |
|
|
|
CompletionTokens: chatResp.OutputTokens, |
|
|
|
TotalTokens: chatResp.InputTokens + chatResp.OutputTokens, |
|
|
|
}, |
|
|
|
}, nil |
|
|
|
} |
|
|
|
|
|
|
|
// ChatStream handles streaming chat completions
|
|
|
|
func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) { |
|
|
|
userId, _ := l.ctx.Value("userId").(int64) |
|
|
|
if userId == 0 { |
|
|
|
return nil, errors.New("unauthorized") |
|
|
|
} |
|
|
|
|
|
|
|
// 1. Look up model
|
|
|
|
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("model not found: %s", req.Model) |
|
|
|
} |
|
|
|
|
|
|
|
// 2. Look up provider
|
|
|
|
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId) |
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("provider not found for model: %s", req.Model) |
|
|
|
} |
|
|
|
|
|
|
|
// 3. Select API key
|
|
|
|
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
// 4. Estimate and freeze
|
|
|
|
estimatedCost := l.estimateCost(req, aiModel) |
|
|
|
quotaSvc := billing.NewQuotaService() |
|
|
|
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil { |
|
|
|
return nil, fmt.Errorf("insufficient balance: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// 5. Build provider
|
|
|
|
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey) |
|
|
|
if err != nil { |
|
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
// 6. Start stream
|
|
|
|
chatReq := l.buildChatRequest(req, aiModel) |
|
|
|
startTime := time.Now() |
|
|
|
streamChan, err := p.ChatStream(l.ctx, chatReq) |
|
|
|
if err != nil { |
|
|
|
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId) |
|
|
|
return nil, fmt.Errorf("AI stream failed: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// 7. Wrap stream channel — accumulate content for post-processing
|
|
|
|
outChan := make(chan *provider.StreamChunk, 100) |
|
|
|
go func() { |
|
|
|
defer close(outChan) |
|
|
|
|
|
|
|
var fullContent string |
|
|
|
var inputTokens, outputTokens int |
|
|
|
|
|
|
|
for chunk := range streamChan { |
|
|
|
fullContent += chunk.Content |
|
|
|
if chunk.InputTokens > 0 { |
|
|
|
inputTokens = chunk.InputTokens |
|
|
|
} |
|
|
|
if chunk.OutputTokens > 0 { |
|
|
|
outputTokens = chunk.OutputTokens |
|
|
|
} |
|
|
|
outChan <- chunk |
|
|
|
} |
|
|
|
|
|
|
|
latencyMs := int(time.Since(startTime).Milliseconds()) |
|
|
|
|
|
|
|
// Post-stream: calculate cost, settle billing, record usage
|
|
|
|
actualCost := l.calculateCost(inputTokens, outputTokens, aiModel) |
|
|
|
bgCtx := context.Background() |
|
|
|
quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId) |
|
|
|
|
|
|
|
usageSvc := billing.NewUsageService() |
|
|
|
usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{ |
|
|
|
UserId: userId, |
|
|
|
ProviderId: aiProvider.Id, |
|
|
|
ModelId: req.Model, |
|
|
|
InputTokens: inputTokens, |
|
|
|
OutputTokens: outputTokens, |
|
|
|
Cost: actualCost, |
|
|
|
ApiKeyId: apiKeyId, |
|
|
|
Status: "ok", |
|
|
|
LatencyMs: latencyMs, |
|
|
|
}) |
|
|
|
|
|
|
|
// Save messages
|
|
|
|
if req.ConversationId > 0 { |
|
|
|
totalTokens := int64(inputTokens + outputTokens) |
|
|
|
usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) |
|
|
|
|
|
|
|
// Save user message (last one in req.Messages)
|
|
|
|
if len(req.Messages) > 0 { |
|
|
|
lastMsg := req.Messages[len(req.Messages)-1] |
|
|
|
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ |
|
|
|
ConversationId: req.ConversationId, |
|
|
|
Role: lastMsg.Role, |
|
|
|
Content: lastMsg.Content, |
|
|
|
ModelId: req.Model, |
|
|
|
}) |
|
|
|
} |
|
|
|
// Save assistant message
|
|
|
|
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{ |
|
|
|
ConversationId: req.ConversationId, |
|
|
|
Role: "assistant", |
|
|
|
Content: fullContent, |
|
|
|
TokenCount: outputTokens, |
|
|
|
Cost: actualCost, |
|
|
|
ModelId: req.Model, |
|
|
|
LatencyMs: latencyMs, |
|
|
|
}) |
|
|
|
} |
|
|
|
}() |
|
|
|
|
|
|
|
return outChan, nil |
|
|
|
} |
|
|
|
|
|
|
|
// selectApiKey selects the best API key: user's own key first, then system key
|
|
|
|
func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) { |
|
|
|
// Try user's own key first
|
|
|
|
userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId) |
|
|
|
if err == nil && len(userKeys) > 0 { |
|
|
|
for _, key := range userKeys { |
|
|
|
if key.IsActive { |
|
|
|
return key.KeyValue, key.Id, nil |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Fall back to system key (userId=0)
|
|
|
|
systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId) |
|
|
|
if err != nil || len(systemKeys) == 0 { |
|
|
|
return "", 0, fmt.Errorf("no API key available for this provider") |
|
|
|
} |
|
|
|
|
|
|
|
return systemKeys[0].KeyValue, 0, nil |
|
|
|
} |
|
|
|
|
|
|
|
// estimateCost estimates cost based on input message length
|
|
|
|
func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 { |
|
|
|
// Rough estimation: ~4 chars per token for input, estimate 1000 output tokens
|
|
|
|
totalChars := 0 |
|
|
|
for _, msg := range req.Messages { |
|
|
|
totalChars += len(msg.Content) |
|
|
|
} |
|
|
|
estimatedInputTokens := totalChars / 4 |
|
|
|
estimatedOutputTokens := 1000 |
|
|
|
if req.MaxTokens > 0 { |
|
|
|
estimatedOutputTokens = req.MaxTokens |
|
|
|
} |
|
|
|
|
|
|
|
cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice + |
|
|
|
float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice |
|
|
|
return cost |
|
|
|
} |
|
|
|
|
|
|
|
// calculateCost computes actual cost from real token counts
|
|
|
|
func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 { |
|
|
|
return float64(inputTokens)/1000.0*aiModel.InputPrice + |
|
|
|
float64(outputTokens)/1000.0*aiModel.OutputPrice |
|
|
|
} |
|
|
|
|
|
|
|
// buildChatRequest converts types request to provider request
|
|
|
|
func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest { |
|
|
|
messages := make([]provider.ChatMessage, len(req.Messages)) |
|
|
|
for i, m := range req.Messages { |
|
|
|
messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content} |
|
|
|
} |
|
|
|
|
|
|
|
maxTokens := req.MaxTokens |
|
|
|
if maxTokens <= 0 { |
|
|
|
maxTokens = aiModel.MaxTokens |
|
|
|
} |
|
|
|
|
|
|
|
temperature := req.Temperature |
|
|
|
if temperature <= 0 { |
|
|
|
temperature = 0.7 |
|
|
|
} |
|
|
|
|
|
|
|
return &provider.ChatRequest{ |
|
|
|
Model: req.Model, |
|
|
|
Messages: messages, |
|
|
|
MaxTokens: maxTokens, |
|
|
|
Temperature: temperature, |
|
|
|
Stream: req.Stream, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// saveMessages saves user and assistant messages to conversation
|
|
|
|
func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) { |
|
|
|
actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel) |
|
|
|
|
|
|
|
// Save user message
|
|
|
|
if len(req.Messages) > 0 { |
|
|
|
lastMsg := req.Messages[len(req.Messages)-1] |
|
|
|
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ |
|
|
|
ConversationId: req.ConversationId, |
|
|
|
Role: lastMsg.Role, |
|
|
|
Content: lastMsg.Content, |
|
|
|
ModelId: req.Model, |
|
|
|
}) |
|
|
|
} |
|
|
|
|
|
|
|
// Save assistant message
|
|
|
|
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{ |
|
|
|
ConversationId: req.ConversationId, |
|
|
|
Role: "assistant", |
|
|
|
Content: resp.Content, |
|
|
|
TokenCount: resp.OutputTokens, |
|
|
|
Cost: actualCost, |
|
|
|
ModelId: req.Model, |
|
|
|
LatencyMs: latencyMs, |
|
|
|
}) |
|
|
|
|
|
|
|
return nil |
|
|
|
// Update conversation stats
|
|
|
|
totalTokens := int64(resp.InputTokens + resp.OutputTokens) |
|
|
|
billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost) |
|
|
|
} |
|
|
|
|