mirror of
https://github.com/lejianwen/rustdesk-api.git
synced 2025-11-29 08:33:21 +00:00
* fix: redirects after oauth can potentially misalign with server's actually hostname * feat: remove `RedirectURL` from oauth config, as it should checked by provider rather than client * feat: align oauth endpoint with the hostname in requests
355 lines
9.6 KiB
Go
355 lines
9.6 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/lejianwen/rustdesk-api/v2/config"
|
|
"github.com/lejianwen/rustdesk-api/v2/global"
|
|
"github.com/lejianwen/rustdesk-api/v2/http"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/cache"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/jwt"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/lock"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/logger"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/orm"
|
|
"github.com/lejianwen/rustdesk-api/v2/lib/upload"
|
|
"github.com/lejianwen/rustdesk-api/v2/model"
|
|
"github.com/lejianwen/rustdesk-api/v2/service"
|
|
"github.com/lejianwen/rustdesk-api/v2/utils"
|
|
"github.com/nicksnyder/go-i18n/v2/i18n"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
// @title 管理系统API
|
|
// @version 1.0
|
|
// @description 接口
|
|
// @basePath /api
|
|
// @securityDefinitions.apikey token
|
|
// @in header
|
|
// @name api-token
|
|
// @securitydefinitions.apikey BearerAuth
|
|
// @in header
|
|
// @name Authorization
|
|
|
|
var rootCmd = &cobra.Command{
|
|
Use: "apimain",
|
|
Short: "RUSTDESK API SERVER",
|
|
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
|
InitGlobal()
|
|
},
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
global.Logger.Info("API SERVER START")
|
|
http.ApiInit()
|
|
},
|
|
}
|
|
|
|
var resetPwdCmd = &cobra.Command{
|
|
Use: "reset-admin-pwd [pwd]",
|
|
Example: "reset-admin-pwd 123456",
|
|
Short: "Reset Admin Password",
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
pwd := args[0]
|
|
admin := service.AllService.UserService.InfoById(1)
|
|
if admin.Id == 0 {
|
|
global.Logger.Warn("user not found! ")
|
|
return
|
|
}
|
|
err := service.AllService.UserService.UpdatePassword(admin, pwd)
|
|
if err != nil {
|
|
global.Logger.Error("reset password fail! ", err)
|
|
return
|
|
}
|
|
global.Logger.Info("reset password success! ")
|
|
},
|
|
}
|
|
var resetUserPwdCmd = &cobra.Command{
|
|
Use: "reset-pwd [userId] [pwd]",
|
|
Example: "reset-pwd 2 123456",
|
|
Short: "Reset User Password",
|
|
Args: cobra.ExactArgs(2),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
userId := args[0]
|
|
pwd := args[1]
|
|
uid, err := strconv.Atoi(userId)
|
|
if err != nil {
|
|
global.Logger.Warn("userId must be int!")
|
|
return
|
|
}
|
|
if uid <= 0 {
|
|
global.Logger.Warn("userId must be greater than 0! ")
|
|
return
|
|
}
|
|
u := service.AllService.UserService.InfoById(uint(uid))
|
|
if u.Id == 0 {
|
|
global.Logger.Warn("user not found! ")
|
|
return
|
|
}
|
|
err = service.AllService.UserService.UpdatePassword(u, pwd)
|
|
if err != nil {
|
|
global.Logger.Warn("reset password fail! ", err)
|
|
return
|
|
}
|
|
global.Logger.Info("reset password success!")
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.PersistentFlags().StringVarP(&global.ConfigPath, "config", "c", "./conf/config.yaml", "choose config file")
|
|
rootCmd.AddCommand(resetPwdCmd, resetUserPwdCmd)
|
|
}
|
|
func main() {
|
|
if err := rootCmd.Execute(); err != nil {
|
|
global.Logger.Error(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func InitGlobal() {
|
|
//配置解析
|
|
global.Viper = config.Init(&global.Config, global.ConfigPath)
|
|
|
|
//日志
|
|
global.Logger = logger.New(&logger.Config{
|
|
Path: global.Config.Logger.Path,
|
|
Level: global.Config.Logger.Level,
|
|
ReportCaller: global.Config.Logger.ReportCaller,
|
|
})
|
|
|
|
global.InitI18n()
|
|
|
|
//redis
|
|
global.Redis = redis.NewClient(&redis.Options{
|
|
Addr: global.Config.Redis.Addr,
|
|
Password: global.Config.Redis.Password,
|
|
DB: global.Config.Redis.Db,
|
|
})
|
|
|
|
//cache
|
|
if global.Config.Cache.Type == cache.TypeFile {
|
|
fc := cache.NewFileCache()
|
|
fc.SetDir(global.Config.Cache.FileDir)
|
|
global.Cache = fc
|
|
} else if global.Config.Cache.Type == cache.TypeRedis {
|
|
global.Cache = cache.NewRedis(&redis.Options{
|
|
Addr: global.Config.Cache.RedisAddr,
|
|
Password: global.Config.Cache.RedisPwd,
|
|
DB: global.Config.Cache.RedisDb,
|
|
})
|
|
}
|
|
//gorm
|
|
if global.Config.Gorm.Type == config.TypeMysql {
|
|
|
|
dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
global.Config.Mysql.Username,
|
|
global.Config.Mysql.Password,
|
|
global.Config.Mysql.Addr,
|
|
global.Config.Mysql.Dbname,
|
|
)
|
|
|
|
global.DB = orm.NewMysql(&orm.MysqlConfig{
|
|
Dsn: dsn,
|
|
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
|
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
|
}, global.Logger)
|
|
} else if global.Config.Gorm.Type == config.TypePostgresql {
|
|
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
|
global.Config.Postgresql.Host,
|
|
global.Config.Postgresql.Port,
|
|
global.Config.Postgresql.User,
|
|
global.Config.Postgresql.Password,
|
|
global.Config.Postgresql.Dbname,
|
|
global.Config.Postgresql.Sslmode,
|
|
global.Config.Postgresql.TimeZone,
|
|
)
|
|
global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
|
|
Dsn: dsn,
|
|
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
|
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
|
}, global.Logger)
|
|
} else {
|
|
//sqlite
|
|
global.DB = orm.NewSqlite(&orm.SqliteConfig{
|
|
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
|
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
|
}, global.Logger)
|
|
}
|
|
|
|
//validator
|
|
global.ApiInitValidator()
|
|
|
|
//oss
|
|
global.Oss = &upload.Oss{
|
|
AccessKeyId: global.Config.Oss.AccessKeyId,
|
|
AccessKeySecret: global.Config.Oss.AccessKeySecret,
|
|
Host: global.Config.Oss.Host,
|
|
CallbackUrl: global.Config.Oss.CallbackUrl,
|
|
ExpireTime: global.Config.Oss.ExpireTime,
|
|
MaxByte: global.Config.Oss.MaxByte,
|
|
}
|
|
|
|
//jwt
|
|
//fmt.Println(global.Config.Jwt.PrivateKey)
|
|
global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration)
|
|
//locker
|
|
global.Lock = lock.NewLocal()
|
|
|
|
//service
|
|
service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
|
|
|
|
global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
|
|
CaptchaThreshold: global.Config.App.CaptchaThreshold,
|
|
BanThreshold: global.Config.App.BanThreshold,
|
|
AttemptsWindow: 10 * time.Minute,
|
|
BanDuration: 30 * time.Minute,
|
|
})
|
|
global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
|
|
DatabaseAutoUpdate()
|
|
}
|
|
|
|
func DatabaseAutoUpdate() {
|
|
version := 263
|
|
|
|
db := global.DB
|
|
|
|
if global.Config.Gorm.Type == config.TypeMysql {
|
|
//检查存不存在数据库,不存在则创建
|
|
dbName := db.Migrator().CurrentDatabase()
|
|
if dbName == "" {
|
|
dbName = global.Config.Mysql.Dbname
|
|
// 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
|
|
dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
global.Config.Mysql.Username,
|
|
global.Config.Mysql.Password,
|
|
global.Config.Mysql.Addr,
|
|
"",
|
|
)
|
|
|
|
//新链接
|
|
dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
|
|
Dsn: dsnWithoutDB,
|
|
}, global.Logger)
|
|
// 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
|
|
sqlDBWithoutDB, err := dbWithoutDB.DB()
|
|
if err != nil {
|
|
global.Logger.Errorf("获取底层 *sql.DB 对象失败: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := sqlDBWithoutDB.Close(); err != nil {
|
|
global.Logger.Errorf("关闭连接失败: %v", err)
|
|
}
|
|
}()
|
|
|
|
err = dbWithoutDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName + " DEFAULT CHARSET utf8mb4").Error
|
|
if err != nil {
|
|
global.Logger.Error(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if !db.Migrator().HasTable(&model.Version{}) {
|
|
Migrate(uint(version))
|
|
} else {
|
|
//查找最后一个version
|
|
var v model.Version
|
|
db.Last(&v)
|
|
if v.Version < uint(version) {
|
|
Migrate(uint(version))
|
|
}
|
|
|
|
// 245迁移
|
|
if v.Version < 245 {
|
|
//oauths 表的 oauth_type 字段设置为 op同样的值
|
|
db.Exec("update oauths set oauth_type = op")
|
|
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google'")
|
|
db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
|
|
//通过email迁移旧的google授权
|
|
uts := make([]model.UserThird, 0)
|
|
db.Where("oauth_type = ?", "google").Find(&uts)
|
|
for _, ut := range uts {
|
|
if ut.UserId > 0 {
|
|
db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
|
|
}
|
|
}
|
|
}
|
|
if v.Version < 246 {
|
|
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer is null")
|
|
}
|
|
}
|
|
|
|
}
|
|
func Migrate(version uint) {
|
|
global.Logger.Info("Migrating....", version)
|
|
err := global.DB.AutoMigrate(
|
|
&model.Version{},
|
|
&model.User{},
|
|
&model.UserToken{},
|
|
&model.Tag{},
|
|
&model.AddressBook{},
|
|
&model.Peer{},
|
|
&model.Group{},
|
|
&model.UserThird{},
|
|
&model.Oauth{},
|
|
&model.LoginLog{},
|
|
&model.ShareRecord{},
|
|
&model.AuditConn{},
|
|
&model.AuditFile{},
|
|
&model.AddressBookCollection{},
|
|
&model.AddressBookCollectionRule{},
|
|
&model.ServerCmd{},
|
|
&model.DeviceGroup{},
|
|
)
|
|
if err != nil {
|
|
global.Logger.Error("migrate err :=>", err)
|
|
}
|
|
global.DB.Create(&model.Version{Version: version})
|
|
//如果是初次则创建一个默认用户
|
|
var vc int64
|
|
global.DB.Model(&model.Version{}).Count(&vc)
|
|
if vc == 1 {
|
|
localizer := global.Localizer("")
|
|
defaultGroup, _ := localizer.LocalizeMessage(&i18n.Message{
|
|
ID: "DefaultGroup",
|
|
})
|
|
group := &model.Group{
|
|
Name: defaultGroup,
|
|
Type: model.GroupTypeDefault,
|
|
}
|
|
service.AllService.GroupService.Create(group)
|
|
|
|
shareGroup, _ := localizer.LocalizeMessage(&i18n.Message{
|
|
ID: "ShareGroup",
|
|
})
|
|
groupShare := &model.Group{
|
|
Name: shareGroup,
|
|
Type: model.GroupTypeShare,
|
|
}
|
|
service.AllService.GroupService.Create(groupShare)
|
|
//是true
|
|
is_admin := true
|
|
admin := &model.User{
|
|
Username: "admin",
|
|
Nickname: "Admin",
|
|
Status: model.COMMON_STATUS_ENABLE,
|
|
IsAdmin: &is_admin,
|
|
GroupId: 1,
|
|
}
|
|
|
|
// 生成随机密码
|
|
pwd := utils.RandomString(8)
|
|
global.Logger.Info("Admin Password Is: ", pwd)
|
|
var err error
|
|
admin.Password, err = utils.EncryptPassword(pwd)
|
|
if err != nil {
|
|
global.Logger.Fatalf("failed to generate admin password: %v", err)
|
|
}
|
|
global.DB.Create(admin)
|
|
}
|
|
|
|
}
|