Compare commits

..

5 Commits

Author SHA1 Message Date
lejianwen
97f98cd6ce chore: update download links for musl cross-compilers 2025-06-05 12:14:17 +08:00
lejianwen
51f2920661 fix: Init sqlite fail(#266) 2025-06-04 09:31:43 +08:00
lejianwen
7a5d141ce8 fix(server): Port custom (#257) 2025-05-30 12:27:37 +08:00
lejianwen
3cef02a0bb fix(webclient): Peer online status 2025-05-29 18:51:37 +08:00
lejianwen
46a7ecc1ba fix: Captcha some problem when users login with same ip 2025-05-27 17:36:20 +08:00
15 changed files with 87 additions and 95 deletions

View File

@@ -115,12 +115,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go

View File

@@ -101,12 +101,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go

2
.gitignore vendored
View File

@@ -5,4 +5,4 @@ runtime/*
go.sum go.sum
resources/admin resources/admin
release release
data data/rustdeskapi.db

View File

@@ -14,6 +14,9 @@ admin:
title: "RustDesk Api Admin" title: "RustDesk Api Admin"
hello-file: "./conf/admin/hello.html" #优先使用file hello-file: "./conf/admin/hello.html" #优先使用file
hello: "" hello: ""
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
id-server-port: 21116 # ID Server port (for server cmd)
relay-server-port: 21117 # ID Server port (for server cmd)
gin: gin:
api-addr: "0.0.0.0:21114" api-addr: "0.0.0.0:21114"
mode: "release" #release,debug,test mode: "release" #release,debug,test

View File

@@ -25,9 +25,11 @@ type App struct {
BanThreshold int `mapstructure:"ban-threshold"` BanThreshold int `mapstructure:"ban-threshold"`
} }
type Admin struct { type Admin struct {
Title string `mapstructure:"title"` Title string `mapstructure:"title"`
Hello string `mapstructure:"hello"` Hello string `mapstructure:"hello"`
HelloFile string `mapstructure:"hello-file"` HelloFile string `mapstructure:"hello-file"`
IdServerPort int `mapstructure:"id-server-port"`
RelayServerPort int `mapstructure:"relay-server-port"`
} }
type Config struct { type Config struct {
Lang string `mapstructure:"lang"` Lang string `mapstructure:"lang"`
@@ -46,6 +48,15 @@ type Config struct {
Ldap Ldap Ldap Ldap
} }
func (a *Admin) Init() {
if a.IdServerPort == 0 {
a.IdServerPort = DefaultIdServerPort
}
if a.RelayServerPort == 0 {
a.RelayServerPort = DefaultRelayServerPort
}
}
// Init 初始化配置 // Init 初始化配置
func Init(rowVal *Config, path string) *viper.Viper { func Init(rowVal *Config, path string) *viper.Viper {
if path == "" { if path == "" {
@@ -80,7 +91,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
panic(fmt.Errorf("Fatal error config: %s \n", err)) panic(fmt.Errorf("Fatal error config: %s \n", err))
} }
rowVal.Rustdesk.LoadKeyFile() rowVal.Rustdesk.LoadKeyFile()
rowVal.Rustdesk.ParsePort() rowVal.Admin.Init()
return v return v
} }

View File

@@ -2,8 +2,6 @@ package config
import ( import (
"os" "os"
"strconv"
"strings"
) )
const ( const (
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
return return
} }
} }
func (rd *Rustdesk) ParsePort() {
// Parse port
idres := strings.Split(rd.IdServer, ":")
if len(idres) == 1 {
rd.IdServerPort = DefaultIdServerPort
} else if len(idres) == 2 {
rd.IdServerPort, _ = strconv.Atoi(idres[1])
}
relayres := strings.Split(rd.RelayServer, ":")
if len(relayres) == 1 {
rd.RelayServerPort = DefaultRelayServerPort
} else if len(relayres) == 2 {
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
}
}

0
data/.gitkeep Normal file
View File

View File

@@ -57,7 +57,7 @@ func (ct *Login) Login(c *gin.Context) {
// 检查是否需要验证码 // 检查是否需要验证码
if needCaptcha { if needCaptcha {
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) { if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")) response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
return return
} }
@@ -68,8 +68,6 @@ func (ct *Login) Login(c *gin.Context) {
if u.Id == 0 { if u.Id == 0 {
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp)) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
loginLimiter.RecordFailedAttempt(clientIp) loginLimiter.RecordFailedAttempt(clientIp)
// 移除验证码,重新生成
loginLimiter.RemoveCaptcha(clientIp)
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha { if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError")) response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
} else { } else {
@@ -80,7 +78,6 @@ func (ct *Login) Login(c *gin.Context) {
if !service.AllService.UserService.CheckUserEnable(u) { if !service.AllService.UserService.CheckUserEnable(u) {
if needCaptcha { if needCaptcha {
loginLimiter.RemoveCaptcha(clientIp)
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled")) response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
return return
} }
@@ -113,7 +110,7 @@ func (ct *Login) Captcha(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired")) response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
return return
} }
err, captcha := loginLimiter.RequireCaptcha(clientIp) err, captcha := loginLimiter.RequireCaptcha()
if err != nil { if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error()) response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return return
@@ -125,6 +122,7 @@ func (ct *Login) Captcha(c *gin.Context) {
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"captcha": gin.H{ "captcha": gin.H{
"id": captcha.Id,
"b64": b64, "b64": b64,
}, },
}) })

View File

@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")) response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return return
} }
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
port := 0
switch rc.Target {
case model.ServerCmdTargetIdServer:
port = global.Config.Admin.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = global.Config.Admin.RelayServerPort
}
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
if err != nil { if err != nil {
response.Fail(c, 101, err.Error()) response.Fail(c, 101, err.Error())
return return

View File

@@ -1,10 +1,11 @@
package admin package admin
type Login struct { type Login struct {
Username string `json:"username" validate:"required" label:"用户名"` Username string `json:"username" validate:"required" label:"用户名"`
Password string `json:"password,omitempty" validate:"required" label:"密码"` Password string `json:"password,omitempty" validate:"required" label:"密码"`
Platform string `json:"platform" label:"平台"` Platform string `json:"platform" label:"平台"`
Captcha string `json:"captcha,omitempty" label:"验证码"` Captcha string `json:"captcha,omitempty" label:"验证码"`
CaptchaId string `json:"captcha_id,omitempty"`
} }
type LoginLogQuery struct { type LoginLogQuery struct {

View File

@@ -11550,7 +11550,7 @@ async function or(u) {
let E = [], l = []; let E = [], l = [];
for (let d = 0; d < e.length; d++) { for (let d = 0; d < e.length; d++) {
const c = 1 << 7 - d % 8; const c = 1 << 7 - d % 8;
(s[d / 8] & c) === c ? E.push(e[d]) : l.push(e[d]) (s[Math.floor(d / 8)] & c) === c ? E.push(e[d]) : l.push(e[d])
} }
_t(E, l), n.close(); _t(E, l), n.close();
return return

View File

@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
} }
// SendCmd 发送命令 // SendCmd 发送命令
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) { func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
port := 0
switch target {
case model.ServerCmdTargetIdServer:
port = Config.Rustdesk.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = Config.Rustdesk.RelayServerPort
}
//组装命令 //组装命令
cmd = cmd + " " + arg cmd = cmd + " " + arg
res, err := is.SendSocketCmd("v6", port, cmd) res, err := is.SendSocketCmd("v6", port, cmd)

View File

@@ -5,15 +5,15 @@ import (
"time" "time"
) )
var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil) var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil) var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
type B64StringCaptchaProvider struct{} type B64StringCaptchaProvider struct{}
func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) { func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
_, content, answer := capdString.GenerateIdQuestionAnswer() id, content, answer := capdString.GenerateIdQuestionAnswer()
return content, answer, nil return id, content, answer, nil
} }
func (p B64StringCaptchaProvider) Expiration() time.Duration { func (p B64StringCaptchaProvider) Expiration() time.Duration {
@@ -30,9 +30,9 @@ func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
type B64MathCaptchaProvider struct{} type B64MathCaptchaProvider struct{}
func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) { func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
_, content, answer := capdMath.GenerateIdQuestionAnswer() id, content, answer := capdMath.GenerateIdQuestionAnswer()
return content, answer, nil return id, content, answer, nil
} }
func (p B64MathCaptchaProvider) Expiration() time.Duration { func (p B64MathCaptchaProvider) Expiration() time.Duration {

View File

@@ -16,7 +16,7 @@ type SecurityPolicy struct {
// 验证码提供者接口 // 验证码提供者接口
type CaptchaProvider interface { type CaptchaProvider interface {
Generate(ip string) (string, string, error) Generate() (id string, content string, answer string, err error)
//Validate(ip, code string) bool //Validate(ip, code string) bool
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
Draw(content string) (string, error) // 绘制验证码 Draw(content string) (string, error) // 绘制验证码
@@ -24,6 +24,7 @@ type CaptchaProvider interface {
// 验证码元数据 // 验证码元数据
type CaptchaMeta struct { type CaptchaMeta struct {
Id string
Content string Content string
Answer string Answer string
ExpiresAt time.Time ExpiresAt time.Time
@@ -117,7 +118,7 @@ func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
} }
// 生成验证码 // 生成验证码
func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) { func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
ll.mu.Lock() ll.mu.Lock()
defer ll.mu.Unlock() defer ll.mu.Unlock()
@@ -125,23 +126,24 @@ func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
return errors.New("no captcha provider available"), CaptchaMeta{} return errors.New("no captcha provider available"), CaptchaMeta{}
} }
content, answer, err := ll.provider.Generate(ip) id, content, answer, err := ll.provider.Generate()
if err != nil { if err != nil {
return err, CaptchaMeta{} return err, CaptchaMeta{}
} }
// 存储验证码 // 存储验证码
ll.captchas[ip] = CaptchaMeta{ ll.captchas[id] = CaptchaMeta{
Id: id,
Content: content, Content: content,
Answer: answer, Answer: answer,
ExpiresAt: time.Now().Add(ll.provider.Expiration()), ExpiresAt: time.Now().Add(ll.provider.Expiration()),
} }
return nil, ll.captchas[ip] return nil, ll.captchas[id]
} }
// 验证验证码 // 验证验证码
func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool { func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
ll.mu.Lock() ll.mu.Lock()
defer ll.mu.Unlock() defer ll.mu.Unlock()
@@ -151,20 +153,20 @@ func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
} }
// 获取并验证验证码 // 获取并验证验证码
captcha, exists := ll.captchas[ip] captcha, exists := ll.captchas[id]
if !exists { if !exists {
return false return false
} }
// 清理过期验证码 // 清理过期验证码
if time.Now().After(captcha.ExpiresAt) { if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, ip) delete(ll.captchas, id)
return false return false
} }
// 验证并清理状态 // 验证并清理状态
if answer == captcha.Answer { if answer == captcha.Answer {
delete(ll.captchas, ip) delete(ll.captchas, id)
return true return true
} }
@@ -176,16 +178,6 @@ func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
return return
} }
func (ll *LoginLimiter) RemoveCaptcha(ip string) {
ll.mu.Lock()
defer ll.mu.Unlock()
_, exists := ll.captchas[ip]
if exists {
delete(ll.captchas, ip)
}
}
// 清除记录窗口 // 清除记录窗口
func (ll *LoginLimiter) RemoveAttempts(ip string) { func (ll *LoginLimiter) RemoveAttempts(ip string) {
ll.mu.Lock() ll.mu.Lock()
@@ -212,7 +204,6 @@ func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequ
// 清理过期数据 // 清理过期数据
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow)) ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
ll.pruneCaptchas(ip)
// 检查验证码要求 // 检查验证码要求
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
@@ -272,10 +263,10 @@ func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
return valid return valid
} }
func (ll *LoginLimiter) pruneCaptchas(ip string) { func (ll *LoginLimiter) pruneCaptchas(id string) {
if captcha, exists := ll.captchas[ip]; exists { if captcha, exists := ll.captchas[id]; exists {
if time.Now().After(captcha.ExpiresAt) { if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, ip) delete(ll.captchas, id)
} }
} }
} }
@@ -299,7 +290,7 @@ func (ll *LoginLimiter) cleanupExpired() {
} }
// 清理验证码 // 清理验证码
for ip := range ll.captchas { for id := range ll.captchas {
ll.pruneCaptchas(ip) ll.pruneCaptchas(id)
} }
} }

View File

@@ -2,18 +2,18 @@ package utils
import ( import (
"fmt" "fmt"
"github.com/google/uuid"
"testing" "testing"
"time" "time"
) )
type MockCaptchaProvider struct{} type MockCaptchaProvider struct{}
func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) { func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
return "CONTENT", "MOCK", nil id := uuid.New().String()
} content := uuid.New().String()
answer := uuid.New().String()
func (p *MockCaptchaProvider) Validate(ip, code string) bool { return id, content, answer, nil
return code == "MOCK"
} }
func (p *MockCaptchaProvider) Expiration() time.Duration { func (p *MockCaptchaProvider) Expiration() time.Duration {
@@ -74,17 +74,22 @@ func TestCaptchaFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
fmt.Printf("验证码内容: %#v\n", capc) fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
// 验证已删除
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已删除")
}
limiter.RemoveAttempts(ip) limiter.RemoveAttempts(ip)
// 验证后状态 // 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need { if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
@@ -104,14 +109,14 @@ func TestCaptchaMustFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
fmt.Printf("验证码内容: %#v\n", capc) fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
@@ -136,7 +141,7 @@ func TestAttemptTimeout(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, _ := limiter.RequireCaptcha(ip) err, _ := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -167,7 +172,7 @@ func TestCaptchaTimeout(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, _ := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -175,9 +180,8 @@ func TestCaptchaTimeout(t *testing.T) {
// 等待超过 CaptchaValidPeriod // 等待超过 CaptchaValidPeriod
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
code := "MOCK"
// 验证成功 // 验证成功
if limiter.VerifyCaptcha(ip, code) { if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已过期") t.Error("验证码应该已过期")
} }
@@ -261,7 +265,7 @@ func TestB64CaptchaFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -275,7 +279,7 @@ func TestB64CaptchaFlow(t *testing.T) {
fmt.Printf("验证码内容: %#v\n", b64) fmt.Printf("验证码内容: %#v\n", b64)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
limiter.RemoveAttempts(ip) limiter.RemoveAttempts(ip)