You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

2.8 KiB

JWT 处理

JWT Manager

package jwtx

import (
	"fmt"
	"time"

	"github.com/golang-jwt/jwt/v5"
)

type Claims struct {
	UserId   int64  `json:"userId"`
	Username string `json:"username"`
	Email    string `json:"email"`
	jwt.RegisteredClaims
}

type JWTManager struct {
	secret string
	expire time.Duration
}

func NewJWTManager(secret string, expireSeconds int64) *JWTManager {
	return &JWTManager{
		secret: secret,
		expire: time.Duration(expireSeconds) * time.Second,
	}
}

// GenerateToken 生成 JWT Token
func (j *JWTManager) GenerateToken(userId int64, username, email string) (string, int64, error) {
	now := time.Now()
	expiresAt := now.Add(j.expire).Unix()

	claims := Claims{
		UserId:   userId,
		Username: username,
		Email:    email,
		RegisteredClaims: jwt.RegisteredClaims{
			ExpiresAt: jwt.NewNumericDate(now.Add(j.expire)),
			NotBefore: jwt.NewNumericDate(now),
			IssuedAt:  jwt.NewNumericDate(now),
		},
	}

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	tokenString, err := token.SignedString([]byte(j.secret))
	if err != nil {
		return "", 0, err
	}

	return tokenString, expiresAt, nil
}

// ParseToken 解析 JWT Token
func (j *JWTManager) ParseToken(tokenString string) (*Claims, error) {
	token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}
		return []byte(j.secret), nil
	})

	if err != nil {
		return nil, err
	}

	if claims, ok := token.Claims.(*Claims); ok && token.Valid {
		return claims, nil
	}

	return nil, fmt.Errorf("invalid token")
}

认证中间件

package middleware

import (
	"backend/internal/jwtx"
	"backend/internal/svc"
	"context"
	"net/http"
	"strings"
)

type contextKey string

const UserIdKey contextKey = "userId"

func AuthMiddleware(svcCtx *svc.ServiceContext) func(http.HandlerFunc) http.HandlerFunc {
	return func(next http.HandlerFunc) http.HandlerFunc {
		return func(w http.ResponseWriter, r *http.Request) {
			authHeader := r.Header.Get("Authorization")
			if authHeader == "" {
				http.Error(w, "missing authorization header", http.StatusUnauthorized)
				return
			}

			parts := strings.SplitN(authHeader, " ", 2)
			if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
				http.Error(w, "invalid authorization header", http.StatusUnauthorized)
				return
			}

			claims, err := svcCtx.JWT.ParseToken(parts[1])
			if err != nil {
				http.Error(w, "invalid token", http.StatusUnauthorized)
				return
			}

			ctx := context.WithValue(r.Context(), UserIdKey, claims.UserId)
			next(w, r.WithContext(ctx))
		}
	}
}

func GetUserId(ctx context.Context) int64 {
	if userId, ok := ctx.Value(UserIdKey).(int64); ok {
		return userId
	}
	return 0
}