You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

304 lines
8.3 KiB

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/youruser/base/internal/svc"
"github.com/youruser/base/internal/util/jwt"
"github.com/youruser/base/model"
"github.com/zeromicro/go-zero/core/logx"
)
// casdoorHttpClient 用于与 Casdoor 通信的 HTTP 客户端(带超时)
var casdoorHttpClient = &http.Client{Timeout: 10 * time.Second}
type SSOLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewSSOLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SSOLogic {
return &SSOLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
// GetLoginUrl 生成 Casdoor SSO 登录链接
func (l *SSOLogic) GetLoginUrl() (map[string]string, error) {
c := l.svcCtx.Config.Casdoor
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %v", err)
}
loginUrl := fmt.Sprintf("%s/login/oauth/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=read&state=%s",
c.Endpoint,
url.QueryEscape(c.ClientId),
url.QueryEscape(c.RedirectUrl),
url.QueryEscape(state),
)
return map[string]string{
"login_url": loginUrl,
}, nil
}
// casdoorTokenResponse Casdoor token 响应
type casdoorTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
// casdoorUserInfo Casdoor 用户信息
type casdoorUserInfo struct {
Sub string `json:"sub"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Email string `json:"email"`
Phone string `json:"phone"`
Avatar string `json:"avatar"`
}
// HandleCallback 处理 SSO 回调
func (l *SSOLogic) HandleCallback(code, state string) (string, error) {
if code == "" {
return "", fmt.Errorf("缺少授权码")
}
c := l.svcCtx.Config.Casdoor
// 1. 用 code 换取 access_token
accessToken, err := l.exchangeToken(code)
if err != nil {
l.Errorf("SSO token 交换失败: %v", err)
return "", fmt.Errorf("token 交换失败: %v", err)
}
// 2. 获取用户信息
userInfo, err := l.getUserInfo(accessToken)
if err != nil {
l.Errorf("SSO 获取用户信息失败: %v", err)
return "", fmt.Errorf("获取用户信息失败: %v", err)
}
// 3. 查找或创建本地用户
casdoorId := userInfo.Sub
if casdoorId == "" {
casdoorId = userInfo.Name
}
// 从 Casdoor 信息中提取用户名和邮箱
username := userInfo.PreferredUsername
if username == "" {
username = userInfo.Name
}
email := userInfo.Email
if email == "" {
email = username + "@sso.local"
}
localUser, err := model.FindOneByCasdoorId(l.ctx, l.svcCtx.DB, casdoorId)
if err != nil {
if err == model.ErrNotFound {
// 用户不存在,尝试通过邮箱关联已有本地用户
existingUser, findErr := model.FindOneByEmail(l.ctx, l.svcCtx.DB, email)
if findErr == nil {
existingUser.CasdoorId = casdoorId
existingUser.UserType = "casdoor"
if updateErr := model.Update(l.ctx, l.svcCtx.DB, existingUser); updateErr != nil {
return "", fmt.Errorf("关联用户失败: %v", updateErr)
}
localUser = existingUser
l.Infof("SSO 关联已有用户: userId=%d, casdoorId=%s", existingUser.Id, casdoorId)
} else {
// 创建新用户
newUser := &model.User{
Username: username,
Email: email,
Password: "SSO_NO_PASSWORD", // SSO 用户不使用密码登录
Phone: userInfo.Phone,
CasdoorId: casdoorId,
UserType: "casdoor",
Role: model.RoleUser,
Source: model.SourceCasdoor,
Status: 1,
}
_, insertErr := model.Insert(l.ctx, l.svcCtx.DB, newUser)
if insertErr != nil {
l.Errorf("SSO 创建用户失败: %v", insertErr)
return "", fmt.Errorf("创建用户失败: %v", insertErr)
}
localUser = newUser
l.Infof("SSO 新用户创建成功: username=%s, casdoorId=%s", username, casdoorId)
}
} else {
return "", fmt.Errorf("查询用户失败: %v", err)
}
} else {
// 已有用户,同步更新 Casdoor 端的最新信息
updated := false
if username != "" && localUser.Username != username {
localUser.Username = username
updated = true
}
if email != "" && localUser.Email != email {
localUser.Email = email
updated = true
}
if userInfo.Phone != "" && localUser.Phone != userInfo.Phone {
localUser.Phone = userInfo.Phone
updated = true
}
if updated {
if updateErr := model.Update(l.ctx, l.svcCtx.DB, localUser); updateErr != nil {
l.Errorf("SSO 同步用户信息失败: %v", updateErr)
}
}
}
// 4. 生成本地 JWT Token
token, err := jwt.GenerateToken(localUser.Id, localUser.Username, localUser.Role)
if err != nil {
return "", fmt.Errorf("生成 Token 失败: %v", err)
}
l.Infof("SSO 登录成功: userId=%d, username=%s", localUser.Id, localUser.Username)
// 5. 构建前端回调 URL
redirectUrl := fmt.Sprintf("%s/sso/callback?token=%s",
c.FrontendUrl,
url.QueryEscape(token),
)
return redirectUrl, nil
}
// exchangeToken 用授权码换取 access_token
func (l *SSOLogic) exchangeToken(code string) (string, error) {
c := l.svcCtx.Config.Casdoor
tokenUrl := fmt.Sprintf("%s/api/login/oauth/access_token", c.Endpoint)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", c.ClientId)
data.Set("client_secret", c.ClientSecret)
data.Set("code", code)
data.Set("redirect_uri", c.RedirectUrl)
req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, tokenUrl,
strings.NewReader(data.Encode()))
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := casdoorHttpClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求 token 失败: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("token 请求返回 %d: %s", resp.StatusCode, string(body))
}
var tokenResp casdoorTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return "", fmt.Errorf("解析 token 响应失败: %v", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("未获取到 access_token, 响应: %s", string(body))
}
return tokenResp.AccessToken, nil
}
// getUserInfo 从 access_token JWT 中解析用户信息
// Casdoor 的 access_token 本身是一个 JWT,包含完整的用户 claims
func (l *SSOLogic) getUserInfo(accessToken string) (*casdoorUserInfo, error) {
// 解析 JWT payload(不验证签名,因为 token 刚从 Casdoor 获取)
parts := strings.Split(accessToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("access_token 不是有效的 JWT 格式")
}
// Base64 解码 payload
payload := parts[1]
// 补齐 base64 padding
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("解码 JWT payload 失败: %v", err)
}
// 解析 JWT claims
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("解析 JWT claims 失败: %v", err)
}
// 从 claims 中提取用户信息(Casdoor JWT 字段名)
userInfo := &casdoorUserInfo{
Sub: getStringClaim(claims, "sub"),
}
// Casdoor JWT 中用户名可能在 name 或 preferred_username 字段
userInfo.Name = getStringClaim(claims, "name")
userInfo.PreferredUsername = getStringClaim(claims, "preferred_username")
userInfo.Email = getStringClaim(claims, "email")
userInfo.Phone = getStringClaim(claims, "phone")
userInfo.Avatar = getStringClaim(claims, "avatar")
return userInfo, nil
}
// getStringClaim 从 claims map 中安全获取字符串值
func getStringClaim(claims map[string]interface{}, key string) string {
if v, ok := claims[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// generateState 生成随机 state 参数(CSRF 防护)
func generateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}