package handlers import ( "net/http" "time" "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "ai-myopia-prevention/internal/middleware" "ai-myopia-prevention/internal/utils" ) // AuthService 认证服务 type AuthService struct { DB *gorm.DB } // LoginRequest 登录请求 type LoginRequest struct { Username string `json:"username" binding:"required"` Password string `json:"password" binding:"required"` DeviceID string `json:"device_id"` // 设备ID,用于设备认证 } // LoginResponse 登录响应 type LoginResponse struct { Code int `json:"code"` Message string `json:"message"` Data struct { Token string `json:"token"` ExpiresAt time.Time `json:"expires_at"` UserID uint `json:"user_id"` Username string `json:"username"` Name string `json:"name"` Role string `json:"role"` // student, parent, teacher, admin } `json:"data"` } // RegisterRequest 注册请求 type RegisterRequest struct { Username string `json:"username" binding:"required,min=3,max=32"` Password string `json:"password" binding:"required"` // 移除min=6,改用强度校验 Name string `json:"name" binding:"required"` Phone string `json:"phone" binding:"required"` Role string `json:"role" binding:"required,oneof=student parent teacher"` // 角色 } // RegisterResponse 注册响应 type RegisterResponse struct { Code int `json:"code"` Message string `json:"message"` Data struct { UserID uint `json:"user_id"` } `json:"data"` } // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` NewPassword string `json:"new_password" binding:"required,min=6"` } // UserProfile 用户资料 type UserProfile struct { ID uint `json:"id"` Name string `json:"name"` Username string `json:"username"` Phone string `json:"phone"` Role string `json:"role"` } // NewAuthService 创建认证服务 func NewAuthService(db *gorm.DB) *AuthService { return &AuthService{DB: db} } // Login 用户登录 func (s *AuthService) Login(c *gin.Context) { var req LoginRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "参数错误: " + err.Error(), }) return } // 检查IP是否被封禁 ip := c.ClientIP() if middleware.LoginRateLimiterInstance.IsBlocked(ip) { c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "登录失败次数过多,请15分钟后重试", }) return } // 根据用户名或手机号查找用户 var user struct { ID uint `json:"id"` Username string `json:"username"` Name string `json:"name"` Phone string `json:"phone"` PasswordHash string `json:"-"` Role string `json:"role"` Status int `json:"status"` LastLoginAt *time.Time `json:"last_login_at"` LastLoginIP string `json:"last_login_ip"` } result := s.DB.Table("user_accounts"). Select("id, username, name, phone, password_hash, role, status"). Where("username = ? OR phone = ?", req.Username, req.Username). First(&user) if result.Error != nil { // 记录登录失败 middleware.LoginRateLimiterInstance.RecordFailure(ip) c.JSON(http.StatusUnauthorized, gin.H{ "code": 401, "message": "用户名或密码错误", }) return } if user.Status != 1 { c.JSON(http.StatusForbidden, gin.H{ "code": 403, "message": "账户已被禁用", }) return } // 验证密码 err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)) if err != nil { // 记录登录失败 middleware.LoginRateLimiterInstance.RecordFailure(ip) c.JSON(http.StatusUnauthorized, gin.H{ "code": 401, "message": "用户名或密码错误", }) return } // 登录成功,重置失败次数 middleware.LoginRateLimiterInstance.ResetAttempts(ip) // 生成JWT Token token, err := middleware.GenerateToken(user.ID, user.Username, user.Role) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "生成认证令牌失败", }) return } // 更新最后登录时间和IP s.DB.Table("user_accounts"). Where("id = ?", user.ID). Updates(map[string]interface{}{ "last_login_at": time.Now(), "last_login_ip": ip, }) resp := LoginResponse{ Code: 0, Message: "登录成功", } resp.Data.Token = token resp.Data.ExpiresAt = time.Now().Add(time.Hour * 24 * 7) // 7天过期 resp.Data.UserID = user.ID resp.Data.Username = user.Username resp.Data.Name = user.Name resp.Data.Role = user.Role c.JSON(http.StatusOK, resp) } // Register 用户注册 func (s *AuthService) Register(c *gin.Context) { var req RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "参数错误: " + err.Error(), }) return } // 密码强度校验 if err := middleware.ValidatePasswordStrength(req.Password); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "密码强度不够: " + err.Error(), }) return } // 检查用户名是否已存在 var count int64 s.DB.Table("user_accounts").Where("username = ?", req.Username).Count(&count) if count > 0 { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "用户名已存在", }) return } // 检查手机号是否已存在 s.DB.Table("user_accounts").Where("phone = ?", req.Phone).Count(&count) if count > 0 { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "手机号已被注册", }) return } // 加密密码 hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "密码加密失败", }) return } // 创建用户账号 userAccount := map[string]interface{}{ "username": req.Username, "password_hash": string(hashedPassword), "phone": req.Phone, "user_type": req.Role, "status": 1, } result := s.DB.Table("user_accounts").Create(userAccount) if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "注册失败: " + result.Error.Error(), }) return } // 获取创建的用户ID var newUser struct { ID uint `json:"id"` } s.DB.Table("user_accounts").Where("username = ?", req.Username).Order("id DESC").First(&newUser) resp := RegisterResponse{ Code: 0, Message: "注册成功", } resp.Data.UserID = newUser.ID c.JSON(http.StatusOK, resp) } // GetProfile 获取用户资料 func (s *AuthService) GetProfile(c *gin.Context) { // 这里应该是从JWT token中获取用户ID // 为了演示,我们使用一个占位符 userID := c.GetUint("user_id") // 从中间件传递过来的用户ID var user struct { ID uint `json:"id"` Name string `json:"name"` Username string `json:"username"` Phone string `json:"phone"` Role string `json:"role"` } result := s.DB.Table("user_accounts").Select("id, username, name, phone, user_type as role").Where("id = ?", userID).First(&user) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ "code": 404, "message": "用户不存在", }) return } c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "查询用户失败: " + result.Error.Error(), }) return } // 对敏感数据进行脱敏处理 user.Phone = utils.MaskPhone(user.Phone) user.Name = utils.MaskName(user.Name) c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "获取成功", "data": user, }) } // UpdateProfile 更新用户资料 func (s *AuthService) UpdateProfile(c *gin.Context) { userID := c.GetUint("user_id") // 从中间件传递过来的用户ID var req struct { Name string `json:"name"` Phone string `json:"phone"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "参数错误: " + err.Error(), }) return } updates := make(map[string]interface{}) if req.Name != "" { updates["name"] = req.Name } if req.Phone != "" { // 对手机号进行验证和格式化 updates["phone"] = req.Phone } if len(updates) == 0 { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "没有可更新的数据", }) return } result := s.DB.Table("user_accounts").Where("id = ?", userID).Updates(updates) if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "更新失败: " + result.Error.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "更新成功", }) } // ChangePassword 修改密码 func (s *AuthService) ChangePassword(c *gin.Context) { userID := c.GetUint("user_id") // 从中间件传递过来的用户ID var req ChangePasswordRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "参数错误: " + err.Error(), }) return } // 获取当前用户密码 var currentPasswordHash string s.DB.Table("user_accounts").Select("password_hash").Where("id = ?", userID).First(¤tPasswordHash) // 验证旧密码 err := bcrypt.CompareHashAndPassword([]byte(currentPasswordHash), []byte(req.OldPassword)) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "旧密码不正确", }) return } // 加密新密码 hashedNewPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "密码加密失败", }) return } // 更新密码 result := s.DB.Table("user_accounts"). Where("id = ?", userID). Update("password_hash", string(hashedNewPassword)) if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "修改密码失败: " + result.Error.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "密码修改成功", }) }