diff --git a/backend/internal/logic/ai/aiconversationcreatelogic.go b/backend/internal/logic/ai/aiconversationcreatelogic.go index 1ee8c94..68461f5 100644 --- a/backend/internal/logic/ai/aiconversationcreatelogic.go +++ b/backend/internal/logic/ai/aiconversationcreatelogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,41 @@ func NewAiConversationCreateLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *AiConversationCreateLogic) AiConversationCreate(req *types.AIConversationCreateRequest) (resp *types.AIConversationInfo, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + title := req.Title + if title == "" { + title = "新对话" + } + + conv := &model.AIConversation{ + UserId: userId, + Title: title, + ModelId: req.ModelId, + } + + // Look up model to get provider ID + if req.ModelId != "" { + aiModel, err := model.AIModelFindByModelId(l.ctx, l.svcCtx.DB, req.ModelId) + if err == nil { + conv.ProviderId = aiModel.ProviderId + } + } + + _, err = model.AIConversationInsert(l.ctx, l.svcCtx.DB, conv) + if err != nil { + return nil, err + } - return + return &types.AIConversationInfo{ + Id: conv.Id, + Title: conv.Title, + ModelId: conv.ModelId, + ProviderId: conv.ProviderId, + CreatedAt: conv.CreatedAt.Format("2006-01-02 15:04:05"), + UpdatedAt: conv.UpdatedAt.Format("2006-01-02 15:04:05"), + }, nil } diff --git a/backend/internal/logic/ai/aiconversationdeletelogic.go b/backend/internal/logic/ai/aiconversationdeletelogic.go index 7eeb69f..987f1d4 100644 --- a/backend/internal/logic/ai/aiconversationdeletelogic.go +++ b/backend/internal/logic/ai/aiconversationdeletelogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,22 @@ func NewAiConversationDeleteLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *AiConversationDeleteLogic) AiConversationDelete(req *types.AIConversationDeleteRequest) (resp *types.Response, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + conv, err := model.AIConversationFindOne(l.ctx, l.svcCtx.DB, req.Id) + if err != nil { + return nil, err + } + if conv.UserId != userId { + return nil, errors.New("forbidden") + } + + if err := model.AIConversationDelete(l.ctx, l.svcCtx.DB, req.Id); err != nil { + return nil, err + } - return + return &types.Response{Success: true}, nil } diff --git a/backend/internal/logic/ai/aiconversationgetlogic.go b/backend/internal/logic/ai/aiconversationgetlogic.go index 936449e..06b2ea3 100644 --- a/backend/internal/logic/ai/aiconversationgetlogic.go +++ b/backend/internal/logic/ai/aiconversationgetlogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,52 @@ func NewAiConversationGetLogic(ctx context.Context, svcCtx *svc.ServiceContext) } func (l *AiConversationGetLogic) AiConversationGet(req *types.AIConversationGetRequest) (resp *types.AIConversationDetailResponse, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + conv, err := model.AIConversationFindOne(l.ctx, l.svcCtx.DB, req.Id) + if err != nil { + return nil, err + } + if conv.UserId != userId { + return nil, errors.New("forbidden") + } + + // Get messages + messages, err := model.AIChatMessageFindByConversation(l.ctx, l.svcCtx.DB, conv.Id) + if err != nil { + return nil, err + } + + msgList := make([]types.AIMessageInfo, len(messages)) + for i, m := range messages { + msgList[i] = types.AIMessageInfo{ + Id: m.Id, + ConversationId: m.ConversationId, + Role: m.Role, + Content: m.Content, + TokenCount: m.TokenCount, + Cost: m.Cost, + ModelId: m.ModelId, + LatencyMs: m.LatencyMs, + CreatedAt: m.CreatedAt.Format("2006-01-02 15:04:05"), + } + } - return + return &types.AIConversationDetailResponse{ + Conversation: types.AIConversationInfo{ + Id: conv.Id, + Title: conv.Title, + ModelId: conv.ModelId, + ProviderId: conv.ProviderId, + TotalTokens: conv.TotalTokens, + TotalCost: conv.TotalCost, + IsArchived: conv.IsArchived, + CreatedAt: conv.CreatedAt.Format("2006-01-02 15:04:05"), + UpdatedAt: conv.UpdatedAt.Format("2006-01-02 15:04:05"), + }, + Messages: msgList, + }, nil } diff --git a/backend/internal/logic/ai/aiconversationlistlogic.go b/backend/internal/logic/ai/aiconversationlistlogic.go index de650ad..52a7295 100644 --- a/backend/internal/logic/ai/aiconversationlistlogic.go +++ b/backend/internal/logic/ai/aiconversationlistlogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,30 @@ func NewAiConversationListLogic(ctx context.Context, svcCtx *svc.ServiceContext) } func (l *AiConversationListLogic) AiConversationList(req *types.AIConversationListRequest) (resp *types.AIConversationListResponse, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + conversations, total, err := model.AIConversationFindByUser(l.ctx, l.svcCtx.DB, userId, req.Page, req.PageSize) + if err != nil { + return nil, err + } + + list := make([]types.AIConversationInfo, len(conversations)) + for i, c := range conversations { + list[i] = types.AIConversationInfo{ + Id: c.Id, + Title: c.Title, + ModelId: c.ModelId, + ProviderId: c.ProviderId, + TotalTokens: c.TotalTokens, + TotalCost: c.TotalCost, + IsArchived: c.IsArchived, + CreatedAt: c.CreatedAt.Format("2006-01-02 15:04:05"), + UpdatedAt: c.UpdatedAt.Format("2006-01-02 15:04:05"), + } + } - return + return &types.AIConversationListResponse{List: list, Total: total}, nil } diff --git a/backend/internal/logic/ai/aiconversationupdatelogic.go b/backend/internal/logic/ai/aiconversationupdatelogic.go index 724d1f8..7418e12 100644 --- a/backend/internal/logic/ai/aiconversationupdatelogic.go +++ b/backend/internal/logic/ai/aiconversationupdatelogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,33 @@ func NewAiConversationUpdateLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *AiConversationUpdateLogic) AiConversationUpdate(req *types.AIConversationUpdateRequest) (resp *types.AIConversationInfo, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + conv, err := model.AIConversationFindOne(l.ctx, l.svcCtx.DB, req.Id) + if err != nil { + return nil, err + } + if conv.UserId != userId { + return nil, errors.New("forbidden") + } + + conv.Title = req.Title + if err := model.AIConversationUpdate(l.ctx, l.svcCtx.DB, conv); err != nil { + return nil, err + } - return + return &types.AIConversationInfo{ + Id: conv.Id, + Title: conv.Title, + ModelId: conv.ModelId, + ProviderId: conv.ProviderId, + TotalTokens: conv.TotalTokens, + TotalCost: conv.TotalCost, + IsArchived: conv.IsArchived, + CreatedAt: conv.CreatedAt.Format("2006-01-02 15:04:05"), + UpdatedAt: conv.UpdatedAt.Format("2006-01-02 15:04:05"), + }, nil } diff --git a/backend/internal/logic/ai/aimodellistlogic.go b/backend/internal/logic/ai/aimodellistlogic.go index 8d1295f..79ecd9c 100644 --- a/backend/internal/logic/ai/aimodellistlogic.go +++ b/backend/internal/logic/ai/aimodellistlogic.go @@ -8,6 +8,7 @@ import ( "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +29,34 @@ func NewAiModelListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiMod } func (l *AiModelListLogic) AiModelList() (resp *types.AIModelListResponse, err error) { - // todo: add your logic here and delete this line + models, err := model.AIModelFindAllActive(l.ctx, l.svcCtx.DB) + if err != nil { + return nil, err + } + + // Build provider name map + providerNames := make(map[int64]string) + providers, _ := model.AIProviderFindAllActive(l.ctx, l.svcCtx.DB) + for _, p := range providers { + providerNames[p.Id] = p.DisplayName + } + + list := make([]types.AIModelInfo, len(models)) + for i, m := range models { + list[i] = types.AIModelInfo{ + Id: m.Id, + ProviderId: m.ProviderId, + ProviderName: providerNames[m.ProviderId], + ModelId: m.ModelId, + DisplayName: m.DisplayName, + InputPrice: m.InputPrice, + OutputPrice: m.OutputPrice, + MaxTokens: m.MaxTokens, + ContextWindow: m.ContextWindow, + SupportsStream: m.SupportsStream, + SupportsVision: m.SupportsVision, + } + } - return + return &types.AIModelListResponse{List: list}, nil } diff --git a/backend/internal/logic/ai/aiquotamelogic.go b/backend/internal/logic/ai/aiquotamelogic.go index 4b3b0fc..e22c1d2 100644 --- a/backend/internal/logic/ai/aiquotamelogic.go +++ b/backend/internal/logic/ai/aiquotamelogic.go @@ -5,9 +5,11 @@ package ai import ( "context" + "errors" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/types" + "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,7 +30,20 @@ func NewAiQuotaMeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AiQuota } func (l *AiQuotaMeLogic) AiQuotaMe() (resp *types.AIQuotaInfo, err error) { - // todo: add your logic here and delete this line + userId, _ := l.ctx.Value("userId").(int64) + if userId == 0 { + return nil, errors.New("unauthorized") + } + + quota, err := model.AIUserQuotaEnsure(l.ctx, l.svcCtx.DB, userId) + if err != nil { + return nil, err + } - return + return &types.AIQuotaInfo{ + Balance: quota.Balance, + TotalRecharged: quota.TotalRecharged, + TotalConsumed: quota.TotalConsumed, + FrozenAmount: quota.FrozenAmount, + }, nil }