From 82bfcb7592acc2ede37af07fac123d4d5da2300e Mon Sep 17 00:00:00 2001 From: dark Date: Sat, 14 Feb 2026 22:08:26 +0800 Subject: [PATCH] feat: add AI model CRUD functions (7 models) - Provider/Model/ApiKey: standard CRUD + special lookups - Conversation: user-scoped pagination - ChatMessage: conversation-ordered retrieval - UsageRecord: insert + user-scoped pagination - UserQuota: freeze/settle/unfreeze atomic operations --- backend/model/ai_api_key_model.go | 87 +++++++++++++++++++++++ backend/model/ai_chat_message_model.go | 27 +++++++ backend/model/ai_conversation_model.go | 65 +++++++++++++++++ backend/model/ai_model_model.go | 98 ++++++++++++++++++++++++++ backend/model/ai_provider_model.go | 88 +++++++++++++++++++++++ backend/model/ai_usage_record_model.go | 39 ++++++++++ backend/model/ai_user_quota_model.go | 91 ++++++++++++++++++++++++ 7 files changed, 495 insertions(+) create mode 100644 backend/model/ai_api_key_model.go create mode 100644 backend/model/ai_chat_message_model.go create mode 100644 backend/model/ai_conversation_model.go create mode 100644 backend/model/ai_model_model.go create mode 100644 backend/model/ai_provider_model.go create mode 100644 backend/model/ai_usage_record_model.go create mode 100644 backend/model/ai_user_quota_model.go diff --git a/backend/model/ai_api_key_model.go b/backend/model/ai_api_key_model.go new file mode 100644 index 0000000..dfb56d9 --- /dev/null +++ b/backend/model/ai_api_key_model.go @@ -0,0 +1,87 @@ +package model + +import ( + "context" + "errors" + + "gorm.io/gorm" +) + +// AIApiKeyInsert 插入AI API密钥 +func AIApiKeyInsert(ctx context.Context, db *gorm.DB, apiKey *AIApiKey) (int64, error) { + result := db.WithContext(ctx).Create(apiKey) + if result.Error != nil { + return 0, result.Error + } + return apiKey.Id, nil +} + +// AIApiKeyFindOne 根据ID查询AI API密钥 +func AIApiKeyFindOne(ctx context.Context, db *gorm.DB, id int64) (*AIApiKey, error) { + var apiKey AIApiKey + result := db.WithContext(ctx).First(&apiKey, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &apiKey, nil +} + +// AIApiKeyFindList 查询AI API密钥列表(分页) +func AIApiKeyFindList(ctx context.Context, db *gorm.DB, page, pageSize int64) ([]AIApiKey, int64, error) { + var apiKeys []AIApiKey + var total int64 + + query := db.WithContext(ctx).Model(&AIApiKey{}) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + if offset < 0 { + offset = 0 + } + err := query.Order("created_at DESC").Offset(int(offset)).Limit(int(pageSize)).Find(&apiKeys).Error + if err != nil { + return nil, 0, err + } + + return apiKeys, total, nil +} + +// AIApiKeyUpdate 更新AI API密钥 +func AIApiKeyUpdate(ctx context.Context, db *gorm.DB, apiKey *AIApiKey) error { + result := db.WithContext(ctx).Save(apiKey) + return result.Error +} + +// AIApiKeyDelete 删除AI API密钥 +func AIApiKeyDelete(ctx context.Context, db *gorm.DB, id int64) error { + result := db.WithContext(ctx).Delete(&AIApiKey{}, id) + return result.Error +} + +// AIApiKeyFindByProviderAndUser 根据供应商ID和用户ID查询API密钥 +func AIApiKeyFindByProviderAndUser(ctx context.Context, db *gorm.DB, providerId, userId int64) ([]AIApiKey, error) { + var apiKeys []AIApiKey + err := db.WithContext(ctx).Where("provider_id = ? AND user_id = ?", providerId, userId). + Order("created_at DESC").Find(&apiKeys).Error + if err != nil { + return nil, err + } + return apiKeys, nil +} + +// AIApiKeyFindSystemKeys 查询系统级API密钥(userId=0为系统密钥) +func AIApiKeyFindSystemKeys(ctx context.Context, db *gorm.DB, providerId int64) ([]AIApiKey, error) { + var apiKeys []AIApiKey + err := db.WithContext(ctx).Where("provider_id = ? AND user_id = 0", providerId). + Order("created_at DESC").Find(&apiKeys).Error + if err != nil { + return nil, err + } + return apiKeys, nil +} diff --git a/backend/model/ai_chat_message_model.go b/backend/model/ai_chat_message_model.go new file mode 100644 index 0000000..fba8671 --- /dev/null +++ b/backend/model/ai_chat_message_model.go @@ -0,0 +1,27 @@ +package model + +import ( + "context" + + "gorm.io/gorm" +) + +// AIChatMessageInsert 插入AI聊天消息 +func AIChatMessageInsert(ctx context.Context, db *gorm.DB, message *AIChatMessage) (int64, error) { + result := db.WithContext(ctx).Create(message) + if result.Error != nil { + return 0, result.Error + } + return message.Id, nil +} + +// AIChatMessageFindByConversation 根据对话ID查询所有消息(按创建时间升序,不分页) +func AIChatMessageFindByConversation(ctx context.Context, db *gorm.DB, conversationId int64) ([]AIChatMessage, error) { + var messages []AIChatMessage + err := db.WithContext(ctx).Where("conversation_id = ?", conversationId). + Order("created_at ASC").Find(&messages).Error + if err != nil { + return nil, err + } + return messages, nil +} diff --git a/backend/model/ai_conversation_model.go b/backend/model/ai_conversation_model.go new file mode 100644 index 0000000..a41f392 --- /dev/null +++ b/backend/model/ai_conversation_model.go @@ -0,0 +1,65 @@ +package model + +import ( + "context" + "errors" + + "gorm.io/gorm" +) + +// AIConversationInsert 插入AI对话 +func AIConversationInsert(ctx context.Context, db *gorm.DB, conversation *AIConversation) (int64, error) { + result := db.WithContext(ctx).Create(conversation) + if result.Error != nil { + return 0, result.Error + } + return conversation.Id, nil +} + +// AIConversationFindOne 根据ID查询AI对话 +func AIConversationFindOne(ctx context.Context, db *gorm.DB, id int64) (*AIConversation, error) { + var conversation AIConversation + result := db.WithContext(ctx).First(&conversation, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &conversation, nil +} + +// AIConversationUpdate 更新AI对话 +func AIConversationUpdate(ctx context.Context, db *gorm.DB, conversation *AIConversation) error { + result := db.WithContext(ctx).Save(conversation) + return result.Error +} + +// AIConversationDelete 删除AI对话 +func AIConversationDelete(ctx context.Context, db *gorm.DB, id int64) error { + result := db.WithContext(ctx).Delete(&AIConversation{}, id) + return result.Error +} + +// AIConversationFindByUser 根据用户ID查询对话列表(分页,按更新时间倒序) +func AIConversationFindByUser(ctx context.Context, db *gorm.DB, userId int64, page, pageSize int64) ([]AIConversation, int64, error) { + var conversations []AIConversation + var total int64 + + query := db.WithContext(ctx).Model(&AIConversation{}).Where("user_id = ?", userId) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + if offset < 0 { + offset = 0 + } + err := query.Order("updated_at DESC").Offset(int(offset)).Limit(int(pageSize)).Find(&conversations).Error + if err != nil { + return nil, 0, err + } + + return conversations, total, nil +} diff --git a/backend/model/ai_model_model.go b/backend/model/ai_model_model.go new file mode 100644 index 0000000..c2e4098 --- /dev/null +++ b/backend/model/ai_model_model.go @@ -0,0 +1,98 @@ +package model + +import ( + "context" + "errors" + + "gorm.io/gorm" +) + +// AIModelInsert 插入AI模型 +func AIModelInsert(ctx context.Context, db *gorm.DB, aiModel *AIModel) (int64, error) { + result := db.WithContext(ctx).Create(aiModel) + if result.Error != nil { + return 0, result.Error + } + return aiModel.Id, nil +} + +// AIModelFindOne 根据ID查询AI模型 +func AIModelFindOne(ctx context.Context, db *gorm.DB, id int64) (*AIModel, error) { + var aiModel AIModel + result := db.WithContext(ctx).First(&aiModel, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &aiModel, nil +} + +// AIModelFindList 查询AI模型列表(分页) +func AIModelFindList(ctx context.Context, db *gorm.DB, page, pageSize int64) ([]AIModel, int64, error) { + var models []AIModel + var total int64 + + query := db.WithContext(ctx).Model(&AIModel{}) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + if offset < 0 { + offset = 0 + } + err := query.Order("sort_order ASC, id ASC").Offset(int(offset)).Limit(int(pageSize)).Find(&models).Error + if err != nil { + return nil, 0, err + } + + return models, total, nil +} + +// AIModelUpdate 更新AI模型 +func AIModelUpdate(ctx context.Context, db *gorm.DB, aiModel *AIModel) error { + result := db.WithContext(ctx).Save(aiModel) + return result.Error +} + +// AIModelDelete 删除AI模型 +func AIModelDelete(ctx context.Context, db *gorm.DB, id int64) error { + result := db.WithContext(ctx).Delete(&AIModel{}, id) + return result.Error +} + +// AIModelFindByModelId 根据模型标识查询AI模型 +func AIModelFindByModelId(ctx context.Context, db *gorm.DB, modelId string) (*AIModel, error) { + var aiModel AIModel + result := db.WithContext(ctx).Where("model_id = ?", modelId).First(&aiModel) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &aiModel, nil +} + +// AIModelFindByProvider 根据供应商ID查询AI模型列表 +func AIModelFindByProvider(ctx context.Context, db *gorm.DB, providerId int64) ([]AIModel, error) { + var models []AIModel + err := db.WithContext(ctx).Where("provider_id = ?", providerId).Order("sort_order ASC, id ASC").Find(&models).Error + if err != nil { + return nil, err + } + return models, nil +} + +// AIModelFindAllActive 查询所有启用的AI模型 +func AIModelFindAllActive(ctx context.Context, db *gorm.DB) ([]AIModel, error) { + var models []AIModel + err := db.WithContext(ctx).Where("is_active = ?", true).Order("sort_order ASC, id ASC").Find(&models).Error + if err != nil { + return nil, err + } + return models, nil +} diff --git a/backend/model/ai_provider_model.go b/backend/model/ai_provider_model.go new file mode 100644 index 0000000..f5be5b5 --- /dev/null +++ b/backend/model/ai_provider_model.go @@ -0,0 +1,88 @@ +package model + +import ( + "context" + "errors" + + "gorm.io/gorm" +) + +// AIProviderInsert 插入AI供应商 +func AIProviderInsert(ctx context.Context, db *gorm.DB, provider *AIProvider) (int64, error) { + result := db.WithContext(ctx).Create(provider) + if result.Error != nil { + return 0, result.Error + } + return provider.Id, nil +} + +// AIProviderFindOne 根据ID查询AI供应商 +func AIProviderFindOne(ctx context.Context, db *gorm.DB, id int64) (*AIProvider, error) { + var provider AIProvider + result := db.WithContext(ctx).First(&provider, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &provider, nil +} + +// AIProviderFindList 查询AI供应商列表(分页) +func AIProviderFindList(ctx context.Context, db *gorm.DB, page, pageSize int64) ([]AIProvider, int64, error) { + var providers []AIProvider + var total int64 + + query := db.WithContext(ctx).Model(&AIProvider{}) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + if offset < 0 { + offset = 0 + } + err := query.Order("sort_order ASC, id ASC").Offset(int(offset)).Limit(int(pageSize)).Find(&providers).Error + if err != nil { + return nil, 0, err + } + + return providers, total, nil +} + +// AIProviderUpdate 更新AI供应商 +func AIProviderUpdate(ctx context.Context, db *gorm.DB, provider *AIProvider) error { + result := db.WithContext(ctx).Save(provider) + return result.Error +} + +// AIProviderDelete 删除AI供应商 +func AIProviderDelete(ctx context.Context, db *gorm.DB, id int64) error { + result := db.WithContext(ctx).Delete(&AIProvider{}, id) + return result.Error +} + +// AIProviderFindByName 根据名称查询AI供应商 +func AIProviderFindByName(ctx context.Context, db *gorm.DB, name string) (*AIProvider, error) { + var provider AIProvider + result := db.WithContext(ctx).Where("name = ?", name).First(&provider) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return &provider, nil +} + +// AIProviderFindAllActive 查询所有启用的AI供应商 +func AIProviderFindAllActive(ctx context.Context, db *gorm.DB) ([]AIProvider, error) { + var providers []AIProvider + err := db.WithContext(ctx).Where("is_active = ?", true).Order("sort_order ASC, id ASC").Find(&providers).Error + if err != nil { + return nil, err + } + return providers, nil +} diff --git a/backend/model/ai_usage_record_model.go b/backend/model/ai_usage_record_model.go new file mode 100644 index 0000000..e6ea645 --- /dev/null +++ b/backend/model/ai_usage_record_model.go @@ -0,0 +1,39 @@ +package model + +import ( + "context" + + "gorm.io/gorm" +) + +// AIUsageRecordInsert 插入AI使用记录 +func AIUsageRecordInsert(ctx context.Context, db *gorm.DB, record *AIUsageRecord) (int64, error) { + result := db.WithContext(ctx).Create(record) + if result.Error != nil { + return 0, result.Error + } + return record.Id, nil +} + +// AIUsageRecordFindByUser 根据用户ID查询使用记录(分页,按创建时间倒序) +func AIUsageRecordFindByUser(ctx context.Context, db *gorm.DB, userId int64, page, pageSize int64) ([]AIUsageRecord, int64, error) { + var records []AIUsageRecord + var total int64 + + query := db.WithContext(ctx).Model(&AIUsageRecord{}).Where("user_id = ?", userId) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + if offset < 0 { + offset = 0 + } + err := query.Order("created_at DESC").Offset(int(offset)).Limit(int(pageSize)).Find(&records).Error + if err != nil { + return nil, 0, err + } + + return records, total, nil +} diff --git a/backend/model/ai_user_quota_model.go b/backend/model/ai_user_quota_model.go new file mode 100644 index 0000000..721254c --- /dev/null +++ b/backend/model/ai_user_quota_model.go @@ -0,0 +1,91 @@ +package model + +import ( + "context" + "errors" + + "gorm.io/gorm" +) + +// AIUserQuotaFindByUser 根据用户ID查询额度 +func AIUserQuotaFindByUser(ctx context.Context, db *gorm.DB, userId int64) (*AIUserQuota, error) { + var quota AIUserQuota + result := db.WithContext(ctx).Where("user_id = ?", userId).First("a) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, result.Error + } + return "a, nil +} + +// AIUserQuotaEnsure 查找或创建用户额度记录 +func AIUserQuotaEnsure(ctx context.Context, db *gorm.DB, userId int64) (*AIUserQuota, error) { + var quota AIUserQuota + result := db.WithContext(ctx).Where("user_id = ?", userId).First("a) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + quota = AIUserQuota{UserId: userId} + if err := db.WithContext(ctx).Create("a).Error; err != nil { + return nil, err + } + return "a, nil + } + return nil, result.Error + } + return "a, nil +} + +// AIUserQuotaFreeze 冻结额度(原子操作:balance -= amount, frozen_amount += amount) +func AIUserQuotaFreeze(ctx context.Context, db *gorm.DB, userId int64, amount float64) error { + result := db.WithContext(ctx).Model(&AIUserQuota{}). + Where("user_id = ? AND balance >= ?", userId, amount). + Updates(map[string]interface{}{ + "balance": gorm.Expr("balance - ?", amount), + "frozen_amount": gorm.Expr("frozen_amount + ?", amount), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errors.New("insufficient balance") + } + return nil +} + +// AIUserQuotaSettle 结算额度(原子操作:frozen_amount -= frozenAmount, total_consumed += actualCost, balance += refund) +func AIUserQuotaSettle(ctx context.Context, db *gorm.DB, userId int64, frozenAmount, actualCost float64) error { + refund := frozenAmount - actualCost + result := db.WithContext(ctx).Model(&AIUserQuota{}). + Where("user_id = ?", userId). + Updates(map[string]interface{}{ + "frozen_amount": gorm.Expr("frozen_amount - ?", frozenAmount), + "total_consumed": gorm.Expr("total_consumed + ?", actualCost), + "balance": gorm.Expr("balance + ?", refund), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return ErrNotFound + } + return nil +} + +// AIUserQuotaUnfreeze 解冻额度(原子操作:frozen_amount -= amount, balance += amount) +func AIUserQuotaUnfreeze(ctx context.Context, db *gorm.DB, userId int64, amount float64) error { + result := db.WithContext(ctx).Model(&AIUserQuota{}). + Where("user_id = ?", userId). + Updates(map[string]interface{}{ + "frozen_amount": gorm.Expr("frozen_amount - ?", amount), + "balance": gorm.Expr("balance + ?", amount), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return ErrNotFound + } + return nil +}