You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2.8 KiB
2.8 KiB
JWT 处理
JWT Manager
package jwtx
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
type Claims struct {
UserId int64 `json:"userId"`
Username string `json:"username"`
Email string `json:"email"`
jwt.RegisteredClaims
}
type JWTManager struct {
secret string
expire time.Duration
}
func NewJWTManager(secret string, expireSeconds int64) *JWTManager {
return &JWTManager{
secret: secret,
expire: time.Duration(expireSeconds) * time.Second,
}
}
// GenerateToken 生成 JWT Token
func (j *JWTManager) GenerateToken(userId int64, username, email string) (string, int64, error) {
now := time.Now()
expiresAt := now.Add(j.expire).Unix()
claims := Claims{
UserId: userId,
Username: username,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.expire)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(j.secret))
if err != nil {
return "", 0, err
}
return tokenString, expiresAt, nil
}
// ParseToken 解析 JWT Token
func (j *JWTManager) ParseToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(j.secret), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
认证中间件
package middleware
import (
"backend/internal/jwtx"
"backend/internal/svc"
"context"
"net/http"
"strings"
)
type contextKey string
const UserIdKey contextKey = "userId"
func AuthMiddleware(svcCtx *svc.ServiceContext) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "missing authorization header", http.StatusUnauthorized)
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
http.Error(w, "invalid authorization header", http.StatusUnauthorized)
return
}
claims, err := svcCtx.JWT.ParseToken(parts[1])
if err != nil {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), UserIdKey, claims.UserId)
next(w, r.WithContext(ctx))
}
}
}
func GetUserId(ctx context.Context) int64 {
if userId, ok := ctx.Value(UserIdKey).(int64); ok {
return userId
}
return 0
}