diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index 4880f85..ab9e491 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -57,7 +57,7 @@ func (ct *Login) Login(c *gin.Context) { // 检查是否需要验证码 if needCaptcha { - if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) { + if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) { response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")) return } @@ -68,8 +68,6 @@ func (ct *Login) Login(c *gin.Context) { if u.Id == 0 { global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp)) loginLimiter.RecordFailedAttempt(clientIp) - // 移除验证码,重新生成 - loginLimiter.RemoveCaptcha(clientIp) if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha { response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError")) } else { @@ -80,7 +78,6 @@ func (ct *Login) Login(c *gin.Context) { if !service.AllService.UserService.CheckUserEnable(u) { if needCaptcha { - loginLimiter.RemoveCaptcha(clientIp) response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled")) return } @@ -113,7 +110,7 @@ func (ct *Login) Captcha(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired")) return } - err, captcha := loginLimiter.RequireCaptcha(clientIp) + err, captcha := loginLimiter.RequireCaptcha() if err != nil { response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error()) return @@ -125,6 +122,7 @@ func (ct *Login) Captcha(c *gin.Context) { } response.Success(c, gin.H{ "captcha": gin.H{ + "id": captcha.Id, "b64": b64, }, }) diff --git a/http/request/admin/login.go b/http/request/admin/login.go index dca10ae..e7d84d2 100644 --- a/http/request/admin/login.go +++ b/http/request/admin/login.go @@ -1,10 +1,11 @@ package admin type Login struct { - Username string `json:"username" validate:"required" label:"用户名"` - Password string `json:"password,omitempty" validate:"required" label:"密码"` - Platform string `json:"platform" label:"平台"` - Captcha string `json:"captcha,omitempty" label:"验证码"` + Username string `json:"username" validate:"required" label:"用户名"` + Password string `json:"password,omitempty" validate:"required" label:"密码"` + Platform string `json:"platform" label:"平台"` + Captcha string `json:"captcha,omitempty" label:"验证码"` + CaptchaId string `json:"captcha_id,omitempty"` } type LoginLogQuery struct { diff --git a/utils/captcha.go b/utils/captcha.go index 45c365e..27674f8 100644 --- a/utils/captcha.go +++ b/utils/captcha.go @@ -5,15 +5,15 @@ import ( "time" ) -var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil) +var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil) -var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil) +var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil) type B64StringCaptchaProvider struct{} -func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) { - _, content, answer := capdString.GenerateIdQuestionAnswer() - return content, answer, nil +func (p B64StringCaptchaProvider) Generate() (string, string, string, error) { + id, content, answer := capdString.GenerateIdQuestionAnswer() + return id, content, answer, nil } func (p B64StringCaptchaProvider) Expiration() time.Duration { @@ -30,9 +30,9 @@ func (p B64StringCaptchaProvider) Draw(content string) (string, error) { type B64MathCaptchaProvider struct{} -func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) { - _, content, answer := capdMath.GenerateIdQuestionAnswer() - return content, answer, nil +func (p B64MathCaptchaProvider) Generate() (string, string, string, error) { + id, content, answer := capdMath.GenerateIdQuestionAnswer() + return id, content, answer, nil } func (p B64MathCaptchaProvider) Expiration() time.Duration { diff --git a/utils/login_limiter.go b/utils/login_limiter.go index 76aeb01..9bc5bb9 100644 --- a/utils/login_limiter.go +++ b/utils/login_limiter.go @@ -16,7 +16,7 @@ type SecurityPolicy struct { // 验证码提供者接口 type CaptchaProvider interface { - Generate(ip string) (string, string, error) + Generate() (id string, content string, answer string, err error) //Validate(ip, code string) bool Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow Draw(content string) (string, error) // 绘制验证码 @@ -24,6 +24,7 @@ type CaptchaProvider interface { // 验证码元数据 type CaptchaMeta struct { + Id string Content string Answer string ExpiresAt time.Time @@ -117,7 +118,7 @@ func (ll *LoginLimiter) RecordFailedAttempt(ip string) { } // 生成验证码 -func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) { +func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) { ll.mu.Lock() defer ll.mu.Unlock() @@ -125,23 +126,24 @@ func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) { return errors.New("no captcha provider available"), CaptchaMeta{} } - content, answer, err := ll.provider.Generate(ip) + id, content, answer, err := ll.provider.Generate() if err != nil { return err, CaptchaMeta{} } // 存储验证码 - ll.captchas[ip] = CaptchaMeta{ + ll.captchas[id] = CaptchaMeta{ + Id: id, Content: content, Answer: answer, ExpiresAt: time.Now().Add(ll.provider.Expiration()), } - return nil, ll.captchas[ip] + return nil, ll.captchas[id] } // 验证验证码 -func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool { +func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool { ll.mu.Lock() defer ll.mu.Unlock() @@ -151,20 +153,20 @@ func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool { } // 获取并验证验证码 - captcha, exists := ll.captchas[ip] + captcha, exists := ll.captchas[id] if !exists { return false } // 清理过期验证码 if time.Now().After(captcha.ExpiresAt) { - delete(ll.captchas, ip) + delete(ll.captchas, id) return false } // 验证并清理状态 if answer == captcha.Answer { - delete(ll.captchas, ip) + delete(ll.captchas, id) return true } @@ -176,16 +178,6 @@ func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) { return } -func (ll *LoginLimiter) RemoveCaptcha(ip string) { - ll.mu.Lock() - defer ll.mu.Unlock() - - _, exists := ll.captchas[ip] - if exists { - delete(ll.captchas, ip) - } -} - // 清除记录窗口 func (ll *LoginLimiter) RemoveAttempts(ip string) { ll.mu.Lock() @@ -212,7 +204,6 @@ func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequ // 清理过期数据 ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow)) - ll.pruneCaptchas(ip) // 检查验证码要求 captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold @@ -272,10 +263,10 @@ func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time { return valid } -func (ll *LoginLimiter) pruneCaptchas(ip string) { - if captcha, exists := ll.captchas[ip]; exists { +func (ll *LoginLimiter) pruneCaptchas(id string) { + if captcha, exists := ll.captchas[id]; exists { if time.Now().After(captcha.ExpiresAt) { - delete(ll.captchas, ip) + delete(ll.captchas, id) } } } @@ -299,7 +290,7 @@ func (ll *LoginLimiter) cleanupExpired() { } // 清理验证码 - for ip := range ll.captchas { - ll.pruneCaptchas(ip) + for id := range ll.captchas { + ll.pruneCaptchas(id) } } diff --git a/utils/login_limiter_test.go b/utils/login_limiter_test.go index afff418..fffcb23 100644 --- a/utils/login_limiter_test.go +++ b/utils/login_limiter_test.go @@ -2,18 +2,18 @@ package utils import ( "fmt" + "github.com/google/uuid" "testing" "time" ) type MockCaptchaProvider struct{} -func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) { - return "CONTENT", "MOCK", nil -} - -func (p *MockCaptchaProvider) Validate(ip, code string) bool { - return code == "MOCK" +func (p *MockCaptchaProvider) Generate() (string, string, string, error) { + id := uuid.New().String() + content := uuid.New().String() + answer := uuid.New().String() + return id, content, answer, nil } func (p *MockCaptchaProvider) Expiration() time.Duration { @@ -74,17 +74,22 @@ func TestCaptchaFlow(t *testing.T) { } // 生成验证码 - err, capc := limiter.RequireCaptcha(ip) + err, capc := limiter.RequireCaptcha() if err != nil { t.Fatalf("生成验证码失败: %v", err) } fmt.Printf("验证码内容: %#v\n", capc) // 验证成功 - if !limiter.VerifyCaptcha(ip, capc.Answer) { + if !limiter.VerifyCaptcha(capc.Id, capc.Answer) { t.Error("验证码应该验证成功") } + // 验证已删除 + if limiter.VerifyCaptcha(capc.Id, capc.Answer) { + t.Error("验证码应该已删除") + } + limiter.RemoveAttempts(ip) // 验证后状态 if banned, need := limiter.CheckSecurityStatus(ip); banned || need { @@ -104,14 +109,14 @@ func TestCaptchaMustFlow(t *testing.T) { } // 生成验证码 - err, capc := limiter.RequireCaptcha(ip) + err, capc := limiter.RequireCaptcha() if err != nil { t.Fatalf("生成验证码失败: %v", err) } fmt.Printf("验证码内容: %#v\n", capc) // 验证成功 - if !limiter.VerifyCaptcha(ip, capc.Answer) { + if !limiter.VerifyCaptcha(capc.Id, capc.Answer) { t.Error("验证码应该验证成功") } @@ -136,7 +141,7 @@ func TestAttemptTimeout(t *testing.T) { } // 生成验证码 - err, _ := limiter.RequireCaptcha(ip) + err, _ := limiter.RequireCaptcha() if err != nil { t.Fatalf("生成验证码失败: %v", err) } @@ -167,7 +172,7 @@ func TestCaptchaTimeout(t *testing.T) { } // 生成验证码 - err, _ := limiter.RequireCaptcha(ip) + err, capc := limiter.RequireCaptcha() if err != nil { t.Fatalf("生成验证码失败: %v", err) } @@ -175,9 +180,8 @@ func TestCaptchaTimeout(t *testing.T) { // 等待超过 CaptchaValidPeriod time.Sleep(3 * time.Second) - code := "MOCK" // 验证成功 - if limiter.VerifyCaptcha(ip, code) { + if limiter.VerifyCaptcha(capc.Id, capc.Answer) { t.Error("验证码应该已过期") } @@ -261,7 +265,7 @@ func TestB64CaptchaFlow(t *testing.T) { } // 生成验证码 - err, capc := limiter.RequireCaptcha(ip) + err, capc := limiter.RequireCaptcha() if err != nil { t.Fatalf("生成验证码失败: %v", err) } @@ -275,7 +279,7 @@ func TestB64CaptchaFlow(t *testing.T) { fmt.Printf("验证码内容: %#v\n", b64) // 验证成功 - if !limiter.VerifyCaptcha(ip, capc.Answer) { + if !limiter.VerifyCaptcha(capc.Id, capc.Answer) { t.Error("验证码应该验证成功") } limiter.RemoveAttempts(ip)