package auth import ( "context" "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/youruser/base/internal/svc" "github.com/youruser/base/internal/util/jwt" "github.com/youruser/base/model" "github.com/zeromicro/go-zero/core/logx" ) // casdoorHttpClient 用于与 Casdoor 通信的 HTTP 客户端(带超时) var casdoorHttpClient = &http.Client{Timeout: 10 * time.Second} type SSOLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewSSOLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SSOLogic { return &SSOLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } // GetLoginUrl 生成 Casdoor SSO 登录链接 func (l *SSOLogic) GetLoginUrl() (map[string]string, error) { c := l.svcCtx.Config.Casdoor state, err := generateState() if err != nil { return nil, fmt.Errorf("生成 state 失败: %v", err) } loginUrl := fmt.Sprintf("%s/login/oauth/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=read&state=%s", c.Endpoint, url.QueryEscape(c.ClientId), url.QueryEscape(c.RedirectUrl), url.QueryEscape(state), ) return map[string]string{ "login_url": loginUrl, }, nil } // casdoorTokenResponse Casdoor token 响应 type casdoorTokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` } // casdoorUserInfo Casdoor 用户信息 type casdoorUserInfo struct { Sub string `json:"sub"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` Email string `json:"email"` Phone string `json:"phone"` Avatar string `json:"avatar"` } // HandleCallback 处理 SSO 回调 func (l *SSOLogic) HandleCallback(code, state string) (string, error) { if code == "" { return "", fmt.Errorf("缺少授权码") } c := l.svcCtx.Config.Casdoor // 1. 用 code 换取 access_token accessToken, err := l.exchangeToken(code) if err != nil { l.Errorf("SSO token 交换失败: %v", err) return "", fmt.Errorf("token 交换失败: %v", err) } // 2. 获取用户信息 userInfo, err := l.getUserInfo(accessToken) if err != nil { l.Errorf("SSO 获取用户信息失败: %v", err) return "", fmt.Errorf("获取用户信息失败: %v", err) } // 3. 查找或创建本地用户 casdoorId := userInfo.Sub if casdoorId == "" { casdoorId = userInfo.Name } // 从 Casdoor 信息中提取用户名和邮箱 username := userInfo.PreferredUsername if username == "" { username = userInfo.Name } email := userInfo.Email if email == "" { email = username + "@sso.local" } localUser, err := model.FindOneByCasdoorId(l.ctx, l.svcCtx.DB, casdoorId) if err != nil { if err == model.ErrNotFound { // 用户不存在,尝试通过邮箱关联已有本地用户 existingUser, findErr := model.FindOneByEmail(l.ctx, l.svcCtx.DB, email) if findErr == nil { existingUser.CasdoorId = casdoorId existingUser.UserType = "casdoor" if updateErr := model.Update(l.ctx, l.svcCtx.DB, existingUser); updateErr != nil { return "", fmt.Errorf("关联用户失败: %v", updateErr) } localUser = existingUser l.Infof("SSO 关联已有用户: userId=%d, casdoorId=%s", existingUser.Id, casdoorId) } else { // 创建新用户 newUser := &model.User{ Username: username, Email: email, Password: "SSO_NO_PASSWORD", // SSO 用户不使用密码登录 Phone: userInfo.Phone, CasdoorId: casdoorId, UserType: "casdoor", Role: model.RoleUser, Source: model.SourceCasdoor, Status: 1, } _, insertErr := model.Insert(l.ctx, l.svcCtx.DB, newUser) if insertErr != nil { l.Errorf("SSO 创建用户失败: %v", insertErr) return "", fmt.Errorf("创建用户失败: %v", insertErr) } localUser = newUser l.Infof("SSO 新用户创建成功: username=%s, casdoorId=%s", username, casdoorId) } } else { return "", fmt.Errorf("查询用户失败: %v", err) } } else { // 已有用户,同步更新 Casdoor 端的最新信息 updated := false if username != "" && localUser.Username != username { localUser.Username = username updated = true } if email != "" && localUser.Email != email { localUser.Email = email updated = true } if userInfo.Phone != "" && localUser.Phone != userInfo.Phone { localUser.Phone = userInfo.Phone updated = true } if updated { if updateErr := model.Update(l.ctx, l.svcCtx.DB, localUser); updateErr != nil { l.Errorf("SSO 同步用户信息失败: %v", updateErr) } } } // 4. 生成本地 JWT Token token, err := jwt.GenerateToken(localUser.Id, localUser.Username, localUser.Role) if err != nil { return "", fmt.Errorf("生成 Token 失败: %v", err) } l.Infof("SSO 登录成功: userId=%d, username=%s", localUser.Id, localUser.Username) // 5. 构建前端回调 URL redirectUrl := fmt.Sprintf("%s/sso/callback?token=%s", c.FrontendUrl, url.QueryEscape(token), ) return redirectUrl, nil } // exchangeToken 用授权码换取 access_token func (l *SSOLogic) exchangeToken(code string) (string, error) { c := l.svcCtx.Config.Casdoor tokenUrl := fmt.Sprintf("%s/api/login/oauth/access_token", c.Endpoint) data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("client_id", c.ClientId) data.Set("client_secret", c.ClientSecret) data.Set("code", code) data.Set("redirect_uri", c.RedirectUrl) req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, tokenUrl, strings.NewReader(data.Encode())) if err != nil { return "", fmt.Errorf("创建请求失败: %v", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := casdoorHttpClient.Do(req) if err != nil { return "", fmt.Errorf("请求 token 失败: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("读取响应失败: %v", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("token 请求返回 %d: %s", resp.StatusCode, string(body)) } var tokenResp casdoorTokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return "", fmt.Errorf("解析 token 响应失败: %v", err) } if tokenResp.AccessToken == "" { return "", fmt.Errorf("未获取到 access_token, 响应: %s", string(body)) } return tokenResp.AccessToken, nil } // getUserInfo 从 access_token JWT 中解析用户信息 // Casdoor 的 access_token 本身是一个 JWT,包含完整的用户 claims func (l *SSOLogic) getUserInfo(accessToken string) (*casdoorUserInfo, error) { // 解析 JWT payload(不验证签名,因为 token 刚从 Casdoor 获取) parts := strings.Split(accessToken, ".") if len(parts) != 3 { return nil, fmt.Errorf("access_token 不是有效的 JWT 格式") } // Base64 解码 payload payload := parts[1] // 补齐 base64 padding switch len(payload) % 4 { case 2: payload += "==" case 3: payload += "=" } decoded, err := base64.URLEncoding.DecodeString(payload) if err != nil { return nil, fmt.Errorf("解码 JWT payload 失败: %v", err) } // 解析 JWT claims var claims map[string]interface{} if err := json.Unmarshal(decoded, &claims); err != nil { return nil, fmt.Errorf("解析 JWT claims 失败: %v", err) } // 从 claims 中提取用户信息(Casdoor JWT 字段名) userInfo := &casdoorUserInfo{ Sub: getStringClaim(claims, "sub"), } // Casdoor JWT 中用户名可能在 name 或 preferred_username 字段 userInfo.Name = getStringClaim(claims, "name") userInfo.PreferredUsername = getStringClaim(claims, "preferred_username") userInfo.Email = getStringClaim(claims, "email") userInfo.Phone = getStringClaim(claims, "phone") userInfo.Avatar = getStringClaim(claims, "avatar") return userInfo, nil } // getStringClaim 从 claims map 中安全获取字符串值 func getStringClaim(claims map[string]interface{}, key string) string { if v, ok := claims[key]; ok { if s, ok := v.(string); ok { return s } } return "" } // generateState 生成随机 state 参数(CSRF 防护) func generateState() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil }