mirror of
https://github.com/lejianwen/rustdesk-api.git
synced 2026-02-21 11:51:07 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97f98cd6ce | ||
|
|
51f2920661 | ||
|
|
7a5d141ce8 | ||
|
|
3cef02a0bb | ||
|
|
46a7ecc1ba |
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -115,12 +115,12 @@ jobs:
|
|||||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||||
else
|
else
|
||||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||||
tar -xf aarch64-linux-musl-cross.tgz
|
tar -xf aarch64-linux-musl-cross.tgz
|
||||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||||
|
|||||||
4
.github/workflows/build_test.yml
vendored
4
.github/workflows/build_test.yml
vendored
@@ -101,12 +101,12 @@ jobs:
|
|||||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||||
else
|
else
|
||||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||||
tar -xf aarch64-linux-musl-cross.tgz
|
tar -xf aarch64-linux-musl-cross.tgz
|
||||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,4 +5,4 @@ runtime/*
|
|||||||
go.sum
|
go.sum
|
||||||
resources/admin
|
resources/admin
|
||||||
release
|
release
|
||||||
data
|
data/rustdeskapi.db
|
||||||
@@ -14,6 +14,9 @@ admin:
|
|||||||
title: "RustDesk Api Admin"
|
title: "RustDesk Api Admin"
|
||||||
hello-file: "./conf/admin/hello.html" #优先使用file
|
hello-file: "./conf/admin/hello.html" #优先使用file
|
||||||
hello: ""
|
hello: ""
|
||||||
|
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
|
||||||
|
id-server-port: 21116 # ID Server port (for server cmd)
|
||||||
|
relay-server-port: 21117 # ID Server port (for server cmd)
|
||||||
gin:
|
gin:
|
||||||
api-addr: "0.0.0.0:21114"
|
api-addr: "0.0.0.0:21114"
|
||||||
mode: "release" #release,debug,test
|
mode: "release" #release,debug,test
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ type App struct {
|
|||||||
BanThreshold int `mapstructure:"ban-threshold"`
|
BanThreshold int `mapstructure:"ban-threshold"`
|
||||||
}
|
}
|
||||||
type Admin struct {
|
type Admin struct {
|
||||||
Title string `mapstructure:"title"`
|
Title string `mapstructure:"title"`
|
||||||
Hello string `mapstructure:"hello"`
|
Hello string `mapstructure:"hello"`
|
||||||
HelloFile string `mapstructure:"hello-file"`
|
HelloFile string `mapstructure:"hello-file"`
|
||||||
|
IdServerPort int `mapstructure:"id-server-port"`
|
||||||
|
RelayServerPort int `mapstructure:"relay-server-port"`
|
||||||
}
|
}
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Lang string `mapstructure:"lang"`
|
Lang string `mapstructure:"lang"`
|
||||||
@@ -46,6 +48,15 @@ type Config struct {
|
|||||||
Ldap Ldap
|
Ldap Ldap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Admin) Init() {
|
||||||
|
if a.IdServerPort == 0 {
|
||||||
|
a.IdServerPort = DefaultIdServerPort
|
||||||
|
}
|
||||||
|
if a.RelayServerPort == 0 {
|
||||||
|
a.RelayServerPort = DefaultRelayServerPort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Init 初始化配置
|
// Init 初始化配置
|
||||||
func Init(rowVal *Config, path string) *viper.Viper {
|
func Init(rowVal *Config, path string) *viper.Viper {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
@@ -80,7 +91,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
|
|||||||
panic(fmt.Errorf("Fatal error config: %s \n", err))
|
panic(fmt.Errorf("Fatal error config: %s \n", err))
|
||||||
}
|
}
|
||||||
rowVal.Rustdesk.LoadKeyFile()
|
rowVal.Rustdesk.LoadKeyFile()
|
||||||
rowVal.Rustdesk.ParsePort()
|
rowVal.Admin.Init()
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (rd *Rustdesk) ParsePort() {
|
|
||||||
// Parse port
|
|
||||||
idres := strings.Split(rd.IdServer, ":")
|
|
||||||
if len(idres) == 1 {
|
|
||||||
rd.IdServerPort = DefaultIdServerPort
|
|
||||||
} else if len(idres) == 2 {
|
|
||||||
rd.IdServerPort, _ = strconv.Atoi(idres[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
relayres := strings.Split(rd.RelayServer, ":")
|
|
||||||
if len(relayres) == 1 {
|
|
||||||
rd.RelayServerPort = DefaultRelayServerPort
|
|
||||||
} else if len(relayres) == 2 {
|
|
||||||
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
@@ -57,7 +57,7 @@ func (ct *Login) Login(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查是否需要验证码
|
// 检查是否需要验证码
|
||||||
if needCaptcha {
|
if needCaptcha {
|
||||||
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
|
if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
|
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -68,8 +68,6 @@ func (ct *Login) Login(c *gin.Context) {
|
|||||||
if u.Id == 0 {
|
if u.Id == 0 {
|
||||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
|
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
|
||||||
loginLimiter.RecordFailedAttempt(clientIp)
|
loginLimiter.RecordFailedAttempt(clientIp)
|
||||||
// 移除验证码,重新生成
|
|
||||||
loginLimiter.RemoveCaptcha(clientIp)
|
|
||||||
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
|
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
|
||||||
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||||
} else {
|
} else {
|
||||||
@@ -80,7 +78,6 @@ func (ct *Login) Login(c *gin.Context) {
|
|||||||
|
|
||||||
if !service.AllService.UserService.CheckUserEnable(u) {
|
if !service.AllService.UserService.CheckUserEnable(u) {
|
||||||
if needCaptcha {
|
if needCaptcha {
|
||||||
loginLimiter.RemoveCaptcha(clientIp)
|
|
||||||
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
|
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -113,7 +110,7 @@ func (ct *Login) Captcha(c *gin.Context) {
|
|||||||
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
|
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err, captcha := loginLimiter.RequireCaptcha(clientIp)
|
err, captcha := loginLimiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
|
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
|
||||||
return
|
return
|
||||||
@@ -125,6 +122,7 @@ func (ct *Login) Captcha(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"captcha": gin.H{
|
"captcha": gin.H{
|
||||||
|
"id": captcha.Id,
|
||||||
"b64": b64,
|
"b64": b64,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
|
|||||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
|
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
|
|
||||||
|
port := 0
|
||||||
|
switch rc.Target {
|
||||||
|
case model.ServerCmdTargetIdServer:
|
||||||
|
port = global.Config.Admin.IdServerPort - 1
|
||||||
|
case model.ServerCmdTargetRelayServer:
|
||||||
|
port = global.Config.Admin.RelayServerPort
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Fail(c, 101, err.Error())
|
response.Fail(c, 101, err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
type Login struct {
|
type Login struct {
|
||||||
Username string `json:"username" validate:"required" label:"用户名"`
|
Username string `json:"username" validate:"required" label:"用户名"`
|
||||||
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
||||||
Platform string `json:"platform" label:"平台"`
|
Platform string `json:"platform" label:"平台"`
|
||||||
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
||||||
|
CaptchaId string `json:"captcha_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoginLogQuery struct {
|
type LoginLogQuery struct {
|
||||||
|
|||||||
2
resources/web2/js/dist/index.js
vendored
2
resources/web2/js/dist/index.js
vendored
@@ -11550,7 +11550,7 @@ async function or(u) {
|
|||||||
let E = [], l = [];
|
let E = [], l = [];
|
||||||
for (let d = 0; d < e.length; d++) {
|
for (let d = 0; d < e.length; d++) {
|
||||||
const c = 1 << 7 - d % 8;
|
const c = 1 << 7 - d % 8;
|
||||||
(s[d / 8] & c) === c ? E.push(e[d]) : l.push(e[d])
|
(s[Math.floor(d / 8)] & c) === c ? E.push(e[d]) : l.push(e[d])
|
||||||
}
|
}
|
||||||
_t(E, l), n.close();
|
_t(E, l), n.close();
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendCmd 发送命令
|
// SendCmd 发送命令
|
||||||
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) {
|
func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
|
||||||
port := 0
|
|
||||||
switch target {
|
|
||||||
case model.ServerCmdTargetIdServer:
|
|
||||||
port = Config.Rustdesk.IdServerPort - 1
|
|
||||||
case model.ServerCmdTargetRelayServer:
|
|
||||||
port = Config.Rustdesk.RelayServerPort
|
|
||||||
}
|
|
||||||
//组装命令
|
//组装命令
|
||||||
cmd = cmd + " " + arg
|
cmd = cmd + " " + arg
|
||||||
res, err := is.SendSocketCmd("v6", port, cmd)
|
res, err := is.SendSocketCmd("v6", port, cmd)
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
||||||
|
|
||||||
var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil)
|
var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
|
||||||
|
|
||||||
type B64StringCaptchaProvider struct{}
|
type B64StringCaptchaProvider struct{}
|
||||||
|
|
||||||
func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) {
|
func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
|
||||||
_, content, answer := capdString.GenerateIdQuestionAnswer()
|
id, content, answer := capdString.GenerateIdQuestionAnswer()
|
||||||
return content, answer, nil
|
return id, content, answer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p B64StringCaptchaProvider) Expiration() time.Duration {
|
func (p B64StringCaptchaProvider) Expiration() time.Duration {
|
||||||
@@ -30,9 +30,9 @@ func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
|
|||||||
|
|
||||||
type B64MathCaptchaProvider struct{}
|
type B64MathCaptchaProvider struct{}
|
||||||
|
|
||||||
func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) {
|
func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
|
||||||
_, content, answer := capdMath.GenerateIdQuestionAnswer()
|
id, content, answer := capdMath.GenerateIdQuestionAnswer()
|
||||||
return content, answer, nil
|
return id, content, answer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p B64MathCaptchaProvider) Expiration() time.Duration {
|
func (p B64MathCaptchaProvider) Expiration() time.Duration {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type SecurityPolicy struct {
|
|||||||
|
|
||||||
// 验证码提供者接口
|
// 验证码提供者接口
|
||||||
type CaptchaProvider interface {
|
type CaptchaProvider interface {
|
||||||
Generate(ip string) (string, string, error)
|
Generate() (id string, content string, answer string, err error)
|
||||||
//Validate(ip, code string) bool
|
//Validate(ip, code string) bool
|
||||||
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
|
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
|
||||||
Draw(content string) (string, error) // 绘制验证码
|
Draw(content string) (string, error) // 绘制验证码
|
||||||
@@ -24,6 +24,7 @@ type CaptchaProvider interface {
|
|||||||
|
|
||||||
// 验证码元数据
|
// 验证码元数据
|
||||||
type CaptchaMeta struct {
|
type CaptchaMeta struct {
|
||||||
|
Id string
|
||||||
Content string
|
Content string
|
||||||
Answer string
|
Answer string
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
@@ -117,7 +118,7 @@ func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
|
func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
|
||||||
ll.mu.Lock()
|
ll.mu.Lock()
|
||||||
defer ll.mu.Unlock()
|
defer ll.mu.Unlock()
|
||||||
|
|
||||||
@@ -125,23 +126,24 @@ func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
|
|||||||
return errors.New("no captcha provider available"), CaptchaMeta{}
|
return errors.New("no captcha provider available"), CaptchaMeta{}
|
||||||
}
|
}
|
||||||
|
|
||||||
content, answer, err := ll.provider.Generate(ip)
|
id, content, answer, err := ll.provider.Generate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, CaptchaMeta{}
|
return err, CaptchaMeta{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 存储验证码
|
// 存储验证码
|
||||||
ll.captchas[ip] = CaptchaMeta{
|
ll.captchas[id] = CaptchaMeta{
|
||||||
|
Id: id,
|
||||||
Content: content,
|
Content: content,
|
||||||
Answer: answer,
|
Answer: answer,
|
||||||
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
|
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ll.captchas[ip]
|
return nil, ll.captchas[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证验证码
|
// 验证验证码
|
||||||
func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
|
func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
|
||||||
ll.mu.Lock()
|
ll.mu.Lock()
|
||||||
defer ll.mu.Unlock()
|
defer ll.mu.Unlock()
|
||||||
|
|
||||||
@@ -151,20 +153,20 @@ func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取并验证验证码
|
// 获取并验证验证码
|
||||||
captcha, exists := ll.captchas[ip]
|
captcha, exists := ll.captchas[id]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理过期验证码
|
// 清理过期验证码
|
||||||
if time.Now().After(captcha.ExpiresAt) {
|
if time.Now().After(captcha.ExpiresAt) {
|
||||||
delete(ll.captchas, ip)
|
delete(ll.captchas, id)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证并清理状态
|
// 验证并清理状态
|
||||||
if answer == captcha.Answer {
|
if answer == captcha.Answer {
|
||||||
delete(ll.captchas, ip)
|
delete(ll.captchas, id)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,16 +178,6 @@ func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ll *LoginLimiter) RemoveCaptcha(ip string) {
|
|
||||||
ll.mu.Lock()
|
|
||||||
defer ll.mu.Unlock()
|
|
||||||
|
|
||||||
_, exists := ll.captchas[ip]
|
|
||||||
if exists {
|
|
||||||
delete(ll.captchas, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清除记录窗口
|
// 清除记录窗口
|
||||||
func (ll *LoginLimiter) RemoveAttempts(ip string) {
|
func (ll *LoginLimiter) RemoveAttempts(ip string) {
|
||||||
ll.mu.Lock()
|
ll.mu.Lock()
|
||||||
@@ -212,7 +204,6 @@ func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequ
|
|||||||
|
|
||||||
// 清理过期数据
|
// 清理过期数据
|
||||||
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
|
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
|
||||||
ll.pruneCaptchas(ip)
|
|
||||||
|
|
||||||
// 检查验证码要求
|
// 检查验证码要求
|
||||||
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
|
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
|
||||||
@@ -272,10 +263,10 @@ func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
|
|||||||
return valid
|
return valid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ll *LoginLimiter) pruneCaptchas(ip string) {
|
func (ll *LoginLimiter) pruneCaptchas(id string) {
|
||||||
if captcha, exists := ll.captchas[ip]; exists {
|
if captcha, exists := ll.captchas[id]; exists {
|
||||||
if time.Now().After(captcha.ExpiresAt) {
|
if time.Now().After(captcha.ExpiresAt) {
|
||||||
delete(ll.captchas, ip)
|
delete(ll.captchas, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -299,7 +290,7 @@ func (ll *LoginLimiter) cleanupExpired() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清理验证码
|
// 清理验证码
|
||||||
for ip := range ll.captchas {
|
for id := range ll.captchas {
|
||||||
ll.pruneCaptchas(ip)
|
ll.pruneCaptchas(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,18 +2,18 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/google/uuid"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockCaptchaProvider struct{}
|
type MockCaptchaProvider struct{}
|
||||||
|
|
||||||
func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) {
|
func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
|
||||||
return "CONTENT", "MOCK", nil
|
id := uuid.New().String()
|
||||||
}
|
content := uuid.New().String()
|
||||||
|
answer := uuid.New().String()
|
||||||
func (p *MockCaptchaProvider) Validate(ip, code string) bool {
|
return id, content, answer, nil
|
||||||
return code == "MOCK"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *MockCaptchaProvider) Expiration() time.Duration {
|
func (p *MockCaptchaProvider) Expiration() time.Duration {
|
||||||
@@ -74,17 +74,22 @@ func TestCaptchaFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
err, capc := limiter.RequireCaptcha(ip)
|
err, capc := limiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("生成验证码失败: %v", err)
|
t.Fatalf("生成验证码失败: %v", err)
|
||||||
}
|
}
|
||||||
fmt.Printf("验证码内容: %#v\n", capc)
|
fmt.Printf("验证码内容: %#v\n", capc)
|
||||||
|
|
||||||
// 验证成功
|
// 验证成功
|
||||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||||
t.Error("验证码应该验证成功")
|
t.Error("验证码应该验证成功")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证已删除
|
||||||
|
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||||
|
t.Error("验证码应该已删除")
|
||||||
|
}
|
||||||
|
|
||||||
limiter.RemoveAttempts(ip)
|
limiter.RemoveAttempts(ip)
|
||||||
// 验证后状态
|
// 验证后状态
|
||||||
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
|
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
|
||||||
@@ -104,14 +109,14 @@ func TestCaptchaMustFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
err, capc := limiter.RequireCaptcha(ip)
|
err, capc := limiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("生成验证码失败: %v", err)
|
t.Fatalf("生成验证码失败: %v", err)
|
||||||
}
|
}
|
||||||
fmt.Printf("验证码内容: %#v\n", capc)
|
fmt.Printf("验证码内容: %#v\n", capc)
|
||||||
|
|
||||||
// 验证成功
|
// 验证成功
|
||||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||||
t.Error("验证码应该验证成功")
|
t.Error("验证码应该验证成功")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +141,7 @@ func TestAttemptTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
err, _ := limiter.RequireCaptcha(ip)
|
err, _ := limiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("生成验证码失败: %v", err)
|
t.Fatalf("生成验证码失败: %v", err)
|
||||||
}
|
}
|
||||||
@@ -167,7 +172,7 @@ func TestCaptchaTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
err, _ := limiter.RequireCaptcha(ip)
|
err, capc := limiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("生成验证码失败: %v", err)
|
t.Fatalf("生成验证码失败: %v", err)
|
||||||
}
|
}
|
||||||
@@ -175,9 +180,8 @@ func TestCaptchaTimeout(t *testing.T) {
|
|||||||
// 等待超过 CaptchaValidPeriod
|
// 等待超过 CaptchaValidPeriod
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
code := "MOCK"
|
|
||||||
// 验证成功
|
// 验证成功
|
||||||
if limiter.VerifyCaptcha(ip, code) {
|
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||||
t.Error("验证码应该已过期")
|
t.Error("验证码应该已过期")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,7 +265,7 @@ func TestB64CaptchaFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成验证码
|
// 生成验证码
|
||||||
err, capc := limiter.RequireCaptcha(ip)
|
err, capc := limiter.RequireCaptcha()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("生成验证码失败: %v", err)
|
t.Fatalf("生成验证码失败: %v", err)
|
||||||
}
|
}
|
||||||
@@ -275,7 +279,7 @@ func TestB64CaptchaFlow(t *testing.T) {
|
|||||||
fmt.Printf("验证码内容: %#v\n", b64)
|
fmt.Printf("验证码内容: %#v\n", b64)
|
||||||
|
|
||||||
// 验证成功
|
// 验证成功
|
||||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||||
t.Error("验证码应该验证成功")
|
t.Error("验证码应该验证成功")
|
||||||
}
|
}
|
||||||
limiter.RemoveAttempts(ip)
|
limiter.RemoveAttempts(ip)
|
||||||
|
|||||||
Reference in New Issue
Block a user