Compare commits

..

11 Commits

Author SHA1 Message Date
lejianwen
2948eaaa5c chore: Update Go version to 1.23 in build configurations 2025-06-16 15:41:16 +08:00
lejianwen
8641ba5c0c docs: Update swagger docs 2025-06-16 12:31:48 +08:00
lejianwen
60b7a18fe7 feat: Add PostgreSQL support and refactor MySQL DSN handling (#284) 2025-06-16 12:26:08 +08:00
lejianwen
ca068816ae feat: Add start time in /api/sysinfover 2025-06-16 12:23:48 +08:00
lejianwen
06648d9a6c fix(admin): Use admin-hello first
(#274) (#255)
2025-06-15 15:33:12 +08:00
puyujian
8a8abd5163 feat(oauth): 支持linux.do登录 (#280)
* 支持linux.do登录

* 修正
2025-06-15 15:32:20 +08:00
lejianwen
97f98cd6ce chore: update download links for musl cross-compilers 2025-06-05 12:14:17 +08:00
lejianwen
51f2920661 fix: Init sqlite fail(#266) 2025-06-04 09:31:43 +08:00
lejianwen
7a5d141ce8 fix(server): Port custom (#257) 2025-05-30 12:27:37 +08:00
lejianwen
3cef02a0bb fix(webclient): Peer online status 2025-05-29 18:51:37 +08:00
lejianwen
46a7ecc1ba fix: Captcha some problem when users login with same ip 2025-05-27 17:36:20 +08:00
33 changed files with 411 additions and 149 deletions

View File

@@ -66,7 +66,7 @@ jobs:
- name: Set up Go environment - name: Set up Go environment
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:
go-version: '1.22' # 选择 Go 版本 go-version: '1.23' # 选择 Go 版本
- name: Set up npm - name: Set up npm
uses: actions/setup-node@v2 uses: actions/setup-node@v2
@@ -115,12 +115,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
@@ -147,6 +147,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Generate Changelog - name: Generate Changelog
if: startsWith(github.ref, 'refs/tags/') && github.event_name == 'push'
run: npx changelogithub # or changelogithub@0.12 if ensure the stable result run: npx changelogithub # or changelogithub@0.12 if ensure the stable result
env: env:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}

View File

@@ -61,7 +61,7 @@ jobs:
- name: Set up Go environment - name: Set up Go environment
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:
go-version: '1.22' # 选择 Go 版本 go-version: '1.23' # 选择 Go 版本
- name: Set up npm - name: Set up npm
uses: actions/setup-node@v2 uses: actions/setup-node@v2
@@ -101,12 +101,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go

2
.gitignore vendored
View File

@@ -5,4 +5,4 @@ runtime/*
go.sum go.sum
resources/admin resources/admin
release release
data data/rustdeskapi.db

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/lejianwen/rustdesk-api/v2/config" "github.com/lejianwen/rustdesk-api/v2/config"
"github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/global"
@@ -140,18 +141,40 @@ func InitGlobal() {
} }
//gorm //gorm
if global.Config.Gorm.Type == config.TypeMysql { if global.Config.Gorm.Type == config.TypeMysql {
dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
global.Config.Mysql.Dbname,
)
global.DB = orm.NewMysql(&orm.MysqlConfig{ global.DB = orm.NewMysql(&orm.MysqlConfig{
Dns: dns, Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}) }, global.Logger)
} else if global.Config.Gorm.Type == config.TypePostgresql {
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
global.Config.Postgresql.Host,
global.Config.Postgresql.Port,
global.Config.Postgresql.User,
global.Config.Postgresql.Password,
global.Config.Postgresql.Dbname,
global.Config.Postgresql.Sslmode,
global.Config.Postgresql.TimeZone,
)
global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}, global.Logger)
} else { } else {
//sqlite //sqlite
global.DB = orm.NewSqlite(&orm.SqliteConfig{ global.DB = orm.NewSqlite(&orm.SqliteConfig{
MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}) }, global.Logger)
} }
//validator //validator
@@ -197,11 +220,17 @@ func DatabaseAutoUpdate() {
if dbName == "" { if dbName == "" {
dbName = global.Config.Mysql.Dbname dbName = global.Config.Mysql.Dbname
// 移除 DSN 中的数据库名称,以便初始连接时不指定数据库 // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local" dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
"",
)
//新链接 //新链接
dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{ dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
Dns: dsnWithoutDB, Dsn: dsnWithoutDB,
}) }, global.Logger)
// 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接 // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
sqlDBWithoutDB, err := dbWithoutDB.DB() sqlDBWithoutDB, err := dbWithoutDB.DB()
if err != nil { if err != nil {

View File

@@ -11,9 +11,12 @@ app:
disable-pwd-login: false #禁用密码登录 disable-pwd-login: false #禁用密码登录
admin: admin:
title: "RustDesk Api Admin" title: "RustDesk API Admin"
hello-file: "./conf/admin/hello.html" #优先使用file hello-file: "./conf/admin/hello.html" #优先使用file
hello: "" hello: ""
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
id-server-port: 21116 # ID Server port (for server cmd)
relay-server-port: 21117 # ID Server port (for server cmd)
gin: gin:
api-addr: "0.0.0.0:21114" api-addr: "0.0.0.0:21114"
mode: "release" #release,debug,test mode: "release" #release,debug,test
@@ -28,6 +31,16 @@ mysql:
password: "" password: ""
addr: "" addr: ""
dbname: "" dbname: ""
postgresql:
host: "127.0.0.1"
port: "5432"
user: ""
password: ""
dbname: "postgres"
sslmode: "disable" # disable, require, verify-ca, verify-full
time-zone: "Asia/Shanghai" # Time zone for PostgreSQL connection
rustdesk: rustdesk:
id-server: "192.168.1.66:21116" id-server: "192.168.1.66:21116"
relay-server: "192.168.1.66:21117" relay-server: "192.168.1.66:21117"

View File

@@ -25,25 +25,37 @@ type App struct {
BanThreshold int `mapstructure:"ban-threshold"` BanThreshold int `mapstructure:"ban-threshold"`
} }
type Admin struct { type Admin struct {
Title string `mapstructure:"title"` Title string `mapstructure:"title"`
Hello string `mapstructure:"hello"` Hello string `mapstructure:"hello"`
HelloFile string `mapstructure:"hello-file"` HelloFile string `mapstructure:"hello-file"`
IdServerPort int `mapstructure:"id-server-port"`
RelayServerPort int `mapstructure:"relay-server-port"`
} }
type Config struct { type Config struct {
Lang string `mapstructure:"lang"` Lang string `mapstructure:"lang"`
App App App App
Admin Admin Admin Admin
Gorm Gorm Gorm Gorm
Mysql Mysql Mysql Mysql
Gin Gin Postgresql Postgresql
Logger Logger Gin Gin
Redis Redis Logger Logger
Cache Cache Redis Redis
Oss Oss Cache Cache
Jwt Jwt Oss Oss
Rustdesk Rustdesk Jwt Jwt
Proxy Proxy Rustdesk Rustdesk
Ldap Ldap Proxy Proxy
Ldap Ldap
}
func (a *Admin) Init() {
if a.IdServerPort == 0 {
a.IdServerPort = DefaultIdServerPort
}
if a.RelayServerPort == 0 {
a.RelayServerPort = DefaultRelayServerPort
}
} }
// Init 初始化配置 // Init 初始化配置
@@ -80,7 +92,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
panic(fmt.Errorf("Fatal error config: %s \n", err)) panic(fmt.Errorf("Fatal error config: %s \n", err))
} }
rowVal.Rustdesk.LoadKeyFile() rowVal.Rustdesk.LoadKeyFile()
rowVal.Rustdesk.ParsePort() rowVal.Admin.Init()
return v return v
} }

View File

@@ -1,8 +1,9 @@
package config package config
const ( const (
TypeSqlite = "sqlite" TypeSqlite = "sqlite"
TypeMysql = "mysql" TypeMysql = "mysql"
TypePostgresql = "postgresql"
) )
type Gorm struct { type Gorm struct {
@@ -17,3 +18,13 @@ type Mysql struct {
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"` Dbname string `mapstructure:"dbname"`
} }
type Postgresql struct {
Host string `mapstructure:"host"`
Port string `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"`
Sslmode string `mapstructure:"sslmode"` // "disable", "require", "verify-ca", "verify-full"
TimeZone string `mapstructure:"time-zone"` // e.g., "Asia/Shanghai"
}

View File

@@ -18,3 +18,9 @@ type OidcOauth struct {
ClientSecret string `mapstructure:"client-secret"` ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"` RedirectUrl string `mapstructure:"redirect-url"`
} }
type LinuxdoOauth struct {
ClientId string `mapstructure:"client-id"`
ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"`
}

View File

@@ -2,8 +2,6 @@ package config
import ( import (
"os" "os"
"strconv"
"strings"
) )
const ( const (
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
return return
} }
} }
func (rd *Rustdesk) ParsePort() {
// Parse port
idres := strings.Split(rd.IdServer, ":")
if len(idres) == 1 {
rd.IdServerPort = DefaultIdServerPort
} else if len(idres) == 2 {
rd.IdServerPort, _ = strconv.Atoi(idres[1])
}
relayres := strings.Split(rd.RelayServer, ":")
if len(relayres) == 1 {
rd.RelayServerPort = DefaultRelayServerPort
} else if len(relayres) == 2 {
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
}
}

0
data/.gitkeep Normal file
View File

View File

@@ -1,4 +1,4 @@
// Package admin Content generated by swaggo/swag. DO NOT EDIT // Package admin Code generated by swaggo/swag. DO NOT EDIT
package admin package admin
import "github.com/swaggo/swag" import "github.com/swaggo/swag"
@@ -5828,6 +5828,9 @@ const docTemplateadmin = `{
"captcha": { "captcha": {
"type": "string" "type": "string"
}, },
"captcha_id": {
"type": "string"
},
"password": { "password": {
"type": "string" "type": "string"
}, },

View File

@@ -5821,6 +5821,9 @@
"captcha": { "captcha": {
"type": "string" "type": "string"
}, },
"captcha_id": {
"type": "string"
},
"password": { "password": {
"type": "string" "type": "string"
}, },

View File

@@ -297,6 +297,8 @@ definitions:
properties: properties:
captcha: captcha:
type: string type: string
captcha_id:
type: string
password: password:
type: string type: string
platform: platform:

View File

@@ -1,4 +1,4 @@
// Package api Content generated by swaggo/swag. DO NOT EDIT // Package api Code generated by swaggo/swag. DO NOT EDIT
package api package api
import "github.com/swaggo/swag" import "github.com/swaggo/swag"
@@ -1208,7 +1208,7 @@ const docTemplateapi = `{
"application/json" "application/json"
], ],
"tags": [ "tags": [
"地址" "System"
], ],
"summary": "提交系统信息", "summary": "提交系统信息",
"parameters": [ "parameters": [
@@ -1238,6 +1238,35 @@ const docTemplateapi = `{
} }
} }
}, },
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": { "/users": {
"get": { "get": {
"security": [ "security": [

View File

@@ -1201,7 +1201,7 @@
"application/json" "application/json"
], ],
"tags": [ "tags": [
"地址" "System"
], ],
"summary": "提交系统信息", "summary": "提交系统信息",
"parameters": [ "parameters": [
@@ -1231,6 +1231,35 @@
} }
} }
}, },
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": { "/users": {
"get": { "get": {
"security": [ "security": [

View File

@@ -973,7 +973,26 @@ paths:
$ref: '#/definitions/response.ErrorResponse' $ref: '#/definitions/response.ErrorResponse'
summary: 提交系统信息 summary: 提交系统信息
tags: tags:
- 地址 - System
/sysinfo_ver:
post:
consumes:
- application/json
description: 获取系统版本信息
produces:
- application/json
responses:
"200":
description: OK
schema:
type: string
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.ErrorResponse'
summary: 获取系统版本信息
tags:
- System
/users: /users:
get: get:
consumes: consumes:

23
go.mod
View File

@@ -1,19 +1,23 @@
module github.com/lejianwen/rustdesk-api/v2 module github.com/lejianwen/rustdesk-api/v2
go 1.22 go 1.23
toolchain go1.23.10
require ( require (
github.com/BurntSushi/toml v1.3.2 github.com/BurntSushi/toml v1.3.2
github.com/antonfisher/nested-logrus-formatter v1.3.1 github.com/antonfisher/nested-logrus-formatter v1.3.1
github.com/fsnotify/fsnotify v1.5.1 github.com/coreos/go-oidc/v3 v3.12.0
github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6 github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6
github.com/gin-gonic/gin v1.9.0 github.com/gin-gonic/gin v1.9.0
github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-playground/locales v0.14.1 github.com/go-playground/locales v0.14.1
github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/universal-translator v0.18.1
github.com/go-playground/validator/v10 v10.26.0 github.com/go-playground/validator/v10 v10.26.0
github.com/go-redis/redis/v8 v8.11.4 github.com/go-redis/redis/v8 v8.11.4
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/mojocn/base64Captcha v1.3.6
github.com/nicksnyder/go-i18n/v2 v2.4.0 github.com/nicksnyder/go-i18n/v2 v2.4.0
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
@@ -24,8 +28,9 @@ require (
golang.org/x/oauth2 v0.23.0 golang.org/x/oauth2 v0.23.0
golang.org/x/text v0.22.0 golang.org/x/text v0.22.0
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.5.6 gorm.io/driver/sqlite v1.5.6
gorm.io/gorm v1.25.7 gorm.io/gorm v1.25.10
) )
require ( require (
@@ -36,13 +41,12 @@ require (
github.com/bytedance/sonic v1.8.0 // indirect github.com/bytedance/sonic v1.8.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.6 // indirect github.com/go-openapi/jsonreference v0.19.6 // indirect
github.com/go-openapi/spec v0.20.4 // indirect github.com/go-openapi/spec v0.20.4 // indirect
@@ -52,6 +56,10 @@ require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
@@ -65,9 +73,9 @@ require (
github.com/mitchellh/mapstructure v1.4.2 // indirect github.com/mitchellh/mapstructure v1.4.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mojocn/base64Captcha v1.3.6 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/spf13/afero v1.6.0 // indirect github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.4.1 // indirect github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
@@ -79,8 +87,9 @@ require (
golang.org/x/crypto v0.33.0 // indirect golang.org/x/crypto v0.33.0 // indirect
golang.org/x/image v0.13.0 // indirect golang.org/x/image v0.13.0 // indirect
golang.org/x/net v0.34.0 // indirect golang.org/x/net v0.34.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.26.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.63.2 // indirect gopkg.in/ini.v1 v1.63.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect

View File

@@ -78,11 +78,13 @@ func (co *Config) AdminConfig(c *gin.Context) {
} }
hello := global.Config.Admin.Hello hello := global.Config.Admin.Hello
helloFile := global.Config.Admin.HelloFile if hello == "" {
if helloFile != "" { helloFile := global.Config.Admin.HelloFile
b, err := os.ReadFile(helloFile) if helloFile != "" {
if err == nil && len(b) > 0 { b, err := os.ReadFile(helloFile)
hello = string(b) if err == nil && len(b) > 0 {
hello = string(b)
}
} }
} }

View File

@@ -57,7 +57,7 @@ func (ct *Login) Login(c *gin.Context) {
// 检查是否需要验证码 // 检查是否需要验证码
if needCaptcha { if needCaptcha {
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) { if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")) response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
return return
} }
@@ -68,8 +68,6 @@ func (ct *Login) Login(c *gin.Context) {
if u.Id == 0 { if u.Id == 0 {
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp)) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
loginLimiter.RecordFailedAttempt(clientIp) loginLimiter.RecordFailedAttempt(clientIp)
// 移除验证码,重新生成
loginLimiter.RemoveCaptcha(clientIp)
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha { if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError")) response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
} else { } else {
@@ -80,7 +78,6 @@ func (ct *Login) Login(c *gin.Context) {
if !service.AllService.UserService.CheckUserEnable(u) { if !service.AllService.UserService.CheckUserEnable(u) {
if needCaptcha { if needCaptcha {
loginLimiter.RemoveCaptcha(clientIp)
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled")) response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
return return
} }
@@ -113,7 +110,7 @@ func (ct *Login) Captcha(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired")) response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
return return
} }
err, captcha := loginLimiter.RequireCaptcha(clientIp) err, captcha := loginLimiter.RequireCaptcha()
if err != nil { if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error()) response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return return
@@ -125,6 +122,7 @@ func (ct *Login) Captcha(c *gin.Context) {
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"captcha": gin.H{ "captcha": gin.H{
"id": captcha.Id,
"b64": b64, "b64": b64,
}, },
}) })

View File

@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")) response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return return
} }
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
port := 0
switch rc.Target {
case model.ServerCmdTargetIdServer:
port = global.Config.Admin.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = global.Config.Admin.RelayServerPort
}
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
if err != nil { if err != nil {
response.Fail(c, 101, err.Error()) response.Fail(c, 101, err.Error())
return return

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api" requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api"
@@ -13,7 +14,7 @@ type Peer struct {
} }
// SysInfo // SysInfo
// @Tags 地址 // @Tags System
// @Summary 提交系统信息 // @Summary 提交系统信息
// @Description 提交系统信息 // @Description 提交系统信息
// @Accept json // @Accept json
@@ -57,8 +58,19 @@ func (p *Peer) SysInfo(c *gin.Context) {
c.String(http.StatusOK, "SYSINFO_UPDATED") c.String(http.StatusOK, "SYSINFO_UPDATED")
} }
// SysInfoVer
// @Tags System
// @Summary 获取系统版本信息
// @Description 获取系统版本信息
// @Accept json
// @Produce json
// @Success 200 {string} string ""
// @Failure 500 {object} response.ErrorResponse
// @Router /sysinfo_ver [post]
func (p *Peer) SysInfoVer(c *gin.Context) { func (p *Peer) SysInfoVer(c *gin.Context) {
//读取resources/version文件 //读取resources/version文件
v := service.AllService.AppService.GetAppVersion() v := service.AllService.AppService.GetAppVersion()
// 加上启动时间方便client上传信息
v = fmt.Sprintf("%s\n%s", v, service.AllService.AppService.GetStartTime())
c.String(http.StatusOK, v) c.String(http.StatusOK, v)
} }

View File

@@ -1,10 +1,11 @@
package admin package admin
type Login struct { type Login struct {
Username string `json:"username" validate:"required" label:"用户名"` Username string `json:"username" validate:"required" label:"用户名"`
Password string `json:"password,omitempty" validate:"required" label:"密码"` Password string `json:"password,omitempty" validate:"required" label:"密码"`
Platform string `json:"platform" label:"平台"` Platform string `json:"platform" label:"平台"`
Captcha string `json:"captcha,omitempty" label:"验证码"` Captcha string `json:"captcha,omitempty" label:"验证码"`
CaptchaId string `json:"captcha_id,omitempty"`
} }
type LoginLogQuery struct { type LoginLogQuery struct {

View File

@@ -2,7 +2,6 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@@ -10,14 +9,14 @@ import (
) )
type MysqlConfig struct { type MysqlConfig struct {
Dns string Dsn string
MaxIdleConns int MaxIdleConns int
MaxOpenConns int MaxOpenConns int
} }
func NewMysql(mysqlConf *MysqlConfig) *gorm.DB { func NewMysql(mysqlConf *MysqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(mysql.New(mysql.Config{ db, err := gorm.Open(mysql.New(mysql.Config{
DSN: mysqlConf.Dns, // DSN data source name DSN: mysqlConf.Dsn, // DSN data source name
DefaultStringSize: 256, // string 类型字段的默认长度 DefaultStringSize: 256, // string 类型字段的默认长度
//DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持 //DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持
//DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引 //DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
@@ -26,7 +25,7 @@ func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
}), &gorm.Config{ }), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New( Logger: logger.New(
global.Logger, // io writer logwriter, // io writer
logger.Config{ logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level LogLevel: logger.Warn, // Log level

45
lib/orm/postgresql.go Normal file
View File

@@ -0,0 +1,45 @@
package orm
import (
"fmt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"time"
)
type PostgresqlConfig struct {
Dsn string
MaxIdleConns int
MaxOpenConns int
}
func NewPostgresql(conf *PostgresqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(postgres.Open(conf.Dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New(
logwriter, // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level
//IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true,
},
),
})
if err != nil {
fmt.Println(err)
}
sqlDB, err2 := db.DB()
if err2 != nil {
fmt.Println(err2)
}
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
// SetMaxOpenConns 设置打开数据库连接的最大数量。
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
return db
}

View File

@@ -2,7 +2,6 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@@ -14,11 +13,11 @@ type SqliteConfig struct {
MaxOpenConns int MaxOpenConns int
} }
func NewSqlite(sqliteConf *SqliteConfig) *gorm.DB { func NewSqlite(sqliteConf *SqliteConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{ db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New( Logger: logger.New(
global.Logger, // io writer logwriter, // io writer
logger.Config{ logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level LogLevel: logger.Warn, // Log level

View File

@@ -14,6 +14,7 @@ const (
OauthTypeGoogle string = "google" OauthTypeGoogle string = "google"
OauthTypeOidc string = "oidc" OauthTypeOidc string = "oidc"
OauthTypeWebauth string = "webauth" OauthTypeWebauth string = "webauth"
OauthTypeLinuxdo string = "linuxdo"
PKCEMethodS256 string = "S256" PKCEMethodS256 string = "S256"
PKCEMethodPlain string = "plain" PKCEMethodPlain string = "plain"
) )
@@ -21,7 +22,7 @@ const (
// Validate the oauth type // Validate the oauth type
func ValidateOauthType(oauthType string) error { func ValidateOauthType(oauthType string) error {
switch oauthType { switch oauthType {
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth: case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth, OauthTypeLinuxdo:
return nil return nil
default: default:
return errors.New("invalid Oauth type") return errors.New("invalid Oauth type")
@@ -30,6 +31,7 @@ func ValidateOauthType(oauthType string) error {
const ( const (
UserEndpointGithub string = "https://api.github.com/user" UserEndpointGithub string = "https://api.github.com/user"
UserEndpointLinuxdo string = "https://connect.linux.do/api/user"
IssuerGoogle string = "https://accounts.google.com" IssuerGoogle string = "https://accounts.google.com"
) )
@@ -60,6 +62,8 @@ func (oa *Oauth) FormatOauthInfo() error {
oa.Op = OauthTypeGithub oa.Op = OauthTypeGithub
case OauthTypeGoogle: case OauthTypeGoogle:
oa.Op = OauthTypeGoogle oa.Op = OauthTypeGoogle
case OauthTypeLinuxdo:
oa.Op = OauthTypeLinuxdo
} }
// check if the op is empty, set the default value // check if the op is empty, set the default value
op := strings.TrimSpace(oa.Op) op := strings.TrimSpace(oa.Op)
@@ -152,6 +156,24 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
} }
} }
type LinuxdoUser struct {
OauthUserBase
Id int `json:"id"`
Username string `json:"username"`
Avatar string `json:"avatar_url"`
}
func (lu *LinuxdoUser) ToOauthUser() *OauthUser {
return &OauthUser{
OpenId: strconv.Itoa(lu.Id),
Name: lu.Name,
Username: strings.ToLower(lu.Username),
Email: lu.Email,
VerifiedEmail: true, // linux.do 用户邮箱默认已验证
Picture: lu.Avatar,
}
}
type OauthList struct { type OauthList struct {
Oauths []*Oauth `json:"list"` Oauths []*Oauth `json:"list"`
Pagination Pagination

View File

@@ -11550,7 +11550,7 @@ async function or(u) {
let E = [], l = []; let E = [], l = [];
for (let d = 0; d < e.length; d++) { for (let d = 0; d < e.length; d++) {
const c = 1 << 7 - d % 8; const c = 1 << 7 - d % 8;
(s[d / 8] & c) === c ? E.push(e[d]) : l.push(e[d]) (s[Math.floor(d / 8)] & c) === c ? E.push(e[d]) : l.push(e[d])
} }
_t(E, l), n.close(); _t(E, l), n.close();
return return

View File

@@ -3,13 +3,14 @@ package service
import ( import (
"os" "os"
"sync" "sync"
"time"
) )
type AppService struct { type AppService struct {
} }
var version = "" var version = ""
var startTime = ""
var once = &sync.Once{} var once = &sync.Once{}
func (a *AppService) GetAppVersion() string { func (a *AppService) GetAppVersion() string {
@@ -26,3 +27,13 @@ func (a *AppService) GetAppVersion() string {
}) })
return version return version
} }
func init() {
// Initialize the AppService if needed
startTime = time.Now().Format("2006-01-02 15:04:05")
}
// GetStartTime
func (a *AppService) GetStartTime() string {
return startTime
}

View File

@@ -154,6 +154,18 @@ func (os *OauthService) GithubProvider() *oidc.Provider {
}).NewProvider(context.Background()) }).NewProvider(context.Background())
} }
func (os *OauthService) LinuxdoProvider() *oidc.Provider {
return (&oidc.ProviderConfig{
IssuerURL: "",
AuthURL: "https://connect.linux.do/oauth2/authorize",
TokenURL: "https://connect.linux.do/oauth2/token",
DeviceAuthURL: "",
UserInfoURL: model.UserEndpointLinuxdo,
JWKSURL: "",
Algorithms: nil,
}).NewProvider(context.Background())
}
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) { func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) //err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
@@ -182,6 +194,10 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
oauthConfig.Endpoint = github.Endpoint oauthConfig.Endpoint = github.Endpoint
oauthConfig.Scopes = []string{"read:user", "user:email"} oauthConfig.Scopes = []string{"read:user", "user:email"}
provider = os.GithubProvider() provider = os.GithubProvider()
case model.OauthTypeLinuxdo:
provider = os.LinuxdoProvider()
oauthConfig.Endpoint = provider.Endpoint()
oauthConfig.Scopes = []string{"profile"}
//case model.OauthTypeGoogle: //google单独出来可以少一次FetchOidcEndpoint请求 //case model.OauthTypeGoogle: //google单独出来可以少一次FetchOidcEndpoint请求
// oauthConfig.Endpoint = google.Endpoint // oauthConfig.Endpoint = google.Endpoint
// oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) // oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
@@ -299,6 +315,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oid
return nil, user.ToOauthUser() return nil, user.ToOauthUser()
} }
// linuxdoCallback linux.do回调
func (os *OauthService) linuxdoCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.LinuxdoUser{}
err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
if err != nil {
return err, nil
}
return nil, user.ToOauthUser()
}
// oidcCallback oidc回调, 通过code获取用户信息 // oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) { func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.OidcUser{} var user = &model.OidcUser{}
@@ -319,6 +345,8 @@ func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, o
switch oauthType { switch oauthType {
case model.OauthTypeGithub: case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce) err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeLinuxdo:
err, oauthUser = os.linuxdoCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeOidc, model.OauthTypeGoogle: case model.OauthTypeOidc, model.OauthTypeGoogle:
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce) err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce)
default: default:

View File

@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
} }
// SendCmd 发送命令 // SendCmd 发送命令
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) { func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
port := 0
switch target {
case model.ServerCmdTargetIdServer:
port = Config.Rustdesk.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = Config.Rustdesk.RelayServerPort
}
//组装命令 //组装命令
cmd = cmd + " " + arg cmd = cmd + " " + arg
res, err := is.SendSocketCmd("v6", port, cmd) res, err := is.SendSocketCmd("v6", port, cmd)

View File

@@ -5,15 +5,15 @@ import (
"time" "time"
) )
var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil) var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil) var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
type B64StringCaptchaProvider struct{} type B64StringCaptchaProvider struct{}
func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) { func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
_, content, answer := capdString.GenerateIdQuestionAnswer() id, content, answer := capdString.GenerateIdQuestionAnswer()
return content, answer, nil return id, content, answer, nil
} }
func (p B64StringCaptchaProvider) Expiration() time.Duration { func (p B64StringCaptchaProvider) Expiration() time.Duration {
@@ -30,9 +30,9 @@ func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
type B64MathCaptchaProvider struct{} type B64MathCaptchaProvider struct{}
func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) { func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
_, content, answer := capdMath.GenerateIdQuestionAnswer() id, content, answer := capdMath.GenerateIdQuestionAnswer()
return content, answer, nil return id, content, answer, nil
} }
func (p B64MathCaptchaProvider) Expiration() time.Duration { func (p B64MathCaptchaProvider) Expiration() time.Duration {

View File

@@ -16,7 +16,7 @@ type SecurityPolicy struct {
// 验证码提供者接口 // 验证码提供者接口
type CaptchaProvider interface { type CaptchaProvider interface {
Generate(ip string) (string, string, error) Generate() (id string, content string, answer string, err error)
//Validate(ip, code string) bool //Validate(ip, code string) bool
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
Draw(content string) (string, error) // 绘制验证码 Draw(content string) (string, error) // 绘制验证码
@@ -24,6 +24,7 @@ type CaptchaProvider interface {
// 验证码元数据 // 验证码元数据
type CaptchaMeta struct { type CaptchaMeta struct {
Id string
Content string Content string
Answer string Answer string
ExpiresAt time.Time ExpiresAt time.Time
@@ -117,7 +118,7 @@ func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
} }
// 生成验证码 // 生成验证码
func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) { func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
ll.mu.Lock() ll.mu.Lock()
defer ll.mu.Unlock() defer ll.mu.Unlock()
@@ -125,23 +126,24 @@ func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
return errors.New("no captcha provider available"), CaptchaMeta{} return errors.New("no captcha provider available"), CaptchaMeta{}
} }
content, answer, err := ll.provider.Generate(ip) id, content, answer, err := ll.provider.Generate()
if err != nil { if err != nil {
return err, CaptchaMeta{} return err, CaptchaMeta{}
} }
// 存储验证码 // 存储验证码
ll.captchas[ip] = CaptchaMeta{ ll.captchas[id] = CaptchaMeta{
Id: id,
Content: content, Content: content,
Answer: answer, Answer: answer,
ExpiresAt: time.Now().Add(ll.provider.Expiration()), ExpiresAt: time.Now().Add(ll.provider.Expiration()),
} }
return nil, ll.captchas[ip] return nil, ll.captchas[id]
} }
// 验证验证码 // 验证验证码
func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool { func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
ll.mu.Lock() ll.mu.Lock()
defer ll.mu.Unlock() defer ll.mu.Unlock()
@@ -151,20 +153,20 @@ func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
} }
// 获取并验证验证码 // 获取并验证验证码
captcha, exists := ll.captchas[ip] captcha, exists := ll.captchas[id]
if !exists { if !exists {
return false return false
} }
// 清理过期验证码 // 清理过期验证码
if time.Now().After(captcha.ExpiresAt) { if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, ip) delete(ll.captchas, id)
return false return false
} }
// 验证并清理状态 // 验证并清理状态
if answer == captcha.Answer { if answer == captcha.Answer {
delete(ll.captchas, ip) delete(ll.captchas, id)
return true return true
} }
@@ -176,16 +178,6 @@ func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
return return
} }
func (ll *LoginLimiter) RemoveCaptcha(ip string) {
ll.mu.Lock()
defer ll.mu.Unlock()
_, exists := ll.captchas[ip]
if exists {
delete(ll.captchas, ip)
}
}
// 清除记录窗口 // 清除记录窗口
func (ll *LoginLimiter) RemoveAttempts(ip string) { func (ll *LoginLimiter) RemoveAttempts(ip string) {
ll.mu.Lock() ll.mu.Lock()
@@ -212,7 +204,6 @@ func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequ
// 清理过期数据 // 清理过期数据
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow)) ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
ll.pruneCaptchas(ip)
// 检查验证码要求 // 检查验证码要求
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
@@ -272,10 +263,10 @@ func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
return valid return valid
} }
func (ll *LoginLimiter) pruneCaptchas(ip string) { func (ll *LoginLimiter) pruneCaptchas(id string) {
if captcha, exists := ll.captchas[ip]; exists { if captcha, exists := ll.captchas[id]; exists {
if time.Now().After(captcha.ExpiresAt) { if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, ip) delete(ll.captchas, id)
} }
} }
} }
@@ -299,7 +290,7 @@ func (ll *LoginLimiter) cleanupExpired() {
} }
// 清理验证码 // 清理验证码
for ip := range ll.captchas { for id := range ll.captchas {
ll.pruneCaptchas(ip) ll.pruneCaptchas(id)
} }
} }

View File

@@ -2,18 +2,18 @@ package utils
import ( import (
"fmt" "fmt"
"github.com/google/uuid"
"testing" "testing"
"time" "time"
) )
type MockCaptchaProvider struct{} type MockCaptchaProvider struct{}
func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) { func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
return "CONTENT", "MOCK", nil id := uuid.New().String()
} content := uuid.New().String()
answer := uuid.New().String()
func (p *MockCaptchaProvider) Validate(ip, code string) bool { return id, content, answer, nil
return code == "MOCK"
} }
func (p *MockCaptchaProvider) Expiration() time.Duration { func (p *MockCaptchaProvider) Expiration() time.Duration {
@@ -74,17 +74,22 @@ func TestCaptchaFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
fmt.Printf("验证码内容: %#v\n", capc) fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
// 验证已删除
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已删除")
}
limiter.RemoveAttempts(ip) limiter.RemoveAttempts(ip)
// 验证后状态 // 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need { if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
@@ -104,14 +109,14 @@ func TestCaptchaMustFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
fmt.Printf("验证码内容: %#v\n", capc) fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
@@ -136,7 +141,7 @@ func TestAttemptTimeout(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, _ := limiter.RequireCaptcha(ip) err, _ := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -167,7 +172,7 @@ func TestCaptchaTimeout(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, _ := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -175,9 +180,8 @@ func TestCaptchaTimeout(t *testing.T) {
// 等待超过 CaptchaValidPeriod // 等待超过 CaptchaValidPeriod
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
code := "MOCK"
// 验证成功 // 验证成功
if limiter.VerifyCaptcha(ip, code) { if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已过期") t.Error("验证码应该已过期")
} }
@@ -261,7 +265,7 @@ func TestB64CaptchaFlow(t *testing.T) {
} }
// 生成验证码 // 生成验证码
err, capc := limiter.RequireCaptcha(ip) err, capc := limiter.RequireCaptcha()
if err != nil { if err != nil {
t.Fatalf("生成验证码失败: %v", err) t.Fatalf("生成验证码失败: %v", err)
} }
@@ -275,7 +279,7 @@ func TestB64CaptchaFlow(t *testing.T) {
fmt.Printf("验证码内容: %#v\n", b64) fmt.Printf("验证码内容: %#v\n", b64)
// 验证成功 // 验证成功
if !limiter.VerifyCaptcha(ip, capc.Answer) { if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功") t.Error("验证码应该验证成功")
} }
limiter.RemoveAttempts(ip) limiter.RemoveAttempts(ip)