diff --git a/cmd/apimain.go b/cmd/apimain.go index e5d73d3..c0488dc 100644 --- a/cmd/apimain.go +++ b/cmd/apimain.go @@ -164,6 +164,9 @@ func InitGlobal() { global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration) //locker global.Lock = lock.NewLocal() + + //service + service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock) } func DatabaseAutoUpdate() { version := 262 diff --git a/service/addressBook.go b/service/addressBook.go index f6f88d9..dcd01c4 100644 --- a/service/addressBook.go +++ b/service/addressBook.go @@ -3,7 +3,6 @@ package service import ( "encoding/json" "github.com/google/uuid" - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" "strings" @@ -14,24 +13,24 @@ type AddressBookService struct { func (s *AddressBookService) Info(id string) *model.AddressBook { p := &model.AddressBook{} - global.DB.Where("id = ?", id).First(p) + DB.Where("id = ?", id).First(p) return p } func (s *AddressBookService) InfoByUserIdAndId(userid uint, id string) *model.AddressBook { p := &model.AddressBook{} - global.DB.Where("user_id = ? and id = ?", userid, id).First(p) + DB.Where("user_id = ? and id = ?", userid, id).First(p) return p } func (s *AddressBookService) InfoByUserIdAndIdAndCid(userid uint, id string, cid uint) *model.AddressBook { p := &model.AddressBook{} - global.DB.Where("user_id = ? and id = ? and collection_id = ?", userid, id, cid).First(p) + DB.Where("user_id = ? and id = ? and collection_id = ?", userid, id, cid).First(p) return p } func (s *AddressBookService) InfoByRowId(id uint) *model.AddressBook { p := &model.AddressBook{} - global.DB.Where("row_id = ?", id).First(p) + DB.Where("row_id = ?", id).First(p) return p } func (s *AddressBookService) ListByUserId(userId, page, pageSize uint) (res *model.AddressBookList) { @@ -49,14 +48,14 @@ func (s *AddressBookService) ListByUserIds(userIds []uint, page, pageSize uint) // AddAddressBook func (s *AddressBookService) AddAddressBook(ab *model.AddressBook) error { - return global.DB.Create(ab).Error + return DB.Create(ab).Error } // UpdateAddressBook func (s *AddressBookService) UpdateAddressBook(abs []*model.AddressBook, userId uint) error { //比较peers和数据库中的数据,如果peers中的数据在数据库中不存在,则添加,如果存在则更新,如果数据库中的数据在peers中不存在,则删除 // 开始事务 - tx := global.DB.Begin() + tx := DB.Begin() //1. 获取数据库中的数据 var dbABs []*model.AddressBook tx.Where("user_id = ?", userId).Find(&dbABs) @@ -107,7 +106,7 @@ func (s *AddressBookService) List(page, pageSize uint, where func(tx *gorm.DB)) res = &model.AddressBookList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.AddressBook{}) + tx := DB.Model(&model.AddressBook{}) if where != nil { where(tx) } @@ -129,38 +128,38 @@ func (s *AddressBookService) FromPeer(peer *model.Peer) (a *model.AddressBook) { // Create 创建 func (s *AddressBookService) Create(u *model.AddressBook) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (s *AddressBookService) Delete(u *model.AddressBook) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (s *AddressBookService) Update(u *model.AddressBook) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } // UpdateByMap 更新 func (s *AddressBookService) UpdateByMap(u *model.AddressBook, data map[string]interface{}) error { - return global.DB.Model(u).Updates(data).Error + return DB.Model(u).Updates(data).Error } // UpdateAll 更新 func (s *AddressBookService) UpdateAll(u *model.AddressBook) error { - return global.DB.Model(u).Select("*").Omit("created_at").Updates(u).Error + return DB.Model(u).Select("*").Omit("created_at").Updates(u).Error } // ShareByWebClient 分享 func (s *AddressBookService) ShareByWebClient(m *model.ShareRecord) error { m.ShareToken = uuid.New().String() - return global.DB.Create(m).Error + return DB.Create(m).Error } // SharedPeer func (s *AddressBookService) SharedPeer(shareToken string) *model.ShareRecord { m := &model.ShareRecord{} - global.DB.Where("share_token = ?", shareToken).First(m) + DB.Where("share_token = ?", shareToken).First(m) return m } @@ -190,7 +189,7 @@ func (s *AddressBookService) ListCollection(page, pageSize uint, where func(tx * res = &model.AddressBookCollectionList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.AddressBookCollection{}) + tx := DB.Model(&model.AddressBookCollection{}) if where != nil { where(tx) } @@ -200,7 +199,7 @@ func (s *AddressBookService) ListCollection(page, pageSize uint, where func(tx * return } func (s *AddressBookService) ListCollectionByIds(ids []uint) (res []*model.AddressBookCollection) { - global.DB.Where("id in ?", ids).Find(&res) + DB.Where("id in ?", ids).Find(&res) return res } @@ -212,20 +211,20 @@ func (s *AddressBookService) ListCollectionByUserId(userId uint) (res *model.Add } func (s *AddressBookService) CollectionInfoById(id uint) *model.AddressBookCollection { p := &model.AddressBookCollection{} - global.DB.Where("id = ?", id).First(p) + DB.Where("id = ?", id).First(p) return p } func (s *AddressBookService) CollectionReadRules(user *model.User) (res []*model.AddressBookCollectionRule) { // personalRules var personalRules []*model.AddressBookCollectionRule - tx2 := global.DB.Model(&model.AddressBookCollectionRule{}) + tx2 := DB.Model(&model.AddressBookCollectionRule{}) tx2.Where("type = ? and to_id = ? and rule > 0", model.ShareAddressBookRuleTypePersonal, user.Id).Find(&personalRules) res = append(res, personalRules...) //group var groupRules []*model.AddressBookCollectionRule - tx3 := global.DB.Model(&model.AddressBookCollectionRule{}) + tx3 := DB.Model(&model.AddressBookCollectionRule{}) tx3.Where("type = ? and to_id = ? and rule > 0", model.ShareAddressBookRuleTypeGroup, user.GroupId).Find(&groupRules) res = append(res, groupRules...) return @@ -238,7 +237,7 @@ func (s *AddressBookService) UserMaxRule(user *model.User, uid, cid uint) int { } max := 0 personalRules := &model.AddressBookCollectionRule{} - tx := global.DB.Model(personalRules) + tx := DB.Model(personalRules) tx.Where("type = ? and collection_id = ? and to_id = ?", model.ShareAddressBookRuleTypePersonal, cid, user.Id).First(&personalRules) if personalRules.Id != 0 { max = personalRules.Rule @@ -248,7 +247,7 @@ func (s *AddressBookService) UserMaxRule(user *model.User, uid, cid uint) int { } groupRules := &model.AddressBookCollectionRule{} - tx2 := global.DB.Model(groupRules) + tx2 := DB.Model(groupRules) tx2.Where("type = ? and collection_id = ? and to_id = ?", model.ShareAddressBookRuleTypeGroup, cid, user.GroupId).First(&groupRules) if groupRules.Id != 0 { if groupRules.Rule > max { @@ -272,16 +271,16 @@ func (s *AddressBookService) CheckUserFullControlPrivilege(user *model.User, uid } func (s *AddressBookService) CreateCollection(t *model.AddressBookCollection) error { - return global.DB.Create(t).Error + return DB.Create(t).Error } func (s *AddressBookService) UpdateCollection(t *model.AddressBookCollection) error { - return global.DB.Model(t).Updates(t).Error + return DB.Model(t).Updates(t).Error } func (s *AddressBookService) DeleteCollection(t *model.AddressBookCollection) error { //删除集合下的所有规则、地址簿,再删除集合 - tx := global.DB.Begin() + tx := DB.Begin() tx.Where("collection_id = ?", t.Id).Delete(&model.AddressBookCollectionRule{}) tx.Where("collection_id = ?", t.Id).Delete(&model.AddressBook{}) tx.Delete(t) @@ -290,23 +289,23 @@ func (s *AddressBookService) DeleteCollection(t *model.AddressBookCollection) er func (s *AddressBookService) RuleInfoById(u uint) *model.AddressBookCollectionRule { p := &model.AddressBookCollectionRule{} - global.DB.Where("id = ?", u).First(p) + DB.Where("id = ?", u).First(p) return p } func (s *AddressBookService) RulePersonalInfoByToIdAndCid(toid, cid uint) *model.AddressBookCollectionRule { p := &model.AddressBookCollectionRule{} - global.DB.Where("type = ? and to_id = ? and collection_id = ?", model.ShareAddressBookRuleTypePersonal, toid, cid).First(p) + DB.Where("type = ? and to_id = ? and collection_id = ?", model.ShareAddressBookRuleTypePersonal, toid, cid).First(p) return p } func (s *AddressBookService) CreateRule(t *model.AddressBookCollectionRule) error { - return global.DB.Create(t).Error + return DB.Create(t).Error } func (s *AddressBookService) ListRules(page uint, size uint, f func(tx *gorm.DB)) *model.AddressBookCollectionRuleList { res := &model.AddressBookCollectionRuleList{} res.Page = int64(page) res.PageSize = int64(size) - tx := global.DB.Model(&model.AddressBookCollectionRule{}) + tx := DB.Model(&model.AddressBookCollectionRule{}) if f != nil { f(tx) } @@ -317,11 +316,11 @@ func (s *AddressBookService) ListRules(page uint, size uint, f func(tx *gorm.DB) } func (s *AddressBookService) UpdateRule(t *model.AddressBookCollectionRule) error { - return global.DB.Model(t).Updates(t).Error + return DB.Model(t).Updates(t).Error } func (s *AddressBookService) DeleteRule(t *model.AddressBookCollectionRule) error { - return global.DB.Delete(t).Error + return DB.Delete(t).Error } // CheckCollectionOwner 检查Collection的所有者 @@ -336,5 +335,5 @@ func (s *AddressBookService) BatchUpdateTags(abs []*model.AddressBook, tags []st ids = append(ids, ab.RowId) } tagsv, _ := json.Marshal(tags) - return global.DB.Model(&model.AddressBook{}).Where("row_id in ?", ids).Update("tags", tagsv).Error + return DB.Model(&model.AddressBook{}).Where("row_id in ?", ids).Update("tags", tagsv).Error } diff --git a/service/audit.go b/service/audit.go index acc3d80..70b110a 100644 --- a/service/audit.go +++ b/service/audit.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -13,7 +12,7 @@ func (as *AuditService) AuditConnList(page, pageSize uint, where func(tx *gorm.D res = &model.AuditConnList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.AuditConn{}) + tx := DB.Model(&model.AuditConn{}) if where != nil { where(tx) } @@ -25,36 +24,36 @@ func (as *AuditService) AuditConnList(page, pageSize uint, where func(tx *gorm.D // Create 创建 func (as *AuditService) CreateAuditConn(u *model.AuditConn) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (as *AuditService) DeleteAuditConn(u *model.AuditConn) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (as *AuditService) UpdateAuditConn(u *model.AuditConn) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } // InfoByPeerIdAndConnId func (as *AuditService) InfoByPeerIdAndConnId(peerId string, connId int64) (res *model.AuditConn) { res = &model.AuditConn{} - global.DB.Where("peer_id = ? and conn_id = ?", peerId, connId).First(res) + DB.Where("peer_id = ? and conn_id = ?", peerId, connId).First(res) return } // ConnInfoById func (as *AuditService) ConnInfoById(id uint) (res *model.AuditConn) { res = &model.AuditConn{} - global.DB.Where("id = ?", id).First(res) + DB.Where("id = ?", id).First(res) return } // FileInfoById func (as *AuditService) FileInfoById(id uint) (res *model.AuditFile) { res = &model.AuditFile{} - global.DB.Where("id = ?", id).First(res) + DB.Where("id = ?", id).First(res) return } @@ -62,7 +61,7 @@ func (as *AuditService) AuditFileList(page, pageSize uint, where func(tx *gorm.D res = &model.AuditFileList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.AuditFile{}) + tx := DB.Model(&model.AuditFile{}) if where != nil { where(tx) } @@ -74,22 +73,22 @@ func (as *AuditService) AuditFileList(page, pageSize uint, where func(tx *gorm.D // CreateAuditFile func (as *AuditService) CreateAuditFile(u *model.AuditFile) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (as *AuditService) DeleteAuditFile(u *model.AuditFile) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (as *AuditService) UpdateAuditFile(u *model.AuditFile) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } func (as *AuditService) BatchDeleteAuditConn(ids []uint) error { - return global.DB.Where("id in (?)", ids).Delete(&model.AuditConn{}).Error + return DB.Where("id in (?)", ids).Delete(&model.AuditConn{}).Error } func (as *AuditService) BatchDeleteAuditFile(ids []uint) error { - return global.DB.Where("id in (?)", ids).Delete(&model.AuditFile{}).Error + return DB.Where("id in (?)", ids).Delete(&model.AuditFile{}).Error } diff --git a/service/group.go b/service/group.go index 938bab0..001ea97 100644 --- a/service/group.go +++ b/service/group.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -12,7 +11,7 @@ type GroupService struct { // InfoById 根据用户id取用户信息 func (us *GroupService) InfoById(id uint) *model.Group { u := &model.Group{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } @@ -20,7 +19,7 @@ func (us *GroupService) List(page, pageSize uint, where func(tx *gorm.DB)) (res res = &model.GroupList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.Group{}) + tx := DB.Model(&model.Group{}) if where != nil { where(tx) } @@ -32,22 +31,22 @@ func (us *GroupService) List(page, pageSize uint, where func(tx *gorm.DB)) (res // Create 创建 func (us *GroupService) Create(u *model.Group) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (us *GroupService) Delete(u *model.Group) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (us *GroupService) Update(u *model.Group) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } // DeviceGroupInfoById 根据用户id取用户信息 func (us *GroupService) DeviceGroupInfoById(id uint) *model.DeviceGroup { u := &model.DeviceGroup{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } @@ -55,7 +54,7 @@ func (us *GroupService) DeviceGroupList(page, pageSize uint, where func(tx *gorm res = &model.DeviceGroupList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.DeviceGroup{}) + tx := DB.Model(&model.DeviceGroup{}) if where != nil { where(tx) } @@ -66,13 +65,13 @@ func (us *GroupService) DeviceGroupList(page, pageSize uint, where func(tx *gorm } func (us *GroupService) DeviceGroupCreate(u *model.DeviceGroup) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (us *GroupService) DeviceGroupDelete(u *model.DeviceGroup) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } func (us *GroupService) DeviceGroupUpdate(u *model.DeviceGroup) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } diff --git a/service/ldap.go b/service/ldap.go index 337b6b5..6d12646 100644 --- a/service/ldap.go +++ b/service/ldap.go @@ -10,7 +10,6 @@ import ( "github.com/go-ldap/ldap/v3" "github.com/lejianwen/rustdesk-api/v2/config" - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" ) @@ -114,7 +113,7 @@ func (ls *LdapService) Authenticate(username, password string) (*model.User, err if !ldapUser.Enabled { return nil, ErrLdapUserDisabled } - cfg := &global.Config.Ldap + cfg := &Config.Ldap user, err := ls.mapToLocalUser(cfg, ldapUser) if err != nil { return nil, errors.Join(ErrLdapToLocalUserFailed, err) @@ -135,7 +134,7 @@ func (ls *LdapService) mapToLocalUser(cfg *config.Ldap, lu *LdapUser) (*model.Us // If needed, you can set a random password here. newUser.IsAdmin = &isAdmin newUser.GroupId = 1 - if err := global.DB.Create(newUser).Error; err != nil { + if err := DB.Create(newUser).Error; err != nil { return nil, errors.Join(ErrLdapCreateUserFailed, err) } return userService.InfoByUsername(lu.Username), nil @@ -164,7 +163,7 @@ func (ls *LdapService) mapToLocalUser(cfg *config.Ldap, lu *LdapUser) (*model.Us // IsUsernameExists checks if a username exists in LDAP (can be useful for local registration checks). func (ls *LdapService) IsUsernameExists(username string) bool { - cfg := &global.Config.Ldap + cfg := &Config.Ldap if !cfg.Enable { return false } @@ -177,7 +176,7 @@ func (ls *LdapService) IsUsernameExists(username string) bool { // IsEmailExists checks if an email exists in LDAP (can be useful for local registration checks). func (ls *LdapService) IsEmailExists(email string) bool { - cfg := &global.Config.Ldap + cfg := &Config.Ldap if !cfg.Enable { return false } @@ -190,7 +189,7 @@ func (ls *LdapService) IsEmailExists(email string) bool { // GetUserInfoByUsernameLdap returns the user info from LDAP for the given username. func (ls *LdapService) GetUserInfoByUsernameLdap(username string) (*LdapUser, error) { - cfg := &global.Config.Ldap + cfg := &Config.Ldap if !cfg.Enable { return nil, ErrLdapNotEnabled } @@ -210,12 +209,12 @@ func (ls *LdapService) GetUserInfoByUsernameLocal(username string) (*model.User, if err != nil { return &model.User{}, err } - return ls.mapToLocalUser(&global.Config.Ldap, ldapUser) + return ls.mapToLocalUser(&Config.Ldap, ldapUser) } // GetUserInfoByEmailLdap returns the user info from LDAP for the given email. func (ls *LdapService) GetUserInfoByEmailLdap(email string) (*LdapUser, error) { - cfg := &global.Config.Ldap + cfg := &Config.Ldap if !cfg.Enable { return nil, ErrLdapNotEnabled } @@ -235,7 +234,7 @@ func (ls *LdapService) GetUserInfoByEmailLocal(email string) (*model.User, error if err != nil { return &model.User{}, err } - return ls.mapToLocalUser(&global.Config.Ldap, ldapUser) + return ls.mapToLocalUser(&Config.Ldap, ldapUser) } // usernameSearchResult returns the search result for the given username. @@ -453,12 +452,12 @@ func (ls *LdapService) isUserEnabled(cfg *config.Ldap, ldapUser *LdapUser) bool // Account is disabled if the ACCOUNTDISABLE flag (0x2) is set const ACCOUNTDISABLE = 0x2 - ldapUser.Enabled = (userAccountControl&ACCOUNTDISABLE == 0) + ldapUser.Enabled = userAccountControl&ACCOUNTDISABLE == 0 return ldapUser.Enabled } // For other attributes, perform a direct comparison with the expected value - ldapUser.Enabled = (ldapUser.EnableAttrValue == enableAttrValue) + ldapUser.Enabled = ldapUser.EnableAttrValue == enableAttrValue return ldapUser.Enabled } diff --git a/service/loginLog.go b/service/loginLog.go index 719aaf0..0b6711b 100644 --- a/service/loginLog.go +++ b/service/loginLog.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -12,7 +11,7 @@ type LoginLogService struct { // InfoById 根据用户id取用户信息 func (us *LoginLogService) InfoById(id uint) *model.LoginLog { u := &model.LoginLog{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } @@ -20,7 +19,7 @@ func (us *LoginLogService) List(page, pageSize uint, where func(tx *gorm.DB)) (r res = &model.LoginLogList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.LoginLog{}) + tx := DB.Model(&model.LoginLog{}) if where != nil { where(tx) } @@ -32,20 +31,20 @@ func (us *LoginLogService) List(page, pageSize uint, where func(tx *gorm.DB)) (r // Create 创建 func (us *LoginLogService) Create(u *model.LoginLog) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (us *LoginLogService) Delete(u *model.LoginLog) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (us *LoginLogService) Update(u *model.LoginLog) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } func (us *LoginLogService) BatchDelete(ids []uint) error { - return global.DB.Where("id in (?)", ids).Delete(&model.LoginLog{}).Error + return DB.Where("id in (?)", ids).Delete(&model.LoginLog{}).Error } func (us *LoginLogService) SoftDelete(l *model.LoginLog) error { @@ -54,5 +53,5 @@ func (us *LoginLogService) SoftDelete(l *model.LoginLog) error { } func (us *LoginLogService) BatchSoftDelete(uid uint, ids []uint) error { - return global.DB.Model(&model.LoginLog{}).Where("user_id = ? and id in (?)", uid, ids).Update("is_deleted", model.IsDeletedYes).Error + return DB.Model(&model.LoginLog{}).Where("user_id = ? and id in (?)", uid, ids).Update("is_deleted", model.IsDeletedYes).Error } diff --git a/service/oauth.go b/service/oauth.go index e380b87..74a8ab8 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "github.com/coreos/go-oidc/v3/oidc" - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/utils" "golang.org/x/oauth2" @@ -99,7 +98,7 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonc verifier = "" nonce = "" if op == model.OauthTypeWebauth { - url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state + url = Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, state, verifier, nonce, url } @@ -164,7 +163,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O } // If the redirect URL is empty, use the default redirect URL if oauthInfo.RedirectUrl == "" { - oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" + oauthInfo.RedirectUrl = Config.Rustdesk.ApiServer + "/api/oidc/callback" } oauthConfig = &oauth2.Config{ ClientID: oauthInfo.ClientId, @@ -202,14 +201,14 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O func getHTTPClientWithProxy() *http.Client { //add timeout 30s timeout := time.Duration(60) * time.Second - if global.Config.Proxy.Enable { - if global.Config.Proxy.Host == "" { - global.Logger.Warn("Proxy is enabled but proxy host is empty.") + if Config.Proxy.Enable { + if Config.Proxy.Host == "" { + Logger.Warn("Proxy is enabled but proxy host is empty.") return http.DefaultClient } - proxyURL, err := url.Parse(global.Config.Proxy.Host) + proxyURL, err := url.Parse(Config.Proxy.Host) if err != nil { - global.Logger.Warn("Invalid proxy URL: ", err) + Logger.Warn("Invalid proxy URL: ", err) return http.DefaultClient } transport := &http.Transport{ @@ -233,7 +232,7 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc. token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...) if err != nil { - global.Logger.Warn("oauthConfig.Exchange() failed: ", err) + Logger.Warn("oauthConfig.Exchange() failed: ", err) return errors.New("GetOauthTokenError"), nil } @@ -244,7 +243,7 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc. v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}) idToken, err2 := v.Verify(ctx, rawIDToken) if err2 != nil { - global.Logger.Warn("IdTokenVerifyError: ", err2) + Logger.Warn("IdTokenVerifyError: ", err2) return errors.New("IdTokenVerifyError"), nil } if nonce != "" { @@ -253,12 +252,12 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc. Nonce string `json:"nonce"` } if err2 = idToken.Claims(&claims); err2 != nil { - global.Logger.Warn("Failed to parse ID Token claims: ", err) + Logger.Warn("Failed to parse ID Token claims: ", err) return errors.New("IDTokenClaimsError"), nil } if claims.Nonce != nonce { - global.Logger.Warn("Nonce does not match") + Logger.Warn("Nonce does not match") return errors.New("NonceDoesNotMatch"), nil } } @@ -268,18 +267,18 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc. client = oauthConfig.Client(ctx, token) resp, err := client.Get(provider.UserInfoEndpoint()) if err != nil { - global.Logger.Warn("failed getting user info: ", err) + Logger.Warn("failed getting user info: ", err) return errors.New("GetOauthUserInfoError"), nil } defer func() { if closeErr := resp.Body.Close(); closeErr != nil { - global.Logger.Warn("failed closing response body: ", closeErr) + Logger.Warn("failed closing response body: ", closeErr) } }() // 解析用户信息 if err = json.NewDecoder(resp.Body).Decode(userData); err != nil { - global.Logger.Warn("failed decoding user info: ", err) + Logger.Warn("failed decoding user info: ", err) return errors.New("DecodeOauthUserInfoError"), nil } @@ -330,7 +329,7 @@ func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, o func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("open_id = ? and op = ?", openId, op).First(ut) + DB.Where("open_id = ? and op = ?", openId, op).First(ut) return ut } @@ -342,7 +341,7 @@ func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, o return err } utr.FromOauthUser(userId, oauthUser, oauthType, op) - return global.DB.Create(utr).Error + return DB.Create(utr).Error } // UnBindOauthUser: Unbind third party account @@ -352,25 +351,25 @@ func (os *OauthService) UnBindOauthUser(userId uint, op string) error { // UnBindThird: Unbind third party account func (os *OauthService) UnBindThird(op string, userId uint) error { - return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error + return DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error } // DeleteUserByUserId: When user is deleted, delete all third party bindings func (os *OauthService) DeleteUserByUserId(userId uint) error { - return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error + return DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error } // InfoById 根据id获取Oauth信息 func (os *OauthService) InfoById(id uint) *model.Oauth { oauthInfo := &model.Oauth{} - global.DB.Where("id = ?", id).First(oauthInfo) + DB.Where("id = ?", id).First(oauthInfo) return oauthInfo } // InfoByOp 根据op获取Oauth信息 func (os *OauthService) InfoByOp(op string) *model.Oauth { oauthInfo := &model.Oauth{} - global.DB.Where("op = ?", op).First(oauthInfo) + DB.Where("op = ?", op).First(oauthInfo) return oauthInfo } @@ -393,7 +392,7 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res res = &model.OauthList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.Oauth{}) + tx := DB.Model(&model.Oauth{}) if where != nil { where(tx) } @@ -406,7 +405,7 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res // GetTypeByOp 根据op获取OauthType func (os *OauthService) GetTypeByOp(op string) (error, string) { oauthInfo := &model.Oauth{} - if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil { + if DB.Where("op = ?", op).First(oauthInfo).Error != nil { return fmt.Errorf("OAuth provider with op '%s' not found", op), "" } return nil, oauthInfo.OauthType @@ -424,7 +423,7 @@ func (os *OauthService) ValidateOauthProvider(op string) error { func (os *OauthService) IsOauthProviderExist(op string) bool { oauthInfo := &model.Oauth{} // 使用 Gorm 的 Take 方法查找符合条件的记录 - if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { + if err := DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { return false } return true @@ -436,11 +435,11 @@ func (os *OauthService) Create(oauthInfo *model.Oauth) error { if err != nil { return err } - res := global.DB.Create(oauthInfo).Error + res := DB.Create(oauthInfo).Error return res } func (os *OauthService) Delete(oauthInfo *model.Oauth) error { - return global.DB.Delete(oauthInfo).Error + return DB.Delete(oauthInfo).Error } // Update 更新 @@ -449,13 +448,13 @@ func (os *OauthService) Update(oauthInfo *model.Oauth) error { if err != nil { return err } - return global.DB.Model(oauthInfo).Updates(oauthInfo).Error + return DB.Model(oauthInfo).Updates(oauthInfo).Error } // GetOauthProviders 获取所有的provider func (os *OauthService) GetOauthProviders() []string { var res []string - global.DB.Model(&model.Oauth{}).Pluck("op", &res) + DB.Model(&model.Oauth{}).Pluck("op", &res) return res } diff --git a/service/peer.go b/service/peer.go index f2f5850..1e2649d 100644 --- a/service/peer.go +++ b/service/peer.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -12,24 +11,24 @@ type PeerService struct { // FindById 根据id查找 func (ps *PeerService) FindById(id string) *model.Peer { p := &model.Peer{} - global.DB.Where("id = ?", id).First(p) + DB.Where("id = ?", id).First(p) return p } func (ps *PeerService) FindByUuid(uuid string) *model.Peer { p := &model.Peer{} - global.DB.Where("uuid = ?", uuid).First(p) + DB.Where("uuid = ?", uuid).First(p) return p } func (ps *PeerService) InfoByRowId(id uint) *model.Peer { p := &model.Peer{} - global.DB.Where("row_id = ?", id).First(p) + DB.Where("row_id = ?", id).First(p) return p } // FindByUserIdAndUuid 根据用户id和uuid查找peer func (ps *PeerService) FindByUserIdAndUuid(uuid string, userId uint) *model.Peer { p := &model.Peer{} - global.DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p) + DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p) return p } @@ -43,7 +42,7 @@ func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint) } else { // 不存在则创建 /*if deviceId != "" { - global.DB.Create(&model.Peer{ + DB.Create(&model.Peer{ Id: deviceId, Uuid: uuid, UserId: userId, @@ -56,13 +55,13 @@ func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint) func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) { peer := ps.FindByUserIdAndUuid(uuid, userId) if peer.RowId > 0 { - global.DB.Model(peer).Update("user_id", 0) + DB.Model(peer).Update("user_id", 0) } } // EraseUserId 清除用户id, 用于用户删除 func (ps *PeerService) EraseUserId(userId uint) error { - return global.DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error + return DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error } // ListByUserIds 根据用户id取列表 @@ -70,7 +69,7 @@ func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res * res = &model.PeerList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.Peer{}) + tx := DB.Model(&model.Peer{}) tx.Where("user_id in (?)", userIds) tx.Count(&res.Total) tx.Scopes(Paginate(page, pageSize)) @@ -82,7 +81,7 @@ func (ps *PeerService) List(page, pageSize uint, where func(tx *gorm.DB)) (res * res = &model.PeerList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.Peer{}) + tx := DB.Model(&model.Peer{}) if where != nil { where(tx) } @@ -106,14 +105,14 @@ func (ps *PeerService) ListFilterByUserId(page, pageSize uint, where func(tx *go // Create 创建 func (ps *PeerService) Create(u *model.Peer) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } // Delete 删除, 同时也应该删除token func (ps *PeerService) Delete(u *model.Peer) error { uuid := u.Uuid - err := global.DB.Delete(u).Error + err := DB.Delete(u).Error if err != nil { return err } @@ -124,7 +123,7 @@ func (ps *PeerService) Delete(u *model.Peer) error { // GetUuidListByIDs 根据ids获取uuid列表 func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) { var uuids []string - err := global.DB.Model(&model.Peer{}). + err := DB.Model(&model.Peer{}). Where("row_id in (?)", ids). Pluck("uuid", &uuids).Error return uuids, err @@ -133,7 +132,7 @@ func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) { // BatchDelete 批量删除, 同时也应该删除token func (ps *PeerService) BatchDelete(ids []uint) error { uuids, err := ps.GetUuidListByIDs(ids) - err = global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error + err = DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error if err != nil { return err } @@ -143,5 +142,5 @@ func (ps *PeerService) BatchDelete(ids []uint) error { // Update 更新 func (ps *PeerService) Update(u *model.Peer) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } diff --git a/service/serverCmd.go b/service/serverCmd.go index 5681054..9f29499 100644 --- a/service/serverCmd.go +++ b/service/serverCmd.go @@ -2,7 +2,6 @@ package service import ( "fmt" - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "net" "time" @@ -15,7 +14,7 @@ func (is *ServerCmdService) List(page, pageSize uint) (res *model.ServerCmdList) res = &model.ServerCmdList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.ServerCmd{}) + tx := DB.Model(&model.ServerCmd{}) tx.Count(&res.Total) tx.Scopes(Paginate(page, pageSize)) tx.Find(&res.ServerCmds) @@ -25,18 +24,18 @@ func (is *ServerCmdService) List(page, pageSize uint) (res *model.ServerCmdList) // Info func (is *ServerCmdService) Info(id uint) *model.ServerCmd { u := &model.ServerCmd{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } // Delete func (is *ServerCmdService) Delete(u *model.ServerCmd) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Create func (is *ServerCmdService) Create(u *model.ServerCmd) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } @@ -45,9 +44,9 @@ func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (stri port := 0 switch target { case model.ServerCmdTargetIdServer: - port = global.Config.Rustdesk.IdServerPort - 1 + port = Config.Rustdesk.IdServerPort - 1 case model.ServerCmdTargetRelayServer: - port = global.Config.Rustdesk.RelayServerPort + port = Config.Rustdesk.RelayServerPort } //组装命令 cmd = cmd + " " + arg @@ -73,14 +72,14 @@ func (is *ServerCmdService) SendSocketCmd(ty string, port int, cmd string) (stri } conn, err := net.Dial(tcp, fmt.Sprintf("%s:%v", addr, port)) if err != nil { - global.Logger.Debugf("%s connect to id server failed: %v", ty, err) + Logger.Debugf("%s connect to id server failed: %v", ty, err) return "", err } defer conn.Close() //发送命令 _, err = conn.Write([]byte(cmd)) if err != nil { - global.Logger.Debugf("%s send cmd failed: %v", ty, err) + Logger.Debugf("%s send cmd failed: %v", ty, err) return "", err } time.Sleep(100 * time.Millisecond) @@ -88,12 +87,12 @@ func (is *ServerCmdService) SendSocketCmd(ty string, port int, cmd string) (stri buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil && err.Error() != "EOF" { - global.Logger.Debugf("%s read response failed: %v", ty, err) + Logger.Debugf("%s read response failed: %v", ty, err) return "", err } return string(buf[:n]), nil } func (is *ServerCmdService) Update(f *model.ServerCmd) error { - return global.DB.Model(f).Updates(f).Error + return DB.Model(f).Updates(f).Error } diff --git a/service/service.go b/service/service.go index bfd01db..6657a38 100644 --- a/service/service.go +++ b/service/service.go @@ -1,7 +1,11 @@ package service import ( + "github.com/lejianwen/rustdesk-api/v2/config" + "github.com/lejianwen/rustdesk-api/v2/lib/jwt" + "github.com/lejianwen/rustdesk-api/v2/lib/lock" "github.com/lejianwen/rustdesk-api/v2/model" + log "github.com/sirupsen/logrus" "gorm.io/gorm" ) @@ -21,12 +25,31 @@ type Service struct { *LdapService } -func New() *Service { - all := new(Service) - return all +type Dependencies struct { + Config *config.Config + DB *gorm.DB + Logger *log.Logger + Jwt *jwt.Jwt + Lock *lock.Locker } -var AllService = New() +var Config *config.Config +var DB *gorm.DB +var Logger *log.Logger +var Jwt *jwt.Jwt +var Lock lock.Locker + +var AllService *Service + +func New(c *config.Config, g *gorm.DB, l *log.Logger, j *jwt.Jwt, lo lock.Locker) *Service { + Config = c + DB = g + Logger = l + Jwt = j + Lock = lo + AllService = new(Service) + return AllService +} func Paginate(page, pageSize uint) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { diff --git a/service/shareRecord.go b/service/shareRecord.go index c6b5892..cfd188c 100644 --- a/service/shareRecord.go +++ b/service/shareRecord.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -12,7 +11,7 @@ type ShareRecordService struct { // InfoById 根据用户id取用户信息 func (srs *ShareRecordService) InfoById(id uint) *model.ShareRecord { u := &model.ShareRecord{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } @@ -20,7 +19,7 @@ func (srs *ShareRecordService) List(page, pageSize uint, where func(tx *gorm.DB) res = &model.ShareRecordList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.ShareRecord{}) + tx := DB.Model(&model.ShareRecord{}) if where != nil { where(tx) } @@ -32,18 +31,18 @@ func (srs *ShareRecordService) List(page, pageSize uint, where func(tx *gorm.DB) // Create 创建 func (srs *ShareRecordService) Create(u *model.ShareRecord) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (srs *ShareRecordService) Delete(u *model.ShareRecord) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (srs *ShareRecordService) Update(u *model.ShareRecord) error { - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } func (srs *ShareRecordService) BatchDelete(ids []uint) error { - return global.DB.Where("id in (?)", ids).Delete(&model.ShareRecord{}).Error + return DB.Where("id in (?)", ids).Delete(&model.ShareRecord{}).Error } diff --git a/service/tag.go b/service/tag.go index 3443af9..eb82532 100644 --- a/service/tag.go +++ b/service/tag.go @@ -1,7 +1,6 @@ package service import ( - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "gorm.io/gorm" ) @@ -11,12 +10,12 @@ type TagService struct { func (s *TagService) Info(id uint) *model.Tag { p := &model.Tag{} - global.DB.Where("id = ?", id).First(p) + DB.Where("id = ?", id).First(p) return p } func (s *TagService) InfoByUserIdAndNameAndCollectionId(userid uint, name string, cid uint) *model.Tag { p := &model.Tag{} - global.DB.Where("user_id = ? and name = ? and collection_id = ?", userid, name, cid).First(p) + DB.Where("user_id = ? and name = ? and collection_id = ?", userid, name, cid).First(p) return p } @@ -34,7 +33,7 @@ func (s *TagService) ListByUserIdAndCollectionId(userId, cid uint) (res *model.T return } func (s *TagService) UpdateTags(userId uint, tags map[string]uint) { - tx := global.DB.Begin() + tx := DB.Begin() //先查询所有tag var allTags []*model.Tag tx.Where("user_id = ?", userId).Find(&allTags) @@ -66,7 +65,7 @@ func (s *TagService) UpdateTags(userId uint, tags map[string]uint) { // InfoById 根据用户id取用户信息 func (s *TagService) InfoById(id uint) *model.Tag { u := &model.Tag{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } @@ -74,7 +73,7 @@ func (s *TagService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *mo res = &model.TagList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.Tag{}) + tx := DB.Model(&model.Tag{}) if where != nil { where(tx) } @@ -86,14 +85,14 @@ func (s *TagService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *mo // Create 创建 func (s *TagService) Create(u *model.Tag) error { - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } func (s *TagService) Delete(u *model.Tag) error { - return global.DB.Delete(u).Error + return DB.Delete(u).Error } // Update 更新 func (s *TagService) Update(u *model.Tag) error { - return global.DB.Model(u).Select("*").Omit("created_at").Updates(u).Error + return DB.Model(u).Select("*").Omit("created_at").Updates(u).Error } diff --git a/service/user.go b/service/user.go index 6880843..6a35033 100644 --- a/service/user.go +++ b/service/user.go @@ -2,7 +2,6 @@ package service import ( "errors" - "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/utils" "math/rand" @@ -20,43 +19,43 @@ type UserService struct { // InfoById 根据用户id取用户信息 func (us *UserService) InfoById(id uint) *model.User { u := &model.User{} - global.DB.Where("id = ?", id).First(u) + DB.Where("id = ?", id).First(u) return u } // InfoByUsername 根据用户名取用户信息 func (us *UserService) InfoByUsername(un string) *model.User { u := &model.User{} - global.DB.Where("username = ?", un).First(u) + DB.Where("username = ?", un).First(u) return u } // InfoByEmail 根据邮箱取用户信息 func (us *UserService) InfoByEmail(email string) *model.User { u := &model.User{} - global.DB.Where("email = ?", email).First(u) + DB.Where("email = ?", email).First(u) return u } // InfoByOpenid 根据openid取用户信息 func (us *UserService) InfoByOpenid(openid string) *model.User { u := &model.User{} - global.DB.Where("openid = ?", openid).First(u) + DB.Where("openid = ?", openid).First(u) return u } // InfoByUsernamePassword 根据用户名密码取用户信息 func (us *UserService) InfoByUsernamePassword(username, password string) *model.User { - if global.Config.Ldap.Enable { + if Config.Ldap.Enable { u, err := AllService.LdapService.Authenticate(username, password) if err == nil { return u } - global.Logger.Errorf("LDAP authentication failed, %v", err) - global.Logger.Warn("Fallback to local database") + Logger.Errorf("LDAP authentication failed, %v", err) + Logger.Warn("Fallback to local database") } u := &model.User{} - global.DB.Where("username = ? and password = ?", username, us.EncryptPassword(password)).First(u) + DB.Where("username = ? and password = ?", username, us.EncryptPassword(password)).First(u) return u } @@ -64,21 +63,21 @@ func (us *UserService) InfoByUsernamePassword(username, password string) *model. func (us *UserService) InfoByAccessToken(token string) (*model.User, *model.UserToken) { u := &model.User{} ut := &model.UserToken{} - global.DB.Where("token = ?", token).First(ut) + DB.Where("token = ?", token).First(ut) if ut.Id == 0 { return u, ut } if ut.ExpiredAt < time.Now().Unix() { return u, ut } - global.DB.Where("id = ?", ut.UserId).First(u) + DB.Where("id = ?", ut.UserId).First(u) return u, ut } // GenerateToken 生成token func (us *UserService) GenerateToken(u *model.User) string { - if len(global.Jwt.Key) > 0 { - return global.Jwt.GenerateToken(u.Id) + if len(Jwt.Key) > 0 { + return Jwt.GenerateToken(u.Id) } return utils.Md5(u.Username + time.Now().String()) } @@ -93,9 +92,9 @@ func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserTok DeviceId: llog.DeviceId, ExpiredAt: us.UserTokenExpireTimestamp(), } - global.DB.Create(ut) + DB.Create(ut) llog.UserTokenId = ut.UserId - global.DB.Create(llog) + DB.Create(llog) if llog.Uuid != "" { AllService.PeerService.UuidBindUserId(llog.DeviceId, llog.Uuid, u.Id) } @@ -116,7 +115,7 @@ func (us *UserService) List(page, pageSize uint, where func(tx *gorm.DB)) (res * res = &model.UserList{} res.Page = int64(page) res.PageSize = int64(pageSize) - tx := global.DB.Model(&model.User{}) + tx := DB.Model(&model.User{}) if where != nil { where(tx) } @@ -127,7 +126,7 @@ func (us *UserService) List(page, pageSize uint, where func(tx *gorm.DB)) (res * } func (us *UserService) ListByIds(ids []uint) (res []*model.User) { - global.DB.Where("id in ?", ids).Find(&res) + DB.Where("id in ?", ids).Find(&res) return res } @@ -141,14 +140,14 @@ func (us *UserService) ListByGroupId(groupId, page, pageSize uint) (res *model.U // ListIdsByGroupId 根据组id取用户id列表 func (us *UserService) ListIdsByGroupId(groupId uint) (ids []uint) { - global.DB.Model(&model.User{}).Where("group_id = ?", groupId).Pluck("id", &ids) + DB.Model(&model.User{}).Where("group_id = ?", groupId).Pluck("id", &ids) return ids } // ListIdAndNameByGroupId 根据组id取用户id和用户名列表 func (us *UserService) ListIdAndNameByGroupId(groupId uint) (res []*model.User) { - global.DB.Model(&model.User{}).Where("group_id = ?", groupId).Select("id, username").Find(&res) + DB.Model(&model.User{}).Where("group_id = ?", groupId).Select("id, username").Find(&res) return res } @@ -170,14 +169,14 @@ func (us *UserService) Create(u *model.User) error { } u.Username = us.formatUsername(u.Username) u.Password = us.EncryptPassword(u.Password) - res := global.DB.Create(u).Error + res := DB.Create(u).Error return res } // GetUuidByToken 根据token和user取uuid func (us *UserService) GetUuidByToken(u *model.User, token string) string { ut := &model.UserToken{} - err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error + err := DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error if err != nil { return "" } @@ -187,7 +186,7 @@ func (us *UserService) GetUuidByToken(u *model.User, token string) string { // Logout 退出登录 -> 删除token, 解绑uuid func (us *UserService) Logout(u *model.User, token string) error { uuid := us.GetUuidByToken(u, token) - err := global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error + err := DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error if err != nil { return err } @@ -203,7 +202,7 @@ func (us *UserService) Delete(u *model.User) error { if userCount <= 1 && us.IsAdmin(u) { return errors.New("The last admin user cannot be deleted") } - tx := global.DB.Begin() + tx := DB.Begin() // 删除用户 if err := tx.Delete(u).Error; err != nil { tx.Rollback() @@ -232,7 +231,7 @@ func (us *UserService) Delete(u *model.User) error { tx.Commit() // 删除关联的peer if err := AllService.PeerService.EraseUserId(u.Id); err != nil { - global.Logger.Warn("User deleted successfully, but failed to unlink peer.") + Logger.Warn("User deleted successfully, but failed to unlink peer.") return nil } return nil @@ -249,28 +248,28 @@ func (us *UserService) Update(u *model.User) error { return errors.New("The last admin user cannot be disabled or demoted") } } - return global.DB.Model(u).Updates(u).Error + return DB.Model(u).Updates(u).Error } // FlushToken 清空token func (us *UserService) FlushToken(u *model.User) error { - return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error + return DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error } // FlushTokenByUuid 清空token func (us *UserService) FlushTokenByUuid(uuid string) error { - return global.DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error + return DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error } // FlushTokenByUuids 清空token func (us *UserService) FlushTokenByUuids(uuids []string) error { - return global.DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error + return DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error } // UpdatePassword 更新密码 func (us *UserService) UpdatePassword(u *model.User, password string) error { u.Password = us.EncryptPassword(password) - err := global.DB.Model(u).Update("password", u.Password).Error + err := DB.Model(u).Update("password", u.Password).Error if err != nil { return err } @@ -306,8 +305,8 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User { // RegisterByOauth 注册 func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) { - global.Lock.Lock("registerByOauth") - defer global.Lock.UnLock("registerByOauth") + Lock.Lock("registerByOauth") + defer Lock.UnLock("registerByOauth") ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) if ut.Id != 0 { return nil, us.InfoById(ut.UserId) @@ -335,12 +334,12 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (e } if user.Id != 0 { ut.FromOauthUser(user.Id, oauthUser, oauthType, op) - global.DB.Create(ut) + DB.Create(ut) return nil, user } } - tx := global.DB.Begin() + tx := DB.Begin() ut = &model.UserThird{} ut.FromOauthUser(0, oauthUser, oauthType, op) // The initial username should be formatted @@ -372,27 +371,27 @@ func (us *UserService) GenerateUsernameByOauth(name string) string { // UserThirdsByUserId func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird) { - global.DB.Where("user_id = ?", userId).Find(&res) + DB.Where("user_id = ?", userId).Find(&res) return res } func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("user_id = ? and op = ?", userId, op).First(ut) + DB.Where("user_id = ? and op = ?", userId, op).First(ut) return ut } // FindLatestUserIdFromLoginLogByUuid 根据uuid查找最后登录的用户id func (us *UserService) FindLatestUserIdFromLoginLogByUuid(uuid string) uint { llog := &model.LoginLog{} - global.DB.Where("uuid = ?", uuid).Order("id desc").First(llog) + DB.Where("uuid = ?", uuid).Order("id desc").First(llog) return llog.UserId } // IsPasswordEmptyById 根据用户id判断密码是否为空,主要用于第三方登录的自动注册 func (us *UserService) IsPasswordEmptyById(id uint) bool { u := &model.User{} - if global.DB.Where("id = ?", id).First(u).Error != nil { + if DB.Where("id = ?", id).First(u).Error != nil { return false } return u.Password == "" @@ -401,7 +400,7 @@ func (us *UserService) IsPasswordEmptyById(id uint) bool { // IsPasswordEmptyByUsername 根据用户id判断密码是否为空,主要用于第三方登录的自动注册 func (us *UserService) IsPasswordEmptyByUsername(username string) bool { u := &model.User{} - if global.DB.Where("username = ?", username).First(u).Error != nil { + if DB.Where("username = ?", username).First(u).Error != nil { return false } return u.Password == "" @@ -431,7 +430,7 @@ func (us *UserService) TokenList(page uint, size uint, f func(tx *gorm.DB)) *mod res := &model.UserTokenList{} res.Page = int64(page) res.PageSize = int64(size) - tx := global.DB.Model(&model.UserToken{}) + tx := DB.Model(&model.UserToken{}) if f != nil { f(tx) } @@ -443,12 +442,12 @@ func (us *UserService) TokenList(page uint, size uint, f func(tx *gorm.DB)) *mod func (us *UserService) TokenInfoById(id uint) *model.UserToken { ut := &model.UserToken{} - global.DB.Where("id = ?", id).First(ut) + DB.Where("id = ?", id).First(ut) return ut } func (us *UserService) DeleteToken(l *model.UserToken) error { - return global.DB.Delete(l).Error + return DB.Delete(l).Error } // Helper functions, used for formatting username @@ -461,20 +460,20 @@ func (us *UserService) formatUsername(username string) string { // Helper functions, getUserCount func (us *UserService) getUserCount() int64 { var count int64 - global.DB.Model(&model.User{}).Count(&count) + DB.Model(&model.User{}).Count(&count) return count } // helper functions, getAdminUserCount func (us *UserService) getAdminUserCount() int64 { var count int64 - global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count) + DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count) return count } // UserTokenExpireTimestamp 生成用户token过期时间 func (us *UserService) UserTokenExpireTimestamp() int64 { - exp := global.Config.App.TokenExpire + exp := Config.App.TokenExpire if exp == 0 { //默认七天 exp = 604800 @@ -484,7 +483,7 @@ func (us *UserService) UserTokenExpireTimestamp() int64 { func (us *UserService) RefreshAccessToken(ut *model.UserToken) { ut.ExpiredAt = us.UserTokenExpireTimestamp() - global.DB.Model(ut).Update("expired_at", ut.ExpiredAt) + DB.Model(ut).Update("expired_at", ut.ExpiredAt) } func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) { if ut.ExpiredAt-time.Now().Unix() < 86400 { @@ -493,11 +492,11 @@ func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) { } func (us *UserService) BatchDeleteUserToken(ids []uint) error { - return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error + return DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error } func (us *UserService) VerifyJWT(token string) (uint, error) { - return global.Jwt.ParseToken(token) + return Jwt.ParseToken(token) } // IsUsernameExists 判断用户名是否存在, it will check the internal database and LDAP(if enabled) @@ -507,7 +506,7 @@ func (us *UserService) IsUsernameExists(username string) bool { func (us *UserService) IsUsernameExistsLocal(username string) bool { u := &model.User{} - global.DB.Where("username = ?", username).First(u) + DB.Where("username = ?", username).First(u) return u.Id != 0 }