You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
1.4 KiB
72 lines
1.4 KiB
package middleware
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"health-ai/pkg/jwt"
|
|
"health-ai/pkg/response"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// AuthRequired JWT认证中间件
|
|
func AuthRequired() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
response.Unauthorized(c, "未提供认证信息")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
response.Unauthorized(c, "认证格式错误,请使用 Bearer Token")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
claims, err := jwt.ParseToken(parts[1])
|
|
if err != nil {
|
|
response.Unauthorized(c, "Token无效或已过期")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 将用户ID存入上下文
|
|
c.Set("userID", claims.UserID)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// GetUserID 从上下文获取用户ID
|
|
func GetUserID(c *gin.Context) uint {
|
|
userID, exists := c.Get("userID")
|
|
if !exists {
|
|
return 0
|
|
}
|
|
return userID.(uint)
|
|
}
|
|
|
|
// OptionalAuth 可选认证中间件(不强制要求登录)
|
|
func OptionalAuth() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
claims, err := jwt.ParseToken(parts[1])
|
|
if err == nil {
|
|
c.Set("userID", claims.UserID)
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|