package main import ( "fmt" "os" "strconv" "time" "github.com/go-redis/redis/v8" "github.com/lejianwen/rustdesk-api/v2/config" "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/http" "github.com/lejianwen/rustdesk-api/v2/lib/cache" "github.com/lejianwen/rustdesk-api/v2/lib/jwt" "github.com/lejianwen/rustdesk-api/v2/lib/lock" "github.com/lejianwen/rustdesk-api/v2/lib/logger" "github.com/lejianwen/rustdesk-api/v2/lib/orm" "github.com/lejianwen/rustdesk-api/v2/lib/upload" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/service" "github.com/lejianwen/rustdesk-api/v2/utils" "github.com/nicksnyder/go-i18n/v2/i18n" "github.com/spf13/cobra" ) const DatabaseVersion = 265 // @title 管理系统API // @version 1.0 // @description 接口 // @basePath /api // @securityDefinitions.apikey token // @in header // @name api-token // @securitydefinitions.apikey BearerAuth // @in header // @name Authorization var rootCmd = &cobra.Command{ Use: "apimain", Short: "RUSTDESK API SERVER", PersistentPreRun: func(cmd *cobra.Command, args []string) { InitGlobal() }, Run: func(cmd *cobra.Command, args []string) { global.Logger.Info("API SERVER START") http.ApiInit() }, } var resetPwdCmd = &cobra.Command{ Use: "reset-admin-pwd [pwd]", Example: "reset-admin-pwd 123456", Short: "Reset Admin Password", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { pwd := args[0] admin := service.AllService.UserService.InfoById(1) if admin.Id == 0 { global.Logger.Warn("user not found! ") return } err := service.AllService.UserService.UpdatePassword(admin, pwd) if err != nil { global.Logger.Error("reset password fail! ", err) return } global.Logger.Info("reset password success! ") }, } var resetUserPwdCmd = &cobra.Command{ Use: "reset-pwd [userId] [pwd]", Example: "reset-pwd 2 123456", Short: "Reset User Password", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { userId := args[0] pwd := args[1] uid, err := strconv.Atoi(userId) if err != nil { global.Logger.Warn("userId must be int!") return } if uid <= 0 { global.Logger.Warn("userId must be greater than 0! ") return } u := service.AllService.UserService.InfoById(uint(uid)) if u.Id == 0 { global.Logger.Warn("user not found! ") return } err = service.AllService.UserService.UpdatePassword(u, pwd) if err != nil { global.Logger.Warn("reset password fail! ", err) return } global.Logger.Info("reset password success!") }, } func init() { rootCmd.PersistentFlags().StringVarP(&global.ConfigPath, "config", "c", "./conf/config.yaml", "choose config file") rootCmd.AddCommand(resetPwdCmd, resetUserPwdCmd) } func main() { if err := rootCmd.Execute(); err != nil { global.Logger.Error(err) os.Exit(1) } } func InitGlobal() { //配置解析 global.Viper = config.Init(&global.Config, global.ConfigPath) //日志 global.Logger = logger.New(&logger.Config{ Path: global.Config.Logger.Path, Level: global.Config.Logger.Level, ReportCaller: global.Config.Logger.ReportCaller, }) global.InitI18n() //redis global.Redis = redis.NewClient(&redis.Options{ Addr: global.Config.Redis.Addr, Password: global.Config.Redis.Password, DB: global.Config.Redis.Db, }) //cache if global.Config.Cache.Type == cache.TypeFile { fc := cache.NewFileCache() fc.SetDir(global.Config.Cache.FileDir) global.Cache = fc } else if global.Config.Cache.Type == cache.TypeRedis { global.Cache = cache.NewRedis(&redis.Options{ Addr: global.Config.Cache.RedisAddr, Password: global.Config.Cache.RedisPwd, DB: global.Config.Cache.RedisDb, }) } //gorm if global.Config.Gorm.Type == config.TypeMysql { dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", global.Config.Mysql.Username, global.Config.Mysql.Password, global.Config.Mysql.Addr, global.Config.Mysql.Dbname, global.Config.Mysql.Tls, ) global.DB = orm.NewMysql(&orm.MysqlConfig{ Dsn: dsn, MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns, }, global.Logger) } else if global.Config.Gorm.Type == config.TypePostgresql { dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", global.Config.Postgresql.Host, global.Config.Postgresql.Port, global.Config.Postgresql.User, global.Config.Postgresql.Password, global.Config.Postgresql.Dbname, global.Config.Postgresql.Sslmode, global.Config.Postgresql.TimeZone, ) global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{ Dsn: dsn, MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns, }, global.Logger) } else { //sqlite global.DB = orm.NewSqlite(&orm.SqliteConfig{ MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns, }, global.Logger) } //validator global.ApiInitValidator() //oss global.Oss = &upload.Oss{ AccessKeyId: global.Config.Oss.AccessKeyId, AccessKeySecret: global.Config.Oss.AccessKeySecret, Host: global.Config.Oss.Host, CallbackUrl: global.Config.Oss.CallbackUrl, ExpireTime: global.Config.Oss.ExpireTime, MaxByte: global.Config.Oss.MaxByte, } //jwt //fmt.Println(global.Config.Jwt.PrivateKey) global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration) //locker global.Lock = lock.NewLocal() //service service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock) global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{ CaptchaThreshold: global.Config.App.CaptchaThreshold, BanThreshold: global.Config.App.BanThreshold, AttemptsWindow: 10 * time.Minute, BanDuration: 30 * time.Minute, }) global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{}) DatabaseAutoUpdate() } func DatabaseAutoUpdate() { version := DatabaseVersion db := global.DB if global.Config.Gorm.Type == config.TypeMysql { //检查存不存在数据库,不存在则创建 dbName := db.Migrator().CurrentDatabase() if dbName == "" { dbName = global.Config.Mysql.Dbname // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库 dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", global.Config.Mysql.Username, global.Config.Mysql.Password, global.Config.Mysql.Addr, "", ) //新链接 dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{ Dsn: dsnWithoutDB, }, global.Logger) // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接 sqlDBWithoutDB, err := dbWithoutDB.DB() if err != nil { global.Logger.Errorf("获取底层 *sql.DB 对象失败: %v", err) return } defer func() { if err := sqlDBWithoutDB.Close(); err != nil { global.Logger.Errorf("关闭连接失败: %v", err) } }() err = dbWithoutDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName + " DEFAULT CHARSET utf8mb4").Error if err != nil { global.Logger.Error(err) return } } } if !db.Migrator().HasTable(&model.Version{}) { Migrate(uint(version)) } else { //查找最后一个version var v model.Version db.Last(&v) if v.Version < uint(version) { Migrate(uint(version)) } // 245迁移 if v.Version < 245 { //oauths 表的 oauth_type 字段设置为 op同样的值 db.Exec("update oauths set oauth_type = op") db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google'") db.Exec("update user_thirds set oauth_type = third_type, op = third_type") //通过email迁移旧的google授权 uts := make([]model.UserThird, 0) db.Where("oauth_type = ?", "google").Find(&uts) for _, ut := range uts { if ut.UserId > 0 { db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId) } } } if v.Version < 246 { db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer is null") } } } func Migrate(version uint) { global.Logger.Info("Migrating....", version) err := global.DB.AutoMigrate( &model.Version{}, &model.User{}, &model.UserToken{}, &model.Tag{}, &model.AddressBook{}, &model.Peer{}, &model.Group{}, &model.UserThird{}, &model.Oauth{}, &model.LoginLog{}, &model.ShareRecord{}, &model.AuditConn{}, &model.AuditFile{}, &model.AddressBookCollection{}, &model.AddressBookCollectionRule{}, &model.ServerCmd{}, &model.DeviceGroup{}, ) if err != nil { global.Logger.Error("migrate err :=>", err) } global.DB.Create(&model.Version{Version: version}) //如果是初次则创建一个默认用户 var vc int64 global.DB.Model(&model.Version{}).Count(&vc) if vc == 1 { localizer := global.Localizer("") defaultGroup, _ := localizer.LocalizeMessage(&i18n.Message{ ID: "DefaultGroup", }) group := &model.Group{ Name: defaultGroup, Type: model.GroupTypeDefault, } service.AllService.GroupService.Create(group) shareGroup, _ := localizer.LocalizeMessage(&i18n.Message{ ID: "ShareGroup", }) groupShare := &model.Group{ Name: shareGroup, Type: model.GroupTypeShare, } service.AllService.GroupService.Create(groupShare) //是true is_admin := true admin := &model.User{ Username: "admin", Nickname: "Admin", Status: model.COMMON_STATUS_ENABLE, IsAdmin: &is_admin, GroupId: 1, } // 生成随机密码 pwd := utils.RandomString(8) global.Logger.Info("Admin Password Is: ", pwd) var err error admin.Password, err = utils.EncryptPassword(pwd) if err != nil { global.Logger.Fatalf("failed to generate admin password: %v", err) } global.DB.Create(admin) } }