You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

223 lines
6.0 KiB

// 生成token,保存到redis,返回token,redis的key是token,value是user_id
// 过期时间由etc下 yaml文件配置 Auth.AccessExpire确定
// 秘钥由etc下 yaml文件配置 Auth.AccessSecret确定
package utils
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
// JWT 自定义声明
type JWTClaims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
ExString string `json:"ex_string"`
jwt.RegisteredClaims
}
// JWT 工具结构体
type JWTUtil struct {
AccessSecret string
AccessExpire int64
RedisClient *redis.Client
TkStore bool
}
// 创建新的 JWT 工具实例
func NewJWTUtil(accessSecret string, accessExpire int64, redisClient *redis.Client, tkStore bool) *JWTUtil {
return &JWTUtil{
AccessSecret: accessSecret,
AccessExpire: accessExpire,
RedisClient: redisClient,
TkStore: tkStore,
}
}
// 生成 JWT token
func (j *JWTUtil) GenerateToken(ctx context.Context, userID int64, username string) (string, error) {
// 创建声明
claims := &JWTClaims{
UserID: userID,
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(j.AccessExpire))),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "usercenter",
Subject: strconv.FormatInt(userID, 10),
},
ExString: uuid.New().String(),
}
// 使用 HS256 算法生成 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(j.AccessSecret))
if err != nil {
return "", fmt.Errorf("生成 token 失败: %w", err)
}
// 将 token 存储到 Redis,key 是 token,value 是 user_id
if j.TkStore {
err = j.RedisClient.Set(ctx, tokenString, userID, time.Duration(j.AccessExpire)*time.Second).Err()
if err != nil {
return "", fmt.Errorf("存储 token 到 Redis 失败: %w", err)
}
}
return tokenString, nil
}
// 验证 JWT token
func (j *JWTUtil) ValidateToken(ctx context.Context, tokenString string) (*JWTClaims, error) {
// 解析 token
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("意外的签名方法: %v", token.Header["alg"])
}
return []byte(j.AccessSecret), nil
})
if err != nil {
return nil, fmt.Errorf("解析 token 失败: %w", err)
}
// 验证 token 是否有效
if !token.Valid {
return nil, errors.New("无效的 token")
}
// 获取声明
claims, ok := token.Claims.(*JWTClaims)
if !ok {
return nil, errors.New("无法解析 token 声明")
}
// 验证 token 是否在 Redis 中存在
userID, err := j.RedisClient.Get(ctx, tokenString).Result()
if err != nil {
if err == redis.Nil {
return nil, errors.New("token 已过期或不存在")
}
return nil, fmt.Errorf("从 Redis 获取 token 失败: %w", err)
}
// 验证 Redis 中的 user_id 是否与 token 中的一致
redisUserID, err := strconv.ParseInt(userID, 10, 64)
if err != nil {
return nil, fmt.Errorf("解析 Redis 中的 user_id 失败: %w", err)
}
if redisUserID != claims.UserID {
return nil, errors.New("token 中的用户ID与 Redis 中的不匹配")
}
return claims, nil
}
// 刷新 token
func (j *JWTUtil) RefreshToken(ctx context.Context, tokenString string) (string, error) {
// 先验证当前 token
claims, err := j.ValidateToken(ctx, tokenString)
if err != nil {
return "", fmt.Errorf("验证旧 token 失败: %w", err)
}
// 删除旧的 token
err = j.RedisClient.Del(ctx, tokenString).Err()
if err != nil {
return "", fmt.Errorf("删除旧 token 失败: %w", err)
}
// 生成新的 token
return j.GenerateToken(ctx, claims.UserID, claims.Username)
}
// 删除 token(登出)
func (j *JWTUtil) DeleteToken(ctx context.Context, tokenString string) error {
// 从 Redis 中删除 token
err := j.RedisClient.Del(ctx, tokenString).Err()
if err != nil {
return fmt.Errorf("从 Redis 删除 token 失败: %w", err)
}
return nil
}
// 检查 token 是否存在
func (j *JWTUtil) TokenExists(ctx context.Context, tokenString string) (bool, error) {
exists, err := j.RedisClient.Exists(ctx, tokenString).Result()
if err != nil {
return false, fmt.Errorf("检查 token 是否存在失败: %w", err)
}
return exists == 1, nil
}
// 解析 token(不验证签名,用于获取基本信息)
func (j *JWTUtil) ParseTokenUnverified(tokenString string) (*JWTClaims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &JWTClaims{})
if err != nil {
return nil, fmt.Errorf("解析 token 失败: %w", err)
}
claims, ok := token.Claims.(*JWTClaims)
if !ok {
return nil, errors.New("无法解析 token 声明")
}
return claims, nil
}
// 获取用户所有有效的 token(通过用户ID前缀搜索)
func (j *JWTUtil) GetUserTokens(ctx context.Context, userID int64) ([]string, error) {
// 使用 SCAN 命令搜索包含用户ID的 token
var tokens []string
iter := j.RedisClient.Scan(ctx, 0, "*", 0).Iterator()
for iter.Next(ctx) {
key := iter.Val()
// 获取该 key 对应的 value
value, err := j.RedisClient.Get(ctx, key).Result()
if err != nil {
continue
}
// 检查 value 是否是目标用户ID
if value == strconv.FormatInt(userID, 10) {
tokens = append(tokens, key)
}
}
if err := iter.Err(); err != nil {
return nil, fmt.Errorf("搜索用户 token 失败: %w", err)
}
return tokens, nil
}
// 删除用户所有 token(强制登出)
func (j *JWTUtil) DeleteAllUserTokens(ctx context.Context, userID int64) error {
tokens, err := j.GetUserTokens(ctx, userID)
if err != nil {
return fmt.Errorf("获取用户 token 失败: %w", err)
}
if len(tokens) == 0 {
return nil
}
// 批量删除 token
err = j.RedisClient.Del(ctx, tokens...).Err()
if err != nil {
return fmt.Errorf("批量删除用户 token 失败: %w", err)
}
return nil
}