From f2ea02296573faa14633a5601c6c6f7000355764 Mon Sep 17 00:00:00 2001 From: lejianwen <84855512@qq.com> Date: Sun, 25 May 2025 00:06:06 +0800 Subject: [PATCH] feat(login): Captcha upgrade and add the function to ban IP addresses (#250) --- cmd/apimain.go | 9 + conf/config.yaml | 3 + config/config.go | 14 +- global/global.go | 10 +- http/controller/admin/login.go | 191 ++++++--------------- http/controller/api/login.go | 8 + http/http.go | 2 +- http/middleware/limiter.go | 22 +++ resources/i18n/en.toml | 7 +- resources/i18n/es.toml | 7 +- resources/i18n/fr.toml | 7 +- resources/i18n/ko.toml | 7 +- resources/i18n/ru.toml | 7 +- resources/i18n/zh_CN.toml | 7 +- resources/i18n/zh_TW.toml | 7 +- utils/captcha.go | 48 ++++++ utils/login_limiter.go | 305 +++++++++++++++++++++++++++++++++ utils/login_limiter_test.go | 286 +++++++++++++++++++++++++++++++ 18 files changed, 787 insertions(+), 160 deletions(-) create mode 100644 http/middleware/limiter.go create mode 100644 utils/captcha.go create mode 100644 utils/login_limiter.go create mode 100644 utils/login_limiter_test.go diff --git a/cmd/apimain.go b/cmd/apimain.go index 67f1788..326c98f 100644 --- a/cmd/apimain.go +++ b/cmd/apimain.go @@ -18,6 +18,7 @@ import ( "github.com/spf13/cobra" "os" "strconv" + "time" ) // @title 管理系统API @@ -175,8 +176,16 @@ func InitGlobal() { //service service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock) + global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{ + CaptchaThreshold: global.Config.App.CaptchaThreshold, + BanThreshold: global.Config.App.BanThreshold, + AttemptsWindow: 10 * time.Minute, + BanDuration: 30 * time.Minute, + }) + global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{}) DatabaseAutoUpdate() } + func DatabaseAutoUpdate() { version := 262 diff --git a/conf/config.yaml b/conf/config.yaml index 8f6ca48..01dc071 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -2,10 +2,13 @@ lang: "zh-CN" app: web-client: 1 # 1:启用 0:禁用 register: false #是否开启注册 + captcha-threshold: 3 # <0:disabled, 0 always, >0:enabled + ban-threshold: 0 # 0:disabled, >0:enabled show-swagger: 0 # 1:启用 0:禁用 token-expire: 168h web-sso: true #web auth sso disable-pwd-login: false #禁用密码登录 + admin: title: "RustDesk Api Admin" hello-file: "./conf/admin/hello.html" #优先使用file diff --git a/config/config.go b/config/config.go index df9dc67..45030cf 100644 --- a/config/config.go +++ b/config/config.go @@ -14,12 +14,14 @@ const ( ) type App struct { - WebClient int `mapstructure:"web-client"` - Register bool `mapstructure:"register"` - ShowSwagger int `mapstructure:"show-swagger"` - TokenExpire time.Duration `mapstructure:"token-expire"` - WebSso bool `mapstructure:"web-sso"` - DisablePwdLogin bool `mapstructure:"disable-pwd-login"` + WebClient int `mapstructure:"web-client"` + Register bool `mapstructure:"register"` + ShowSwagger int `mapstructure:"show-swagger"` + TokenExpire time.Duration `mapstructure:"token-expire"` + WebSso bool `mapstructure:"web-sso"` + DisablePwdLogin bool `mapstructure:"disable-pwd-login"` + CaptchaThreshold int `mapstructure:"captcha-threshold"` + BanThreshold int `mapstructure:"ban-threshold"` } type Admin struct { Title string `mapstructure:"title"` diff --git a/global/global.go b/global/global.go index 3c46d9c..f418be7 100644 --- a/global/global.go +++ b/global/global.go @@ -10,6 +10,7 @@ import ( "github.com/lejianwen/rustdesk-api/v2/lib/jwt" "github.com/lejianwen/rustdesk-api/v2/lib/lock" "github.com/lejianwen/rustdesk-api/v2/lib/upload" + "github.com/lejianwen/rustdesk-api/v2/utils" "github.com/nicksnyder/go-i18n/v2/i18n" "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -31,8 +32,9 @@ var ( ValidStruct func(*gin.Context, interface{}) []string ValidVar func(ctx *gin.Context, field interface{}, tag string) []string } - Oss *upload.Oss - Jwt *jwt.Jwt - Lock lock.Locker - Localizer func(lang string) *i18n.Localizer + Oss *upload.Oss + Jwt *jwt.Jwt + Lock lock.Locker + Localizer func(lang string) *i18n.Localizer + LoginLimiter *utils.LoginLimiter ) diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index bfbe801..4880f85 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -11,135 +11,11 @@ import ( adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/service" - "github.com/mojocn/base64Captcha" - "sync" - "time" ) type Login struct { } -// Captcha 验证码结构 -type Captcha struct { - Id string `json:"id"` // 验证码 ID - B64 string `json:"b64"` // base64 验证码 - Code string `json:"-"` // 验证码内容 - ExpiresAt time.Time `json:"-"` // 过期时间 -} -type LoginLimiter struct { - mu sync.RWMutex - failCount map[string]int // 记录每个 IP 的失败次数 - timestamp map[string]time.Time // 记录每个 IP 的最后失败时间 - captchas map[string]Captcha // 每个 IP 的验证码 - threshold int // 失败阈值 - expiry time.Duration // 失败记录过期时间 -} - -func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter { - return &LoginLimiter{ - failCount: make(map[string]int), - timestamp: make(map[string]time.Time), - captchas: make(map[string]Captcha), - threshold: threshold, - expiry: expiry, - } -} - -// RecordFailure 记录登录失败 -func (l *LoginLimiter) RecordFailure(ip string) { - l.mu.Lock() - defer l.mu.Unlock() - - // 如果该 IP 的记录已经过期,重置计数 - if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry { - l.failCount[ip] = 0 - } - - // 更新失败次数和时间戳 - l.failCount[ip]++ - l.timestamp[ip] = time.Now() -} - -// NeedsCaptcha 检查是否需要验证码 -func (l *LoginLimiter) NeedsCaptcha(ip string) bool { - l.mu.RLock() - defer l.mu.RUnlock() - - // 检查记录是否存在且未过期 - if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry { - return l.failCount[ip] >= l.threshold - } - return false -} - -// GenerateCaptcha 为指定 IP 生成验证码 -func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha { - l.mu.Lock() - defer l.mu.Unlock() - - capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil) - b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore) - id, b64s, answer, err := b64cap.Generate() - if err != nil { - global.Logger.Error("Generate captcha failed: " + err.Error()) - return Captcha{} - } - // 保存验证码到对应 IP - l.captchas[ip] = Captcha{ - Id: id, - B64: b64s, - Code: answer, - ExpiresAt: time.Now().Add(5 * time.Minute), - } - return l.captchas[ip] -} - -// VerifyCaptcha 验证指定 IP 的验证码 -func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool { - l.mu.RLock() - defer l.mu.RUnlock() - - // 检查验证码是否存在且未过期 - if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) { - return captcha.Code == code - } - return false -} - -// RemoveCaptcha 移除指定 IP 的验证码 -func (l *LoginLimiter) RemoveCaptcha(ip string) { - l.mu.Lock() - defer l.mu.Unlock() - - delete(l.captchas, ip) -} - -// CleanupExpired 清理过期的记录 -func (l *LoginLimiter) CleanupExpired() { - l.mu.Lock() - defer l.mu.Unlock() - - now := time.Now() - for ip, lastTime := range l.timestamp { - if now.Sub(lastTime) > l.expiry { - delete(l.failCount, ip) - delete(l.timestamp, ip) - delete(l.captchas, ip) - } - } -} - -func (l *LoginLimiter) RemoveRecord(ip string) { - l.mu.Lock() - defer l.mu.Unlock() - - delete(l.failCount, ip) - delete(l.timestamp, ip) - delete(l.captchas, ip) -} - -var loginLimiter = NewLoginLimiter(3, 5*time.Minute) - // Login 登录 // @Tags 登录 // @Summary 登录 @@ -156,10 +32,16 @@ func (ct *Login) Login(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled")) return } + + // 检查登录限制 + loginLimiter := global.LoginLimiter + clientIp := c.ClientIP() + _, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp) + f := &admin.Login{} err := c.ShouldBindJSON(f) - clientIp := c.ClientIP() if err != nil { + loginLimiter.RecordFailedAttempt(clientIp) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp)) response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error()) return @@ -167,13 +49,14 @@ func (ct *Login) Login(c *gin.Context) { errList := global.Validator.ValidStruct(c, f) if len(errList) > 0 { + loginLimiter.RecordFailedAttempt(clientIp) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp)) response.Fail(c, 101, errList[0]) return } // 检查是否需要验证码 - if loginLimiter.NeedsCaptcha(clientIp) { + if needCaptcha { if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) { response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")) return @@ -184,17 +67,22 @@ func (ct *Login) Login(c *gin.Context) { if u.Id == 0 { global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp)) - loginLimiter.RecordFailure(clientIp) - if loginLimiter.NeedsCaptcha(clientIp) { - loginLimiter.RemoveCaptcha(clientIp) + loginLimiter.RecordFailedAttempt(clientIp) + // 移除验证码,重新生成 + loginLimiter.RemoveCaptcha(clientIp) + if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha { + response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError")) + } else { + response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError")) } - response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError")) return } if !service.AllService.UserService.CheckUserEnable(u) { - if loginLimiter.NeedsCaptcha(clientIp) { + if needCaptcha { loginLimiter.RemoveCaptcha(clientIp) + response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled")) + return } response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled")) return @@ -209,23 +97,36 @@ func (ct *Login) Login(c *gin.Context) { Platform: f.Platform, }) - // 成功后清除记录 - loginLimiter.RemoveRecord(clientIp) - - // 清理过期记录 - go loginLimiter.CleanupExpired() - + // 登录成功,清除登录限制 + loginLimiter.RemoveAttempts(clientIp) responseLoginSuccess(c, u, ut.Token) } func (ct *Login) Captcha(c *gin.Context) { + loginLimiter := global.LoginLimiter clientIp := c.ClientIP() - if !loginLimiter.NeedsCaptcha(clientIp) { + banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp) + if banned { + response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned")) + return + } + if !needCaptcha { response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired")) return } - captcha := loginLimiter.GenerateCaptcha(clientIp) + err, captcha := loginLimiter.RequireCaptcha(clientIp) + if err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error()) + return + } + err, b64 := loginLimiter.DrawCaptcha(captcha.Content) + if err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error()) + return + } response.Success(c, gin.H{ - "captcha": captcha, + "captcha": gin.H{ + "b64": b64, + }, }) } @@ -257,12 +158,18 @@ func (ct *Login) Logout(c *gin.Context) { // @Failure 500 {object} response.ErrorResponse // @Router /admin/login-options [post] func (ct *Login) LoginOptions(c *gin.Context) { - ip := c.ClientIP() + loginLimiter := global.LoginLimiter + clientIp := c.ClientIP() + banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp) + if banned { + response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned")) + return + } ops := service.AllService.OauthService.GetOauthProviders() response.Success(c, gin.H{ "ops": ops, "register": global.Config.App.Register, - "need_captcha": loginLimiter.NeedsCaptcha(ip), + "need_captcha": needCaptcha, }) } diff --git a/http/controller/api/login.go b/http/controller/api/login.go index adaf89f..e2e3cb8 100644 --- a/http/controller/api/login.go +++ b/http/controller/api/login.go @@ -31,10 +31,16 @@ func (l *Login) Login(c *gin.Context) { response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled")) return } + + // 检查登录限制 + loginLimiter := global.LoginLimiter + clientIp := c.ClientIP() + f := &api.LoginForm{} err := c.ShouldBindJSON(f) //fmt.Println(f) if err != nil { + loginLimiter.RecordFailedAttempt(clientIp) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP())) response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error()) return @@ -42,6 +48,7 @@ func (l *Login) Login(c *gin.Context) { errList := global.Validator.ValidStruct(c, f) if len(errList) > 0 { + loginLimiter.RecordFailedAttempt(clientIp) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP())) response.Error(c, errList[0]) return @@ -50,6 +57,7 @@ func (l *Login) Login(c *gin.Context) { u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password) if u.Id == 0 { + loginLimiter.RecordFailedAttempt(clientIp) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP())) response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError")) return diff --git a/http/http.go b/http/http.go index d90549e..9d1273a 100644 --- a/http/http.go +++ b/http/http.go @@ -33,7 +33,7 @@ func ApiInit() { g.NoRoute(func(c *gin.Context) { c.String(http.StatusNotFound, "404 not found") }) - g.Use(middleware.Logger(), gin.Recovery()) + g.Use(middleware.Logger(), middleware.Limiter(), gin.Recovery()) router.WebInit(g) router.Init(g) router.ApiInit(g) diff --git a/http/middleware/limiter.go b/http/middleware/limiter.go new file mode 100644 index 0000000..1edc1b5 --- /dev/null +++ b/http/middleware/limiter.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "github.com/lejianwen/rustdesk-api/v2/global" + "github.com/lejianwen/rustdesk-api/v2/http/response" + "net/http" +) + +func Limiter() gin.HandlerFunc { + return func(c *gin.Context) { + loginLimiter := global.LoginLimiter + clientIp := c.ClientIP() + banned, _ := loginLimiter.CheckSecurityStatus(clientIp) + if banned { + response.Fail(c, http.StatusLocked, response.TranslateMsg(c, "Banned")) + c.Abort() + return + } + c.Next() + } +} diff --git a/resources/i18n/en.toml b/resources/i18n/en.toml index 80c70ba..8cea201 100644 --- a/resources/i18n/en.toml +++ b/resources/i18n/en.toml @@ -142,4 +142,9 @@ other = "Password login disabled." [CannotShareToSelf] description = "Cannot share to self." one = "Cannot share to self." -other = "Cannot share to self." \ No newline at end of file +other = "Cannot share to self." + +[Banned] +description = "Banned." +one = "Banned." +other = "Banned." \ No newline at end of file diff --git a/resources/i18n/es.toml b/resources/i18n/es.toml index d795293..3bcff6f 100644 --- a/resources/i18n/es.toml +++ b/resources/i18n/es.toml @@ -151,4 +151,9 @@ other = "Inicio de sesión con contraseña deshabilitado." [CannotShareToSelf] description = "Cannot share to self." one = "No se puede compartir con uno mismo." -other = "No se puede compartir con uno mismo." \ No newline at end of file +other = "No se puede compartir con uno mismo." + +[Banned] +description = "Banned." +one = "Prohibido." +other = "Prohibido." \ No newline at end of file diff --git a/resources/i18n/fr.toml b/resources/i18n/fr.toml index a6805ca..093f738 100644 --- a/resources/i18n/fr.toml +++ b/resources/i18n/fr.toml @@ -151,4 +151,9 @@ other = "Connexion par mot de passe désactivée." [CannotShareToSelf] description = "Cannot share to self." one = "Impossible de partager avec soi-même." -other = "Impossible de partager avec soi-même." \ No newline at end of file +other = "Impossible de partager avec soi-même." + +[Banned] +description = "Banned." +one = "Banni." +other = "Banni." \ No newline at end of file diff --git a/resources/i18n/ko.toml b/resources/i18n/ko.toml index b273aab..1667c1a 100644 --- a/resources/i18n/ko.toml +++ b/resources/i18n/ko.toml @@ -145,4 +145,9 @@ other = "비밀번호 로그인이 비활성화되었습니다." [CannotShareToSelf] description = "Cannot share to self." one = "자기 자신에게 공유할 수 없습니다." -other = "자기 자신에게 공유할 수 없습니다." \ No newline at end of file +other = "자기 자신에게 공유할 수 없습니다." + +[Banned] +description = "Banned." +one = "금지됨." +other = "금지됨." \ No newline at end of file diff --git a/resources/i18n/ru.toml b/resources/i18n/ru.toml index 675df26..05a0e31 100644 --- a/resources/i18n/ru.toml +++ b/resources/i18n/ru.toml @@ -151,4 +151,9 @@ other = "Вход по паролю отключен." [CannotShareToSelf] description = "Cannot share to self." one = "Нельзя поделиться с собой." -other = "Нельзя поделиться с собой." \ No newline at end of file +other = "Нельзя поделиться с собой." + +[Banned] +description = "Banned." +one = "Заблокировано." +other = "Заблокировано." \ No newline at end of file diff --git a/resources/i18n/zh_CN.toml b/resources/i18n/zh_CN.toml index 030d5c8..6403d03 100644 --- a/resources/i18n/zh_CN.toml +++ b/resources/i18n/zh_CN.toml @@ -144,4 +144,9 @@ other = "密码登录已禁用。" [CannotShareToSelf] description = "Cannot share to self." one = "不能共享给自己。" -other = "不能共享给自己。" \ No newline at end of file +other = "不能共享给自己。" + +[Banned] +description = "Banned." +one = "已被封禁。" +other = "已被封禁。" \ No newline at end of file diff --git a/resources/i18n/zh_TW.toml b/resources/i18n/zh_TW.toml index 3660c27..2197535 100644 --- a/resources/i18n/zh_TW.toml +++ b/resources/i18n/zh_TW.toml @@ -144,4 +144,9 @@ other = "密碼登錄已禁用。" [CannotShareToSelf] description = "Cannot share to self." one = "無法共享給自己。" -other = "無法共享給自己。" \ No newline at end of file +other = "無法共享給自己。" + +[Banned] +description = "Banned." +one = "禁止使用。" +other = "禁止使用。" \ No newline at end of file diff --git a/utils/captcha.go b/utils/captcha.go new file mode 100644 index 0000000..45c365e --- /dev/null +++ b/utils/captcha.go @@ -0,0 +1,48 @@ +package utils + +import ( + "github.com/mojocn/base64Captcha" + "time" +) + +var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil) + +var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil) + +type B64StringCaptchaProvider struct{} + +func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) { + _, content, answer := capdString.GenerateIdQuestionAnswer() + return content, answer, nil +} + +func (p B64StringCaptchaProvider) Expiration() time.Duration { + return 5 * time.Minute +} +func (p B64StringCaptchaProvider) Draw(content string) (string, error) { + item, err := capdString.DrawCaptcha(content) + if err != nil { + return "", err + } + b64str := item.EncodeB64string() + return b64str, nil +} + +type B64MathCaptchaProvider struct{} + +func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) { + _, content, answer := capdMath.GenerateIdQuestionAnswer() + return content, answer, nil +} + +func (p B64MathCaptchaProvider) Expiration() time.Duration { + return 5 * time.Minute +} +func (p B64MathCaptchaProvider) Draw(content string) (string, error) { + item, err := capdMath.DrawCaptcha(content) + if err != nil { + return "", err + } + b64str := item.EncodeB64string() + return b64str, nil +} diff --git a/utils/login_limiter.go b/utils/login_limiter.go new file mode 100644 index 0000000..76aeb01 --- /dev/null +++ b/utils/login_limiter.go @@ -0,0 +1,305 @@ +package utils + +import ( + "errors" + "sync" + "time" +) + +// 安全策略配置 +type SecurityPolicy struct { + CaptchaThreshold int // 尝试失败次数达到验证码阈值,小于0表示不启用, 0表示强制启用 + BanThreshold int // 尝试失败次数达到封禁阈值,为0表示不启用 + AttemptsWindow time.Duration + BanDuration time.Duration +} + +// 验证码提供者接口 +type CaptchaProvider interface { + Generate(ip string) (string, string, error) + //Validate(ip, code string) bool + Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow + Draw(content string) (string, error) // 绘制验证码 +} + +// 验证码元数据 +type CaptchaMeta struct { + Content string + Answer string + ExpiresAt time.Time +} + +// IP封禁记录 +type BanRecord struct { + ExpiresAt time.Time + Reason string +} + +// 登录限制器 +type LoginLimiter struct { + mu sync.Mutex + policy SecurityPolicy + attempts map[string][]time.Time // + captchas map[string]CaptchaMeta + bannedIPs map[string]BanRecord + provider CaptchaProvider + cleanupStop chan struct{} +} + +var defaultSecurityPolicy = SecurityPolicy{ + CaptchaThreshold: 3, + BanThreshold: 5, + AttemptsWindow: 5 * time.Minute, + BanDuration: 30 * time.Minute, +} + +func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter { + // 设置默认值 + if policy.AttemptsWindow == 0 { + policy.AttemptsWindow = 5 * time.Minute + } + if policy.BanDuration == 0 { + policy.BanDuration = 30 * time.Minute + } + + ll := &LoginLimiter{ + policy: policy, + attempts: make(map[string][]time.Time), + captchas: make(map[string]CaptchaMeta), + bannedIPs: make(map[string]BanRecord), + cleanupStop: make(chan struct{}), + } + go ll.cleanupRoutine() + return ll +} + +// 注册验证码提供者 +func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) { + ll.mu.Lock() + defer ll.mu.Unlock() + ll.provider = p +} + +// isDisabled 检查是否禁用登录限制 +func (ll *LoginLimiter) isDisabled() bool { + return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0 +} + +// 记录登录失败尝试 +func (ll *LoginLimiter) RecordFailedAttempt(ip string) { + if ll.isDisabled() { + return + } + ll.mu.Lock() + defer ll.mu.Unlock() + + if banned, _ := ll.isBanned(ip); banned { + return + } + + now := time.Now() + windowStart := now.Add(-ll.policy.AttemptsWindow) + + // 清理过期尝试 + validAttempts := ll.pruneAttempts(ip, windowStart) + + // 记录新尝试 + validAttempts = append(validAttempts, now) + ll.attempts[ip] = validAttempts + + // 检查封禁条件 + if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold { + ll.banIP(ip, "excessive failed attempts") + return + } + + return +} + +// 生成验证码 +func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) { + ll.mu.Lock() + defer ll.mu.Unlock() + + if ll.provider == nil { + return errors.New("no captcha provider available"), CaptchaMeta{} + } + + content, answer, err := ll.provider.Generate(ip) + if err != nil { + return err, CaptchaMeta{} + } + + // 存储验证码 + ll.captchas[ip] = CaptchaMeta{ + Content: content, + Answer: answer, + ExpiresAt: time.Now().Add(ll.provider.Expiration()), + } + + return nil, ll.captchas[ip] +} + +// 验证验证码 +func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool { + ll.mu.Lock() + defer ll.mu.Unlock() + + // 查找匹配验证码 + if ll.provider == nil { + return false + } + + // 获取并验证验证码 + captcha, exists := ll.captchas[ip] + if !exists { + return false + } + + // 清理过期验证码 + if time.Now().After(captcha.ExpiresAt) { + delete(ll.captchas, ip) + return false + } + + // 验证并清理状态 + if answer == captcha.Answer { + delete(ll.captchas, ip) + return true + } + + return false +} + +func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) { + str, err = ll.provider.Draw(content) + return +} + +func (ll *LoginLimiter) RemoveCaptcha(ip string) { + ll.mu.Lock() + defer ll.mu.Unlock() + + _, exists := ll.captchas[ip] + if exists { + delete(ll.captchas, ip) + } +} + +// 清除记录窗口 +func (ll *LoginLimiter) RemoveAttempts(ip string) { + ll.mu.Lock() + defer ll.mu.Unlock() + + _, exists := ll.attempts[ip] + if exists { + delete(ll.attempts, ip) + } +} + +// CheckSecurityStatus 检查安全状态 +func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) { + if ll.isDisabled() { + return + } + ll.mu.Lock() + defer ll.mu.Unlock() + + // 检查封禁状态 + if banned, _ = ll.isBanned(ip); banned { + return + } + + // 清理过期数据 + ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow)) + ll.pruneCaptchas(ip) + + // 检查验证码要求 + captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold + + return +} + +// 后台清理任务 +func (ll *LoginLimiter) cleanupRoutine() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + ll.cleanupExpired() + case <-ll.cleanupStop: + return + } + } +} + +// 内部工具方法 +func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) { + record, exists := ll.bannedIPs[ip] + if !exists { + return false, BanRecord{} + } + if time.Now().After(record.ExpiresAt) { + delete(ll.bannedIPs, ip) + return false, BanRecord{} + } + return true, record +} + +func (ll *LoginLimiter) banIP(ip, reason string) { + ll.bannedIPs[ip] = BanRecord{ + ExpiresAt: time.Now().Add(ll.policy.BanDuration), + Reason: reason, + } + delete(ll.attempts, ip) + delete(ll.captchas, ip) +} + +func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time { + var valid []time.Time + for _, t := range ll.attempts[ip] { + if t.After(cutoff) { + valid = append(valid, t) + } + } + if len(valid) == 0 { + delete(ll.attempts, ip) + } else { + ll.attempts[ip] = valid + } + return valid +} + +func (ll *LoginLimiter) pruneCaptchas(ip string) { + if captcha, exists := ll.captchas[ip]; exists { + if time.Now().After(captcha.ExpiresAt) { + delete(ll.captchas, ip) + } + } +} + +func (ll *LoginLimiter) cleanupExpired() { + ll.mu.Lock() + defer ll.mu.Unlock() + + now := time.Now() + + // 清理封禁记录 + for ip, record := range ll.bannedIPs { + if now.After(record.ExpiresAt) { + delete(ll.bannedIPs, ip) + } + } + + // 清理尝试记录 + for ip := range ll.attempts { + ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow)) + } + + // 清理验证码 + for ip := range ll.captchas { + ll.pruneCaptchas(ip) + } +} diff --git a/utils/login_limiter_test.go b/utils/login_limiter_test.go new file mode 100644 index 0000000..afff418 --- /dev/null +++ b/utils/login_limiter_test.go @@ -0,0 +1,286 @@ +package utils + +import ( + "fmt" + "testing" + "time" +) + +type MockCaptchaProvider struct{} + +func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) { + return "CONTENT", "MOCK", nil +} + +func (p *MockCaptchaProvider) Validate(ip, code string) bool { + return code == "MOCK" +} + +func (p *MockCaptchaProvider) Expiration() time.Duration { + return 2 * time.Second +} +func (p *MockCaptchaProvider) Draw(content string) (string, error) { + return "MOCK", nil +} + +func TestSecurityWorkflow(t *testing.T) { + policy := SecurityPolicy{ + CaptchaThreshold: 3, + BanThreshold: 5, + AttemptsWindow: 5 * time.Minute, + BanDuration: 5 * time.Minute, + } + limiter := NewLoginLimiter(policy) + ip := "192.168.1.100" + + // 测试正常失败记录 + for i := 0; i < 3; i++ { + limiter.RecordFailedAttempt(ip) + } + isBanned, capRequired := limiter.CheckSecurityStatus(ip) + fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired) + if isBanned { + t.Error("IP should not be banned yet") + } + if !capRequired { + t.Error("Captcha should be required") + } + // 测试触发封禁 + for i := 0; i < 3; i++ { + limiter.RecordFailedAttempt(ip) + isBanned, capRequired = limiter.CheckSecurityStatus(ip) + fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired) + } + + // 测试封禁状态 + if isBanned, _ = limiter.CheckSecurityStatus(ip); !isBanned { + t.Error("IP should be banned") + } +} + +func TestCaptchaFlow(t *testing.T) { + policy := SecurityPolicy{CaptchaThreshold: 2} + limiter := NewLoginLimiter(policy) + limiter.RegisterProvider(&MockCaptchaProvider{}) + ip := "10.0.0.1" + + // 触发验证码要求 + limiter.RecordFailedAttempt(ip) + limiter.RecordFailedAttempt(ip) + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } + + // 生成验证码 + err, capc := limiter.RequireCaptcha(ip) + if err != nil { + t.Fatalf("生成验证码失败: %v", err) + } + fmt.Printf("验证码内容: %#v\n", capc) + + // 验证成功 + if !limiter.VerifyCaptcha(ip, capc.Answer) { + t.Error("验证码应该验证成功") + } + + limiter.RemoveAttempts(ip) + // 验证后状态 + if banned, need := limiter.CheckSecurityStatus(ip); banned || need { + t.Error("验证成功后应该重置状态") + } +} + +func TestCaptchaMustFlow(t *testing.T) { + policy := SecurityPolicy{CaptchaThreshold: 0} + limiter := NewLoginLimiter(policy) + limiter.RegisterProvider(&MockCaptchaProvider{}) + ip := "10.0.0.1" + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } + + // 生成验证码 + err, capc := limiter.RequireCaptcha(ip) + if err != nil { + t.Fatalf("生成验证码失败: %v", err) + } + fmt.Printf("验证码内容: %#v\n", capc) + + // 验证成功 + if !limiter.VerifyCaptcha(ip, capc.Answer) { + t.Error("验证码应该验证成功") + } + + // 验证后状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } +} +func TestAttemptTimeout(t *testing.T) { + policy := SecurityPolicy{CaptchaThreshold: 2, AttemptsWindow: 1 * time.Second} + limiter := NewLoginLimiter(policy) + limiter.RegisterProvider(&MockCaptchaProvider{}) + ip := "10.0.0.1" + + // 触发验证码要求 + limiter.RecordFailedAttempt(ip) + limiter.RecordFailedAttempt(ip) + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } + + // 生成验证码 + err, _ := limiter.RequireCaptcha(ip) + if err != nil { + t.Fatalf("生成验证码失败: %v", err) + } + // 等待超过 AttemptsWindow + time.Sleep(2 * time.Second) + // 触发验证码要求 + limiter.RecordFailedAttempt(ip) + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); need { + t.Error("不应该需要验证码") + } +} + +func TestCaptchaTimeout(t *testing.T) { + policy := SecurityPolicy{CaptchaThreshold: 2} + limiter := NewLoginLimiter(policy) + limiter.RegisterProvider(&MockCaptchaProvider{}) + ip := "10.0.0.1" + + // 触发验证码要求 + limiter.RecordFailedAttempt(ip) + limiter.RecordFailedAttempt(ip) + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } + + // 生成验证码 + err, _ := limiter.RequireCaptcha(ip) + if err != nil { + t.Fatalf("生成验证码失败: %v", err) + } + + // 等待超过 CaptchaValidPeriod + time.Sleep(3 * time.Second) + + code := "MOCK" + // 验证成功 + if limiter.VerifyCaptcha(ip, code) { + t.Error("验证码应该已过期") + } + +} + +func TestBanFlow(t *testing.T) { + policy := SecurityPolicy{BanThreshold: 5} + limiter := NewLoginLimiter(policy) + ip := "10.0.0.1" + // 触发ban + for i := 0; i < 5; i++ { + limiter.RecordFailedAttempt(ip) + } + + // 检查状态 + if banned, _ := limiter.CheckSecurityStatus(ip); !banned { + t.Error("should be banned") + } +} +func TestBanDisableFlow(t *testing.T) { + policy := SecurityPolicy{BanThreshold: 0} + limiter := NewLoginLimiter(policy) + ip := "10.0.0.1" + // 触发ban + for i := 0; i < 5; i++ { + limiter.RecordFailedAttempt(ip) + } + + // 检查状态 + if banned, _ := limiter.CheckSecurityStatus(ip); banned { + t.Error("should not be banned") + } +} +func TestBanTimeout(t *testing.T) { + policy := SecurityPolicy{BanThreshold: 5, BanDuration: 1 * time.Second} + limiter := NewLoginLimiter(policy) + ip := "10.0.0.1" + // 触发ban + // 触发ban + for i := 0; i < 5; i++ { + limiter.RecordFailedAttempt(ip) + } + + time.Sleep(2 * time.Second) + + // 检查状态 + if banned, _ := limiter.CheckSecurityStatus(ip); banned { + t.Error("should not be banned") + } +} + +func TestLimiterDisabled(t *testing.T) { + policy := SecurityPolicy{BanThreshold: 0, CaptchaThreshold: -1} + limiter := NewLoginLimiter(policy) + ip := "10.0.0.1" + // 触发ban + for i := 0; i < 5; i++ { + limiter.RecordFailedAttempt(ip) + } + + // 检查状态 + if banned, capNeed := limiter.CheckSecurityStatus(ip); banned || capNeed { + fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, banned, capNeed) + t.Error("should not be banned or need captcha") + } +} + +func TestB64CaptchaFlow(t *testing.T) { + limiter := NewLoginLimiter(defaultSecurityPolicy) + limiter.RegisterProvider(B64StringCaptchaProvider{}) + ip := "10.0.0.1" + + // 触发验证码要求 + limiter.RecordFailedAttempt(ip) + limiter.RecordFailedAttempt(ip) + limiter.RecordFailedAttempt(ip) + + // 检查状态 + if _, need := limiter.CheckSecurityStatus(ip); !need { + t.Error("应该需要验证码") + } + + // 生成验证码 + err, capc := limiter.RequireCaptcha(ip) + if err != nil { + t.Fatalf("生成验证码失败: %v", err) + } + fmt.Printf("验证码内容: %#v\n", capc) + + //draw + err, b64 := limiter.DrawCaptcha(capc.Content) + if err != nil { + t.Fatalf("绘制验证码失败: %v", err) + } + fmt.Printf("验证码内容: %#v\n", b64) + + // 验证成功 + if !limiter.VerifyCaptcha(ip, capc.Answer) { + t.Error("验证码应该验证成功") + } + limiter.RemoveAttempts(ip) + // 验证后状态 + if banned, need := limiter.CheckSecurityStatus(ip); banned || need { + t.Error("验证成功后应该重置状态") + } +}