Browse Source

feat: implement AI chat completions with SSE streaming

- Custom SSE handler (replaces goctl stub)
- Chat() for non-streaming, ChatStream() for SSE
- Provider selection (user key > system key)
- Quota freeze/settle/unfreeze billing flow
- Usage recording + conversation message saving
master
dark 1 month ago
parent
commit
f1d1549595
  1. 44
      backend/internal/handler/ai/aichatcompletionshandler.go
  2. 327
      backend/internal/logic/ai/aichatcompletionslogic.go

44
backend/internal/handler/ai/aichatcompletionshandler.go

@ -1,9 +1,8 @@
// Code scaffolded by goctl. Safe to edit.
// goctl 1.9.2
package ai package ai
import ( import (
"encoding/json"
"fmt"
"net/http" "net/http"
"github.com/youruser/base/internal/logic/ai" "github.com/youruser/base/internal/logic/ai"
@ -12,21 +11,52 @@ import (
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
) )
// AI 对话补全
func AiChatCompletionsHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { func AiChatCompletionsHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.AIChatCompletionRequest var req types.AIChatCompletionRequest
if err := httpx.Parse(r, &req); err != nil { if err := httpx.ParseJsonBody(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) httpx.ErrorCtx(r.Context(), w, err)
return return
} }
l := ai.NewAiChatCompletionsLogic(r.Context(), svcCtx) l := ai.NewAiChatCompletionsLogic(r.Context(), svcCtx)
err := l.AiChatCompletions(&req)
if req.Stream {
// SSE streaming mode
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
streamChan, err := l.ChatStream(&req)
if err != nil {
errData, _ := json.Marshal(map[string]string{"error": err.Error()})
fmt.Fprintf(w, "data: %s\n\n", errData)
flusher.Flush()
return
}
for chunk := range streamChan {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
} else {
// Normal (non-streaming) mode
resp, err := l.Chat(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) httpx.ErrorCtx(r.Context(), w, err)
} else { } else {
httpx.Ok(w) httpx.OkJsonCtx(r.Context(), w, resp)
}
} }
} }
} }

327
backend/internal/logic/ai/aichatcompletionslogic.go

@ -1,13 +1,16 @@
// Code scaffolded by goctl. Safe to edit.
// goctl 1.9.2
package ai package ai
import ( import (
"context" "context"
"errors"
"fmt"
"time"
"github.com/youruser/base/internal/ai/billing"
"github.com/youruser/base/internal/ai/provider"
"github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/svc"
"github.com/youruser/base/internal/types" "github.com/youruser/base/internal/types"
"github.com/youruser/base/model"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -18,7 +21,6 @@ type AiChatCompletionsLogic struct {
svcCtx *svc.ServiceContext svcCtx *svc.ServiceContext
} }
// AI 对话补全
func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic { func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiChatCompletionsLogic {
return &AiChatCompletionsLogic{ return &AiChatCompletionsLogic{
Logger: logx.WithContext(ctx), Logger: logx.WithContext(ctx),
@ -27,8 +29,319 @@ func NewAiChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext)
} }
} }
func (l *AiChatCompletionsLogic) AiChatCompletions(req *types.AIChatCompletionRequest) error { // Chat handles non-streaming chat completions
// todo: add your logic here and delete this line func (l *AiChatCompletionsLogic) Chat(req *types.AIChatCompletionRequest) (*types.AIChatCompletionResponse, error) {
startTime := time.Now()
userId, _ := l.ctx.Value("userId").(int64)
if userId == 0 {
return nil, errors.New("unauthorized")
}
// 1. Look up model
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
if err != nil {
return nil, fmt.Errorf("model not found: %s", req.Model)
}
// 2. Look up provider
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
if err != nil {
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
}
// 3. Select API key
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
if err != nil {
return nil, err
}
// 4. Estimate cost and freeze
estimatedCost := l.estimateCost(req, aiModel)
quotaSvc := billing.NewQuotaService()
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
return nil, fmt.Errorf("insufficient balance: %v", err)
}
// 5. Build provider and call
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, err
}
// Convert messages
chatReq := l.buildChatRequest(req, aiModel)
chatResp, err := p.Chat(l.ctx, chatReq)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, fmt.Errorf("AI request failed: %v", err)
}
latencyMs := int(time.Since(startTime).Milliseconds())
// 6. Calculate actual cost and settle
actualCost := l.calculateCost(chatResp.InputTokens, chatResp.OutputTokens, aiModel)
quotaSvc.Settle(l.ctx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
// 7. Record usage
usageSvc := billing.NewUsageService()
usageSvc.Record(l.ctx, l.svcCtx.DB, &model.AIUsageRecord{
UserId: userId,
ProviderId: aiProvider.Id,
ModelId: req.Model,
InputTokens: chatResp.InputTokens,
OutputTokens: chatResp.OutputTokens,
Cost: actualCost,
ApiKeyId: apiKeyId,
Status: "ok",
LatencyMs: latencyMs,
})
// 8. Save messages to conversation if conversation_id provided
if req.ConversationId > 0 {
l.saveMessages(req, chatResp, aiModel, latencyMs)
}
// 9. Build response
return &types.AIChatCompletionResponse{
Id: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
Object: "chat.completion",
Model: req.Model,
Choices: []types.AIChatCompletionChoice{
{
Index: 0,
FinishReason: chatResp.FinishReason,
Message: types.AIChatMessage{
Role: "assistant",
Content: chatResp.Content,
},
},
},
Usage: types.AIChatCompletionUsage{
PromptTokens: chatResp.InputTokens,
CompletionTokens: chatResp.OutputTokens,
TotalTokens: chatResp.InputTokens + chatResp.OutputTokens,
},
}, nil
}
// ChatStream handles streaming chat completions
func (l *AiChatCompletionsLogic) ChatStream(req *types.AIChatCompletionRequest) (<-chan *provider.StreamChunk, error) {
userId, _ := l.ctx.Value("userId").(int64)
if userId == 0 {
return nil, errors.New("unauthorized")
}
// 1. Look up model
aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.Model)
if err != nil {
return nil, fmt.Errorf("model not found: %s", req.Model)
}
// 2. Look up provider
aiProvider, err := model.AIProviderFindOne(l.ctx, l.svcCtx.DB, aiModel.ProviderId)
if err != nil {
return nil, fmt.Errorf("provider not found for model: %s", req.Model)
}
// 3. Select API key
apiKey, apiKeyId, err := l.selectApiKey(aiProvider.Id, userId)
if err != nil {
return nil, err
}
// 4. Estimate and freeze
estimatedCost := l.estimateCost(req, aiModel)
quotaSvc := billing.NewQuotaService()
if err := quotaSvc.CheckAndFreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId); err != nil {
return nil, fmt.Errorf("insufficient balance: %v", err)
}
// 5. Build provider
p, err := provider.NewProvider(aiProvider.SdkType, aiProvider.BaseUrl, apiKey)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, err
}
// 6. Start stream
chatReq := l.buildChatRequest(req, aiModel)
startTime := time.Now()
streamChan, err := p.ChatStream(l.ctx, chatReq)
if err != nil {
quotaSvc.Unfreeze(l.ctx, l.svcCtx.DB, userId, estimatedCost, apiKeyId)
return nil, fmt.Errorf("AI stream failed: %v", err)
}
// 7. Wrap stream channel — accumulate content for post-processing
outChan := make(chan *provider.StreamChunk, 100)
go func() {
defer close(outChan)
var fullContent string
var inputTokens, outputTokens int
for chunk := range streamChan {
fullContent += chunk.Content
if chunk.InputTokens > 0 {
inputTokens = chunk.InputTokens
}
if chunk.OutputTokens > 0 {
outputTokens = chunk.OutputTokens
}
outChan <- chunk
}
latencyMs := int(time.Since(startTime).Milliseconds())
// Post-stream: calculate cost, settle billing, record usage
actualCost := l.calculateCost(inputTokens, outputTokens, aiModel)
bgCtx := context.Background()
quotaSvc.Settle(bgCtx, l.svcCtx.DB, userId, estimatedCost, actualCost, apiKeyId)
usageSvc := billing.NewUsageService()
usageSvc.Record(bgCtx, l.svcCtx.DB, &model.AIUsageRecord{
UserId: userId,
ProviderId: aiProvider.Id,
ModelId: req.Model,
InputTokens: inputTokens,
OutputTokens: outputTokens,
Cost: actualCost,
ApiKeyId: apiKeyId,
Status: "ok",
LatencyMs: latencyMs,
})
// Save messages
if req.ConversationId > 0 {
totalTokens := int64(inputTokens + outputTokens)
usageSvc.UpdateConversationStats(bgCtx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
// Save user message (last one in req.Messages)
if len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: lastMsg.Role,
Content: lastMsg.Content,
ModelId: req.Model,
})
}
// Save assistant message
model.AIChatMessageInsert(bgCtx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: "assistant",
Content: fullContent,
TokenCount: outputTokens,
Cost: actualCost,
ModelId: req.Model,
LatencyMs: latencyMs,
})
}
}()
return outChan, nil
}
// selectApiKey selects the best API key: user's own key first, then system key
func (l *AiChatCompletionsLogic) selectApiKey(providerId, userId int64) (string, int64, error) {
// Try user's own key first
userKeys, err := model.AIApiKeyFindByProviderAndUser(l.ctx, l.svcCtx.DB, providerId, userId)
if err == nil && len(userKeys) > 0 {
for _, key := range userKeys {
if key.IsActive {
return key.KeyValue, key.Id, nil
}
}
}
// Fall back to system key (userId=0)
systemKeys, err := model.AIApiKeyFindSystemKeys(l.ctx, l.svcCtx.DB, providerId)
if err != nil || len(systemKeys) == 0 {
return "", 0, fmt.Errorf("no API key available for this provider")
}
return systemKeys[0].KeyValue, 0, nil
}
// estimateCost estimates cost based on input message length
func (l *AiChatCompletionsLogic) estimateCost(req *types.AIChatCompletionRequest, aiModel *model.AIModel) float64 {
// Rough estimation: ~4 chars per token for input, estimate 1000 output tokens
totalChars := 0
for _, msg := range req.Messages {
totalChars += len(msg.Content)
}
estimatedInputTokens := totalChars / 4
estimatedOutputTokens := 1000
if req.MaxTokens > 0 {
estimatedOutputTokens = req.MaxTokens
}
cost := float64(estimatedInputTokens)/1000.0*aiModel.InputPrice +
float64(estimatedOutputTokens)/1000.0*aiModel.OutputPrice
return cost
}
// calculateCost computes actual cost from real token counts
func (l *AiChatCompletionsLogic) calculateCost(inputTokens, outputTokens int, aiModel *model.AIModel) float64 {
return float64(inputTokens)/1000.0*aiModel.InputPrice +
float64(outputTokens)/1000.0*aiModel.OutputPrice
}
// buildChatRequest converts types request to provider request
func (l *AiChatCompletionsLogic) buildChatRequest(req *types.AIChatCompletionRequest, aiModel *model.AIModel) *provider.ChatRequest {
messages := make([]provider.ChatMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = provider.ChatMessage{Role: m.Role, Content: m.Content}
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = aiModel.MaxTokens
}
temperature := req.Temperature
if temperature <= 0 {
temperature = 0.7
}
return &provider.ChatRequest{
Model: req.Model,
Messages: messages,
MaxTokens: maxTokens,
Temperature: temperature,
Stream: req.Stream,
}
}
// saveMessages saves user and assistant messages to conversation
func (l *AiChatCompletionsLogic) saveMessages(req *types.AIChatCompletionRequest, resp *provider.ChatResponse, aiModel *model.AIModel, latencyMs int) {
actualCost := l.calculateCost(resp.InputTokens, resp.OutputTokens, aiModel)
// Save user message
if len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: lastMsg.Role,
Content: lastMsg.Content,
ModelId: req.Model,
})
}
// Save assistant message
model.AIChatMessageInsert(l.ctx, l.svcCtx.DB, &model.AIChatMessage{
ConversationId: req.ConversationId,
Role: "assistant",
Content: resp.Content,
TokenCount: resp.OutputTokens,
Cost: actualCost,
ModelId: req.Model,
LatencyMs: latencyMs,
})
return nil // Update conversation stats
totalTokens := int64(resp.InputTokens + resp.OutputTokens)
billing.NewUsageService().UpdateConversationStats(l.ctx, l.svcCtx.DB, req.ConversationId, totalTokens, actualCost)
} }

Loading…
Cancel
Save