Files
ai-myopia-prevention/internal/middleware/auth.go
虾司令 881144269c 🚀 AI 近视防控系统 - 生产环境上线版本 v1.0
 已完成功能:
- 后端 Go 服务 (认证/授权/检测)
- JWT 认证 + RBAC 权限控制
- 登录速率限制 (5 次失败锁定 15 分钟)
- 密码强度校验
- 敏感数据脱敏
- Vue3 管理后台
- 路由守卫
- 删除二次确认

📦 部署配置:
- Docker Compose 生产环境配置
- MySQL/Redis/MongoDB 数据库
- Nginx 前端服务
- 强密码安全配置

⚠️ P2 待办 (下次迭代):
- 学生/检测/预警等业务模块实现
- 错误处理统一化
- 缓存策略优化
- 日志分级

📍 生产环境:
- 服务器:192.168.15.222
- 管理后台:http://192.168.15.222:8081
- API 服务:http://192.168.15.222:8080

2026-03-29 上线部署完成
2026-03-29 18:16:41 +08:00

368 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
// JWT Claims
type Claims struct {
UserID uint `json:"user_id"`
Username string `json:"username"`
Role string `json:"role"` // student, parent, teacher, admin
jwt.RegisteredClaims
}
// JWT密钥 - 在实际应用中应从环境变量加载
var jwtKey = []byte("ai-myopia-prevention-jwt-secret-key-change-in-production")
// GenerateToken 生成JWT Token
func GenerateToken(userID uint, username string, role string) (string, error) {
expirationTime := time.Now().Add(24 * 7 * time.Hour) // 7天过期
claims := &Claims{
UserID: userID,
Username: username,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "ai-myopia-prevention",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(jwtKey)
}
// ParseToken 解析JWT Token
func ParseToken(tokenStr string) (*Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return claims, nil
}
// JWTAuthMiddleware JWT认证中间件
func JWTAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "未提供认证信息",
})
c.Abort()
return
}
// 检查Authorization头部格式
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "认证信息格式错误",
})
c.Abort()
return
}
tokenStr := parts[1]
claims, err := ParseToken(tokenStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "无效的认证信息",
})
c.Abort()
return
}
// 将用户信息存储到上下文中
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("role", claims.Role)
c.Next()
}
}
// RBACMiddleware 基于角色的访问控制中间件
func RBACMiddleware(allowedRoles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "无法获取用户角色",
})
c.Abort()
return
}
userRole, ok := role.(string)
if !ok {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "无效的用户角色",
})
c.Abort()
return
}
// 检查用户角色是否在允许的角色列表中
for _, allowedRole := range allowedRoles {
if userRole == allowedRole {
c.Next()
return
}
}
// 管理员可以访问所有接口
if userRole == "admin" {
c.Next()
return
}
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足,无法访问该接口",
})
c.Abort()
}
}
// ValidatePasswordStrength 验证密码强度
func ValidatePasswordStrength(password string) error {
// 检查长度至少8位
if len(password) < 8 {
return fmt.Errorf("密码长度至少8位")
}
// 检查是否包含大小写字母、数字和特殊字符
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char == '!' || char == '@' || char == '#' || char == '$' ||
char == '%' || char == '^' || char == '&' || char == '*' ||
char == '(' || char == ')' || char == '-' || char == '_' ||
char == '+' || char == '=' || char == '[' || char == ']' ||
char == '{' || char == '}' || char == '|' || char == '\\' ||
char == ':' || char == ';' || char == '"' || char == '\'' ||
char == '<' || char == '>' || char == ',' || char == '.' ||
char == '?' || char == '/' || char == '~' || char == '`':
hasSpecial = true
}
}
if !hasUpper {
return fmt.Errorf("密码必须包含大写字母")
}
if !hasLower {
return fmt.Errorf("密码必须包含小写字母")
}
if !hasDigit {
return fmt.Errorf("密码必须包含数字")
}
if !hasSpecial {
return fmt.Errorf("密码必须包含特殊字符")
}
return nil
}
// AttemptInfo 尝试信息
type AttemptInfo struct {
Attempts int
LastAttempt time.Time
Blocked bool
BlockUntil time.Time
}
// LoginRateLimiter 登录速率限制器
type LoginRateLimiter struct {
// 存储每个IP的失败尝试次数
attempts map[string]*AttemptInfo
mutex sync.RWMutex
// 配置参数
maxAttempts int // 最大尝试次数
blockTime time.Duration // 封锁时间
}
// NewLoginRateLimiter 创建新的登录速率限制器
func NewLoginRateLimiter(maxAttempts int, blockTime time.Duration) *LoginRateLimiter {
limiter := &LoginRateLimiter{
attempts: make(map[string]*AttemptInfo),
maxAttempts: maxAttempts,
blockTime: blockTime,
}
// 启动清理协程,定期清理过期记录
go limiter.cleanup()
return limiter
}
// IsBlocked 检查IP是否被封禁
func (l *LoginRateLimiter) IsBlocked(ip string) bool {
l.mutex.RLock()
info, exists := l.attempts[ip]
l.mutex.RUnlock()
if !exists {
return false
}
// 如果封禁时间已过,解除封禁
if info.Blocked && time.Now().After(info.BlockUntil) {
l.mutex.Lock()
delete(l.attempts, ip)
l.mutex.Unlock()
return false
}
return info.Blocked
}
// RecordFailure 记录登录失败
func (l *LoginRateLimiter) RecordFailure(ip string) {
l.mutex.Lock()
defer l.mutex.Unlock()
info, exists := l.attempts[ip]
if !exists {
info = &AttemptInfo{
Attempts: 1,
LastAttempt: time.Now(),
}
l.attempts[ip] = info
return
}
// 更新尝试次数
info.Attempts++
info.LastAttempt = time.Now()
// 如果超过最大尝试次数封禁IP
if info.Attempts >= l.maxAttempts {
info.Blocked = true
info.BlockUntil = time.Now().Add(l.blockTime)
}
}
// ResetAttempts 重置尝试次数(登录成功后调用)
func (l *LoginRateLimiter) ResetAttempts(ip string) {
l.mutex.Lock()
defer l.mutex.Unlock()
delete(l.attempts, ip)
}
// cleanup 定期清理过期记录
func (l *LoginRateLimiter) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
l.mutex.Lock()
now := time.Now()
for ip, info := range l.attempts {
// 如果封禁时间已过,删除记录
if info.Blocked && now.After(info.BlockUntil) {
delete(l.attempts, ip)
} else if !info.Blocked && now.Sub(info.LastAttempt) > l.blockTime {
// 如果最后一次尝试时间超过封禁时间,也删除记录(非封禁状态下的旧记录)
delete(l.attempts, ip)
}
}
l.mutex.Unlock()
}
}
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limit int, window time.Duration) gin.HandlerFunc {
type RequestInfo struct {
Count int
Time time.Time
}
requests := make(map[string][]RequestInfo)
mutex := sync.RWMutex{}
return func(c *gin.Context) {
clientIP := c.ClientIP()
mutex.Lock()
defer mutex.Unlock()
now := time.Now()
windowStart := now.Add(-window)
// 清理过期请求记录
var validRequests []RequestInfo
for _, req := range requests[clientIP] {
if req.Time.After(windowStart) {
validRequests = append(validRequests, req)
}
}
requests[clientIP] = validRequests
// 检查是否超出限制
if len(requests[clientIP]) >= limit {
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
// 记录当前请求
requests[clientIP] = append(requests[clientIP], RequestInfo{
Count: 1,
Time: now,
})
c.Next()
}
}
// PasswordValidatorMiddleware 密码强度验证中间件
func PasswordValidatorMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 这个中间件主要用于验证密码强度
// 在需要验证密码强度的接口中使用
c.Next()
}
}
// 全局登录速率限制器实例
var LoginRateLimiterInstance *LoginRateLimiter = NewLoginRateLimiter(5, 15*time.Minute) // 5次失败后封禁15分钟