From daeae19194cced7c96255aa3cb02c5ede8ee22a4 Mon Sep 17 00:00:00 2001 From: ljw <84855512@qq.com> Date: Tue, 5 Nov 2024 09:48:02 +0800 Subject: [PATCH] up oauth re --- cmd/apimain.go | 17 ++++++- http/controller/admin/oauth.go | 13 +++-- http/controller/api/ouath.go | 2 +- http/request/admin/user.go | 16 +++--- model/oauth.go | 91 ++++++++++++++-------------------- model/user.go | 23 ++++----- model/userThird.go | 21 ++++---- service/oauth.go | 80 ++++++++++++++---------------- service/user.go | 82 +++++++++++++++--------------- 9 files changed, 170 insertions(+), 175 deletions(-) diff --git a/cmd/apimain.go b/cmd/apimain.go index bde9723..b2b9c83 100644 --- a/cmd/apimain.go +++ b/cmd/apimain.go @@ -101,7 +101,7 @@ func main() { } func DatabaseAutoUpdate() { - version := 244 + version := 245 db := global.DB @@ -146,6 +146,21 @@ func DatabaseAutoUpdate() { if v.Version < uint(version) { Migrate(uint(version)) } + // 245迁移 + if v.Version < 245 { + //oauths 表的 oauth_type 字段设置为 op同样的值 + db.Exec("update oauths set oauth_type = op") + db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''") + db.Exec("update user_thirds set oauth_type = third_type, op = third_type") + //通过email迁移旧的google授权 + uts := make([]model.UserThird, 0) + db.Where("oauth_type = ?", "google").Find(&uts) + for _, ut := range uts { + if ut.UserId > 0 { + db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId) + } + } + } } } diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index ade2f3c..ba03761 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) { response.Fail(c, 101, errList[0]) return } - - ex := service.AllService.OauthService.InfoByOp(f.Op) + u := f.ToOauth() + err := u.FormatOauthInfo() + if err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error()) + return + } + ex := service.AllService.OauthService.InfoByOp(u.Op) if ex.Id > 0 { response.Fail(c, 101, response.TranslateMsg(c, "ItemExists")) return } - - u := f.ToOauth() - err := service.AllService.OauthService.Create(u) + err = service.AllService.OauthService.Create(u) if err != nil { response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) return diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 3d7ab67..58624ab 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) { oauthCache.UserId = user.Id oauthService.SetOauthCache(cacheKey, oauthCache, 0) // 如果是webadmin,登录成功后跳转到webadmin - if oauthCache.DeviceType == "webadmin" { + if oauthCache.DeviceType == model.LoginLogClientWebAdmin { /*service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: "webadmin", diff --git a/http/request/admin/user.go b/http/request/admin/user.go index 227332d..47b39c3 100644 --- a/http/request/admin/user.go +++ b/http/request/admin/user.go @@ -5,15 +5,15 @@ import ( ) type UserForm struct { - Id uint `json:"id"` - Username string `json:"username" validate:"required,gte=4,lte=10"` - Email string `json:"email" validate:"required,email"` + Id uint `json:"id"` + Username string `json:"username" validate:"required,gte=4,lte=10"` + Email string `json:"email"` //validate:"required,email" email不强制 //Password string `json:"password" validate:"required,gte=4,lte=20"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - GroupId uint `json:"group_id" validate:"required"` - IsAdmin *bool `json:"is_admin" ` - Status model.StatusCode `json:"status" validate:"required,gte=0"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + GroupId uint `json:"group_id" validate:"required"` + IsAdmin *bool `json:"is_admin" ` + Status model.StatusCode `json:"status" validate:"required,gte=0"` } func (uf *UserForm) FromUser(user *model.User) *UserForm { diff --git a/model/oauth.go b/model/oauth.go index e879caf..b6004ab 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -1,9 +1,9 @@ package model import ( + "errors" "strconv" "strings" - "errors" ) const OIDC_DEFAULT_SCOPES = "openid,profile,email" @@ -27,32 +27,23 @@ func ValidateOauthType(oauthType string) error { } const ( - OauthNameGithub string = "GitHub" - OauthNameGoogle string = "Google" - OauthNameOidc string = "OIDC" - OauthNameWebauth string = "WebAuth" -) - -const ( - UserEndpointGithub string = "https://api.github.com/user" - IssuerGoogle string = "https://accounts.google.com" + UserEndpointGithub string = "https://api.github.com/user" + IssuerGoogle string = "https://accounts.google.com" ) type Oauth struct { IdModel - Op string `json:"op"` - OauthType string `json:"oauth_type"` - ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` - RedirectUrl string `json:"redirect_url"` - AutoRegister *bool `json:"auto_register"` - Scopes string `json:"scopes"` - Issuer string `json:"issuer"` + Op string `json:"op"` + OauthType string `json:"oauth_type"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RedirectUrl string `json:"redirect_url"` + AutoRegister *bool `json:"auto_register"` + Scopes string `json:"scopes"` + Issuer string `json:"issuer"` TimeModel } - - // Helper function to format oauth info, it's used in the update and create method func (oa *Oauth) FormatOauthInfo() error { oauthType := strings.TrimSpace(oa.OauthType) @@ -60,25 +51,20 @@ func (oa *Oauth) FormatOauthInfo() error { if err != nil { return err } + switch oauthType { + case OauthTypeGithub: + oa.Op = OauthTypeGithub + case OauthTypeGoogle: + oa.Op = OauthTypeGoogle + } // check if the op is empty, set the default value op := strings.TrimSpace(oa.Op) - if op == "" { - switch oauthType { - case OauthTypeGithub: - oa.Op = OauthNameGithub - case OauthTypeGoogle: - oa.Op = OauthNameGoogle - case OauthTypeOidc: - oa.Op = OauthNameOidc - case OauthTypeWebauth: - oa.Op = OauthNameWebauth - default: - oa.Op = oauthType - } + if op == "" && oauthType == OauthTypeOidc { + oa.Op = OauthTypeOidc } // check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value issuer := strings.TrimSpace(oa.Issuer) - // If the oauth type is google and the issuer is empty, set the issuer to the default value + // If the oauth type is google and the issuer is empty, set the issuer to the default value if oauthType == OauthTypeGoogle && issuer == "" { oa.Issuer = IssuerGoogle } @@ -86,12 +72,12 @@ func (oa *Oauth) FormatOauthInfo() error { } type OauthUser struct { - OpenId string `json:"open_id" gorm:"not null;index"` - Name string `json:"name"` - Username string `json:"username"` - Email string `json:"email"` - VerifiedEmail bool `json:"verified_email,omitempty"` - Picture string `json:"picture,omitempty"` + OpenId string `json:"open_id" gorm:"not null;index"` + Name string `json:"name"` + Username string `json:"username"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email,omitempty"` + Picture string `json:"picture,omitempty"` } func (ou *OauthUser) ToUser(user *User, overideUsername bool) { @@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser { if ou.PreferredUsername != "" { username = ou.PreferredUsername } else { - username = strings.ToLower(strings.Split(ou.Email, "@")[0]) + username = strings.ToLower(ou.Email) } return &OauthUser{ @@ -135,29 +121,26 @@ func (ou *OidcUser) ToOauthUser() *OauthUser { } } - type GithubUser struct { OauthUserBase - Id int `json:"id"` - Login string `json:"login"` - AvatarUrl string `json:"avatar_url"` - VerifiedEmail bool `json:"verified_email"` + Id int `json:"id"` + Login string `json:"login"` + AvatarUrl string `json:"avatar_url"` + VerifiedEmail bool `json:"verified_email"` } func (gu *GithubUser) ToOauthUser() *OauthUser { username := strings.ToLower(gu.Login) return &OauthUser{ - OpenId: strconv.Itoa(gu.Id), - Name: gu.Name, - Username: username, - Email: gu.Email, - VerifiedEmail: gu.VerifiedEmail, - Picture: gu.AvatarUrl, + OpenId: strconv.Itoa(gu.Id), + Name: gu.Name, + Username: username, + Email: gu.Email, + VerifiedEmail: gu.VerifiedEmail, + Picture: gu.AvatarUrl, } } - - type OauthList struct { Oauths []*Oauth `json:"list"` Pagination diff --git a/model/user.go b/model/user.go index 4be0049..fe2f99d 100644 --- a/model/user.go +++ b/model/user.go @@ -1,14 +1,9 @@ package model -import ( - "fmt" - "gorm.io/gorm" -) - type User struct { IdModel - Username string `json:"username" gorm:"default:'';not null;uniqueIndex"` - Email string `json:"email" gorm:"default:'';not null;uniqueIndex"` + Username string `json:"username" gorm:"default:'';not null;uniqueIndex"` + Email string `json:"email" gorm:"default:'';not null;index"` // Email string `json:"email" ` Password string `json:"-" gorm:"default:'';not null;"` Nickname string `json:"nickname" gorm:"default:'';not null;"` @@ -20,13 +15,13 @@ type User struct { } // BeforeSave 钩子用于确保 email 字段有合理的默认值 -func (u *User) BeforeSave(tx *gorm.DB) (err error) { - // 如果 email 为空,设置为默认值 - if u.Email == "" { - u.Email = fmt.Sprintf("%s@example.com", u.Username) - } - return nil -} +//func (u *User) BeforeSave(tx *gorm.DB) (err error) { +// // 如果 email 为空,设置为默认值 +// if u.Email == "" { +// u.Email = fmt.Sprintf("%s@example.com", u.Username) +// } +// return nil +//} type UserList struct { Users []*User `json:"list,omitempty"` diff --git a/model/userThird.go b/model/userThird.go index 5b8f5ef..eab222e 100644 --- a/model/userThird.go +++ b/model/userThird.go @@ -6,20 +6,21 @@ import ( type UserThird struct { IdModel - UserId uint ` json:"user_id" gorm:"not null;index"` + UserId uint `json:"user_id" gorm:"not null;index"` OauthUser - // UnionId string `json:"union_id" gorm:"not null;"` + UnionId string `json:"union_id" gorm:"default:'';not null;"` // OauthType string `json:"oauth_type" gorm:"not null;"` - OauthType string `json:"oauth_type"` - Op string `json:"op" gorm:"not null;"` + ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated + OauthType string `json:"oauth_type" gorm:"default:'';not null;"` + Op string `json:"op" gorm:"default:'';not null;"` TimeModel } func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) { - u.UserId = userId - u.OauthUser = *oauthUser - u.OauthType = oauthType - u.Op = op + u.UserId = userId + u.OauthUser = *oauthUser + u.OauthType = oauthType + u.Op = op // make sure email is lower case - u.Email = strings.ToLower(u.Email) -} \ No newline at end of file + u.Email = strings.ToLower(u.Email) +} diff --git a/service/oauth.go b/service/oauth.go index 1966237..0359c48 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -12,16 +12,15 @@ import ( // "golang.org/x/oauth2/google" "gorm.io/gorm" // "io" + "fmt" "net/http" "net/url" "strconv" "strings" "sync" "time" - "fmt" ) - type OauthService struct { } @@ -34,26 +33,26 @@ type OidcEndpoint struct { } type OauthCacheItem struct { - UserId uint `json:"user_id"` - Id string `json:"id"` //rustdesk的设备ID - Op string `json:"op"` - Action string `json:"action"` - Uuid string `json:"uuid"` - DeviceName string `json:"device_name"` - DeviceOs string `json:"device_os"` - DeviceType string `json:"device_type"` - OpenId string `json:"open_id"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` + UserId uint `json:"user_id"` + Id string `json:"id"` //rustdesk的设备ID + Op string `json:"op"` + Action string `json:"action"` + Uuid string `json:"uuid"` + DeviceName string `json:"device_name"` + DeviceOs string `json:"device_os"` + DeviceType string `json:"device_type"` + OpenId string `json:"open_id"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` } func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { return &model.OauthUser{ - OpenId: oci.OpenId, + OpenId: oci.OpenId, Username: oci.Username, - Name: oci.Name, - Email: oci.Email, + Name: oci.Name, + Email: oci.Email, } } @@ -64,14 +63,13 @@ const ( OauthActionTypeBind = "bind" ) -func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { - oa.OpenId = oauthUser.OpenId - oa.Username = oauthUser.Username - oa.Name = oauthUser.Name - oa.Email = oauthUser.Email +func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { + oci.OpenId = oauthUser.OpenId + oci.Username = oauthUser.Username + oci.Name = oauthUser.Name + oci.Email = oauthUser.Email } - func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { v, ok := OauthCache.Load(key) if !ok { @@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O if err != nil { return err, nil, nil } - oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} + oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL} oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) default: return errors.New("unsupported OAuth type"), nil, nil @@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) return nil, user.ToOauthUser() } - // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) { var user = &model.OidcUser{} if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { return err, nil @@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser * } oauthType := oauthInfo.OauthType switch oauthType { - case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(oauthConfig, code) - case model.OauthTypeOidc, model.OauthTypeGoogle: + case model.OauthTypeGithub: + err, oauthUser = os.githubCallback(oauthConfig, code) + case model.OauthTypeOidc, model.OauthTypeGoogle: err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { return err, nil } - err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) - default: - return errors.New("unsupported OAuth type"), nil - } - return err, oauthUser + err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) + default: + return errors.New("unsupported OAuth type"), nil + } + return err, oauthUser } - func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird { ut := &model.UserThird{} global.DB.Where("open_id = ? and op = ?", openId, op).First(ut) @@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth { // Helper function to get scopes by operation func (os *OauthService) getScopesByOp(op string) []string { - scopes := os.InfoByOp(op).Scopes + scopes := os.InfoByOp(op).Scopes return os.constructScopes(scopes) } // Helper function to construct scopes func (os *OauthService) constructScopes(scopes string) []string { - scopes = strings.TrimSpace(scopes) - if scopes == "" { - scopes = model.OIDC_DEFAULT_SCOPES - } - return strings.Split(scopes, ",") + scopes = strings.TrimSpace(scopes) + if scopes == "" { + scopes = model.OIDC_DEFAULT_SCOPES + } + return strings.Split(scopes, ",") } func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { @@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m } return fmt.Errorf("no primary verified email found") -} \ No newline at end of file +} diff --git a/service/user.go b/service/user.go index 774ec1c..c6e80ed 100644 --- a/service/user.go +++ b/service/user.go @@ -5,13 +5,13 @@ import ( adResp "Gwen/http/response/admin" "Gwen/model" "Gwen/utils" + "errors" "github.com/gin-gonic/gin" "gorm.io/gorm" "math/rand" "strconv" - "time" "strings" - "errors" + "time" ) type UserService struct { @@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User { global.DB.Where("id = ?", id).First(u) return u } + // InfoByUsername 根据用户名取用户信息 func (us *UserService) InfoByUsername(un string) *model.User { u := &model.User{} @@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string { func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken { token := us.GenerateToken(u) ut := &model.UserToken{ - UserId: u.Id, - Token: token, + UserId: u.Id, + Token: token, DeviceUuid: llog.Uuid, DeviceId: llog.DeviceId, - ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), + ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), } global.DB.Create(ut) llog.UserTokenId = ut.UserId @@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error { // GetUuidByToken 根据token和user取uuid func (us *UserService) GetUuidByToken(u *model.User, token string) string { ut := &model.UserToken{} - err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error + err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error if err != nil { return "" } @@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error { tx.Rollback() return err } - tx.Commit() // 删除关联的peer if err := AllService.PeerService.EraseUserId(u.Id); err != nil { tx.Rollback() return err } + tx.Commit() return nil } @@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error { if us.IsAdmin(currentUser) { adminCount := us.getAdminUserCount() // 如果这是唯一的管理员,确保不能禁用或取消管理员权限 - if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) { + if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) { return errors.New("The last admin user cannot be disabled or demoted") } } @@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User { } // RegisterByOauth 注册 -func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) { +func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) { global.Lock.Lock("registerByOauth") defer global.Lock.UnLock("registerByOauth") ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) if ut.Id != 0 { return nil, us.InfoById(ut.UserId) } - //check if this email has been registered - email := oauthUser.Email err, oauthType := AllService.OauthService.GetTypeByOp(op) if err != nil { return err, nil } - // if email is empty, use username and op as email - if email == "" { - email = oauthUser.Username + "@" + op - } - email = strings.ToLower(email) - // update email to oauthUser, in case it contain upper case - oauthUser.Email = email - user := us.InfoByEmail(email) - tx := global.DB.Begin() - if user.Id != 0 { - ut.FromOauthUser(user.Id, oauthUser, oauthType, op) - } else { - ut = &model.UserThird{} - ut.FromOauthUser(0, oauthUser, oauthType, op) - // The initial username should be formatted - username := us.formatUsername(oauthUser.Username) - usernameUnique := us.GenerateUsernameByOauth(username) - user = &model.User{ - Username: usernameUnique, - GroupId: 1, + //check if this email has been registered + email := oauthUser.Email + // only email is not empty + if email != "" { + email = strings.ToLower(email) + // update email to oauthUser, in case it contain upper case + oauthUser.Email = email + user := us.InfoByEmail(email) + if user.Id != 0 { + ut.FromOauthUser(user.Id, oauthUser, oauthType, op) + global.DB.Create(ut) + return nil, user } - oauthUser.ToUser(user, false) - tx.Create(user) - if user.Id == 0 { - tx.Rollback() - return errors.New("OauthRegisterFailed"), user - } - ut.UserId = user.Id } + + tx := global.DB.Begin() + ut = &model.UserThird{} + ut.FromOauthUser(0, oauthUser, oauthType, op) + // The initial username should be formatted + username := us.formatUsername(oauthUser.Username) + usernameUnique := us.GenerateUsernameByOauth(username) + user := &model.User{ + Username: usernameUnique, + GroupId: 1, + } + oauthUser.ToUser(user, false) + tx.Create(user) + if user.Id == 0 { + tx.Rollback() + return errors.New("OauthRegisterFailed"), user + } + ut.UserId = user.Id tx.Create(ut) tx.Commit() return nil, user @@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string { return username } -// Helper functions, getUserCount +// Helper functions, getUserCount func (us *UserService) getUserCount() int64 { var count int64 global.DB.Model(&model.User{}).Count(&count) @@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 { var count int64 global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count) return count -} \ No newline at end of file +}