package middleware import ( "errors" "fmt" "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) // JWT Claims type Claims struct { UserID uint `json:"user_id"` Username string `json:"username"` Role string `json:"role"` // student, parent, teacher, admin jwt.RegisteredClaims } // JWT密钥 - 在实际应用中应从环境变量加载 var jwtKey = []byte("ai-myopia-prevention-jwt-secret-key-change-in-production") // GenerateToken 生成JWT Token func GenerateToken(userID uint, username string, role string) (string, error) { expirationTime := time.Now().Add(24 * 7 * time.Hour) // 7天过期 claims := &Claims{ UserID: userID, Username: username, Role: role, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: "ai-myopia-prevention", }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(jwtKey) } // ParseToken 解析JWT Token func ParseToken(tokenStr string) (*Claims, error) { claims := &Claims{} token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { return jwtKey, nil }) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("invalid token") } return claims, nil } // JWTAuthMiddleware JWT认证中间件 func JWTAuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{ "code": 401, "message": "未提供认证信息", }) c.Abort() return } // 检查Authorization头部格式 parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.JSON(http.StatusUnauthorized, gin.H{ "code": 401, "message": "认证信息格式错误", }) c.Abort() return } tokenStr := parts[1] claims, err := ParseToken(tokenStr) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "code": 401, "message": "无效的认证信息", }) c.Abort() return } // 将用户信息存储到上下文中 c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("role", claims.Role) c.Next() } } // RBACMiddleware 基于角色的访问控制中间件 func RBACMiddleware(allowedRoles ...string) gin.HandlerFunc { return func(c *gin.Context) { role, exists := c.Get("role") if !exists { c.JSON(http.StatusForbidden, gin.H{ "code": 403, "message": "无法获取用户角色", }) c.Abort() return } userRole, ok := role.(string) if !ok { c.JSON(http.StatusForbidden, gin.H{ "code": 403, "message": "无效的用户角色", }) c.Abort() return } // 检查用户角色是否在允许的角色列表中 for _, allowedRole := range allowedRoles { if userRole == allowedRole { c.Next() return } } // 管理员可以访问所有接口 if userRole == "admin" { c.Next() return } c.JSON(http.StatusForbidden, gin.H{ "code": 403, "message": "权限不足,无法访问该接口", }) c.Abort() } } // ValidatePasswordStrength 验证密码强度 func ValidatePasswordStrength(password string) error { // 检查长度至少8位 if len(password) < 8 { return fmt.Errorf("密码长度至少8位") } // 检查是否包含大小写字母、数字和特殊字符 hasUpper := false hasLower := false hasDigit := false hasSpecial := false for _, char := range password { switch { case char >= 'A' && char <= 'Z': hasUpper = true case char >= 'a' && char <= 'z': hasLower = true case char >= '0' && char <= '9': hasDigit = true case char == '!' || char == '@' || char == '#' || char == '$' || char == '%' || char == '^' || char == '&' || char == '*' || char == '(' || char == ')' || char == '-' || char == '_' || char == '+' || char == '=' || char == '[' || char == ']' || char == '{' || char == '}' || char == '|' || char == '\\' || char == ':' || char == ';' || char == '"' || char == '\'' || char == '<' || char == '>' || char == ',' || char == '.' || char == '?' || char == '/' || char == '~' || char == '`': hasSpecial = true } } if !hasUpper { return fmt.Errorf("密码必须包含大写字母") } if !hasLower { return fmt.Errorf("密码必须包含小写字母") } if !hasDigit { return fmt.Errorf("密码必须包含数字") } if !hasSpecial { return fmt.Errorf("密码必须包含特殊字符") } return nil } // AttemptInfo 尝试信息 type AttemptInfo struct { Attempts int LastAttempt time.Time Blocked bool BlockUntil time.Time } // LoginRateLimiter 登录速率限制器 type LoginRateLimiter struct { // 存储每个IP的失败尝试次数 attempts map[string]*AttemptInfo mutex sync.RWMutex // 配置参数 maxAttempts int // 最大尝试次数 blockTime time.Duration // 封锁时间 } // NewLoginRateLimiter 创建新的登录速率限制器 func NewLoginRateLimiter(maxAttempts int, blockTime time.Duration) *LoginRateLimiter { limiter := &LoginRateLimiter{ attempts: make(map[string]*AttemptInfo), maxAttempts: maxAttempts, blockTime: blockTime, } // 启动清理协程,定期清理过期记录 go limiter.cleanup() return limiter } // IsBlocked 检查IP是否被封禁 func (l *LoginRateLimiter) IsBlocked(ip string) bool { l.mutex.RLock() info, exists := l.attempts[ip] l.mutex.RUnlock() if !exists { return false } // 如果封禁时间已过,解除封禁 if info.Blocked && time.Now().After(info.BlockUntil) { l.mutex.Lock() delete(l.attempts, ip) l.mutex.Unlock() return false } return info.Blocked } // RecordFailure 记录登录失败 func (l *LoginRateLimiter) RecordFailure(ip string) { l.mutex.Lock() defer l.mutex.Unlock() info, exists := l.attempts[ip] if !exists { info = &AttemptInfo{ Attempts: 1, LastAttempt: time.Now(), } l.attempts[ip] = info return } // 更新尝试次数 info.Attempts++ info.LastAttempt = time.Now() // 如果超过最大尝试次数,封禁IP if info.Attempts >= l.maxAttempts { info.Blocked = true info.BlockUntil = time.Now().Add(l.blockTime) } } // ResetAttempts 重置尝试次数(登录成功后调用) func (l *LoginRateLimiter) ResetAttempts(ip string) { l.mutex.Lock() defer l.mutex.Unlock() delete(l.attempts, ip) } // cleanup 定期清理过期记录 func (l *LoginRateLimiter) cleanup() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for range ticker.C { l.mutex.Lock() now := time.Now() for ip, info := range l.attempts { // 如果封禁时间已过,删除记录 if info.Blocked && now.After(info.BlockUntil) { delete(l.attempts, ip) } else if !info.Blocked && now.Sub(info.LastAttempt) > l.blockTime { // 如果最后一次尝试时间超过封禁时间,也删除记录(非封禁状态下的旧记录) delete(l.attempts, ip) } } l.mutex.Unlock() } } // RateLimitMiddleware 速率限制中间件 func RateLimitMiddleware(limit int, window time.Duration) gin.HandlerFunc { type RequestInfo struct { Count int Time time.Time } requests := make(map[string][]RequestInfo) mutex := sync.RWMutex{} return func(c *gin.Context) { clientIP := c.ClientIP() mutex.Lock() defer mutex.Unlock() now := time.Now() windowStart := now.Add(-window) // 清理过期请求记录 var validRequests []RequestInfo for _, req := range requests[clientIP] { if req.Time.After(windowStart) { validRequests = append(validRequests, req) } } requests[clientIP] = validRequests // 检查是否超出限制 if len(requests[clientIP]) >= limit { c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", }) c.Abort() return } // 记录当前请求 requests[clientIP] = append(requests[clientIP], RequestInfo{ Count: 1, Time: now, }) c.Next() } } // PasswordValidatorMiddleware 密码强度验证中间件 func PasswordValidatorMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 这个中间件主要用于验证密码强度 // 在需要验证密码强度的接口中使用 c.Next() } } // 全局登录速率限制器实例 var LoginRateLimiterInstance *LoginRateLimiter = NewLoginRateLimiter(5, 15*time.Minute) // 5次失败后封禁15分钟