You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
8.3 KiB
304 lines
8.3 KiB
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/youruser/base/internal/svc"
|
|
"github.com/youruser/base/internal/util/jwt"
|
|
"github.com/youruser/base/model"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
// casdoorHttpClient 用于与 Casdoor 通信的 HTTP 客户端(带超时)
|
|
var casdoorHttpClient = &http.Client{Timeout: 10 * time.Second}
|
|
|
|
type SSOLogic struct {
|
|
logx.Logger
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
}
|
|
|
|
func NewSSOLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SSOLogic {
|
|
return &SSOLogic{
|
|
Logger: logx.WithContext(ctx),
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
}
|
|
}
|
|
|
|
// GetLoginUrl 生成 Casdoor SSO 登录链接
|
|
func (l *SSOLogic) GetLoginUrl() (map[string]string, error) {
|
|
c := l.svcCtx.Config.Casdoor
|
|
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("生成 state 失败: %v", err)
|
|
}
|
|
|
|
loginUrl := fmt.Sprintf("%s/login/oauth/authorize?client_id=%s&response_type=code&redirect_uri=%s&scope=read&state=%s",
|
|
c.Endpoint,
|
|
url.QueryEscape(c.ClientId),
|
|
url.QueryEscape(c.RedirectUrl),
|
|
url.QueryEscape(state),
|
|
)
|
|
|
|
return map[string]string{
|
|
"login_url": loginUrl,
|
|
}, nil
|
|
}
|
|
|
|
// casdoorTokenResponse Casdoor token 响应
|
|
type casdoorTokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
// casdoorUserInfo Casdoor 用户信息
|
|
type casdoorUserInfo struct {
|
|
Sub string `json:"sub"`
|
|
Name string `json:"name"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
Email string `json:"email"`
|
|
Phone string `json:"phone"`
|
|
Avatar string `json:"avatar"`
|
|
}
|
|
|
|
// HandleCallback 处理 SSO 回调
|
|
func (l *SSOLogic) HandleCallback(code, state string) (string, error) {
|
|
if code == "" {
|
|
return "", fmt.Errorf("缺少授权码")
|
|
}
|
|
|
|
c := l.svcCtx.Config.Casdoor
|
|
|
|
// 1. 用 code 换取 access_token
|
|
accessToken, err := l.exchangeToken(code)
|
|
if err != nil {
|
|
l.Errorf("SSO token 交换失败: %v", err)
|
|
return "", fmt.Errorf("token 交换失败: %v", err)
|
|
}
|
|
|
|
// 2. 获取用户信息
|
|
userInfo, err := l.getUserInfo(accessToken)
|
|
if err != nil {
|
|
l.Errorf("SSO 获取用户信息失败: %v", err)
|
|
return "", fmt.Errorf("获取用户信息失败: %v", err)
|
|
}
|
|
|
|
// 3. 查找或创建本地用户
|
|
casdoorId := userInfo.Sub
|
|
if casdoorId == "" {
|
|
casdoorId = userInfo.Name
|
|
}
|
|
|
|
// 从 Casdoor 信息中提取用户名和邮箱
|
|
username := userInfo.PreferredUsername
|
|
if username == "" {
|
|
username = userInfo.Name
|
|
}
|
|
email := userInfo.Email
|
|
if email == "" {
|
|
email = username + "@sso.local"
|
|
}
|
|
|
|
localUser, err := model.FindOneByCasdoorId(l.ctx, l.svcCtx.DB, casdoorId)
|
|
if err != nil {
|
|
if err == model.ErrNotFound {
|
|
// 用户不存在,尝试通过邮箱关联已有本地用户
|
|
existingUser, findErr := model.FindOneByEmail(l.ctx, l.svcCtx.DB, email)
|
|
if findErr == nil {
|
|
existingUser.CasdoorId = casdoorId
|
|
existingUser.UserType = "casdoor"
|
|
if updateErr := model.Update(l.ctx, l.svcCtx.DB, existingUser); updateErr != nil {
|
|
return "", fmt.Errorf("关联用户失败: %v", updateErr)
|
|
}
|
|
localUser = existingUser
|
|
l.Infof("SSO 关联已有用户: userId=%d, casdoorId=%s", existingUser.Id, casdoorId)
|
|
} else {
|
|
// 创建新用户
|
|
newUser := &model.User{
|
|
Username: username,
|
|
Email: email,
|
|
Password: "SSO_NO_PASSWORD", // SSO 用户不使用密码登录
|
|
Phone: userInfo.Phone,
|
|
CasdoorId: casdoorId,
|
|
UserType: "casdoor",
|
|
Role: model.RoleUser,
|
|
Source: model.SourceCasdoor,
|
|
Status: 1,
|
|
}
|
|
|
|
_, insertErr := model.Insert(l.ctx, l.svcCtx.DB, newUser)
|
|
if insertErr != nil {
|
|
l.Errorf("SSO 创建用户失败: %v", insertErr)
|
|
return "", fmt.Errorf("创建用户失败: %v", insertErr)
|
|
}
|
|
|
|
localUser = newUser
|
|
l.Infof("SSO 新用户创建成功: username=%s, casdoorId=%s", username, casdoorId)
|
|
}
|
|
} else {
|
|
return "", fmt.Errorf("查询用户失败: %v", err)
|
|
}
|
|
} else {
|
|
// 已有用户,同步更新 Casdoor 端的最新信息
|
|
updated := false
|
|
if username != "" && localUser.Username != username {
|
|
localUser.Username = username
|
|
updated = true
|
|
}
|
|
if email != "" && localUser.Email != email {
|
|
localUser.Email = email
|
|
updated = true
|
|
}
|
|
if userInfo.Phone != "" && localUser.Phone != userInfo.Phone {
|
|
localUser.Phone = userInfo.Phone
|
|
updated = true
|
|
}
|
|
if updated {
|
|
if updateErr := model.Update(l.ctx, l.svcCtx.DB, localUser); updateErr != nil {
|
|
l.Errorf("SSO 同步用户信息失败: %v", updateErr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. 生成本地 JWT Token
|
|
token, err := jwt.GenerateToken(localUser.Id, localUser.Username, localUser.Role, localUser.CurrentOrgId)
|
|
if err != nil {
|
|
return "", fmt.Errorf("生成 Token 失败: %v", err)
|
|
}
|
|
|
|
l.Infof("SSO 登录成功: userId=%d, username=%s", localUser.Id, localUser.Username)
|
|
|
|
// 5. 构建前端回调 URL
|
|
redirectUrl := fmt.Sprintf("%s/sso/callback?token=%s",
|
|
c.FrontendUrl,
|
|
url.QueryEscape(token),
|
|
)
|
|
|
|
return redirectUrl, nil
|
|
}
|
|
|
|
// exchangeToken 用授权码换取 access_token
|
|
func (l *SSOLogic) exchangeToken(code string) (string, error) {
|
|
c := l.svcCtx.Config.Casdoor
|
|
|
|
tokenUrl := fmt.Sprintf("%s/api/login/oauth/access_token", c.Endpoint)
|
|
|
|
data := url.Values{}
|
|
data.Set("grant_type", "authorization_code")
|
|
data.Set("client_id", c.ClientId)
|
|
data.Set("client_secret", c.ClientSecret)
|
|
data.Set("code", code)
|
|
data.Set("redirect_uri", c.RedirectUrl)
|
|
|
|
req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, tokenUrl,
|
|
strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return "", fmt.Errorf("创建请求失败: %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := casdoorHttpClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("请求 token 失败: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("读取响应失败: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("token 请求返回 %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tokenResp casdoorTokenResponse
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return "", fmt.Errorf("解析 token 响应失败: %v", err)
|
|
}
|
|
|
|
if tokenResp.AccessToken == "" {
|
|
return "", fmt.Errorf("未获取到 access_token, 响应: %s", string(body))
|
|
}
|
|
|
|
return tokenResp.AccessToken, nil
|
|
}
|
|
|
|
// getUserInfo 从 access_token JWT 中解析用户信息
|
|
// Casdoor 的 access_token 本身是一个 JWT,包含完整的用户 claims
|
|
func (l *SSOLogic) getUserInfo(accessToken string) (*casdoorUserInfo, error) {
|
|
// 解析 JWT payload(不验证签名,因为 token 刚从 Casdoor 获取)
|
|
parts := strings.Split(accessToken, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("access_token 不是有效的 JWT 格式")
|
|
}
|
|
|
|
// Base64 解码 payload
|
|
payload := parts[1]
|
|
// 补齐 base64 padding
|
|
switch len(payload) % 4 {
|
|
case 2:
|
|
payload += "=="
|
|
case 3:
|
|
payload += "="
|
|
}
|
|
|
|
decoded, err := base64.URLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("解码 JWT payload 失败: %v", err)
|
|
}
|
|
|
|
// 解析 JWT claims
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
|
return nil, fmt.Errorf("解析 JWT claims 失败: %v", err)
|
|
}
|
|
|
|
// 从 claims 中提取用户信息(Casdoor JWT 字段名)
|
|
userInfo := &casdoorUserInfo{
|
|
Sub: getStringClaim(claims, "sub"),
|
|
}
|
|
|
|
// Casdoor JWT 中用户名可能在 name 或 preferred_username 字段
|
|
userInfo.Name = getStringClaim(claims, "name")
|
|
userInfo.PreferredUsername = getStringClaim(claims, "preferred_username")
|
|
userInfo.Email = getStringClaim(claims, "email")
|
|
userInfo.Phone = getStringClaim(claims, "phone")
|
|
userInfo.Avatar = getStringClaim(claims, "avatar")
|
|
|
|
return userInfo, nil
|
|
}
|
|
|
|
// getStringClaim 从 claims map 中安全获取字符串值
|
|
func getStringClaim(claims map[string]interface{}, key string) string {
|
|
if v, ok := claims[key]; ok {
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// generateState 生成随机 state 参数(CSRF 防护)
|
|
func generateState() (string, error) {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|