diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index ad25393..cd879cb 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -63,6 +63,8 @@ func (ct *Login) Login(c *gin.Context) { response.Success(c, &adResp.LoginPayload{ Token: ut.Token, Username: u.Username, + Email: u.Email, + Avatar: u.Avatar, RouteNames: service.AllService.UserService.RouteNames(u), Nickname: u.Nickname, }) diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index 3444f3e..ade2f3c 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -5,7 +5,6 @@ import ( "Gwen/http/request/admin" adminReq "Gwen/http/request/admin" "Gwen/http/response" - "Gwen/model" "Gwen/service" "github.com/gin-gonic/gin" "strconv" @@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")) return } - v := service.AllService.OauthService.GetOauthCache(j.Code) - if v == nil { + oauthService := service.AllService.OauthService + oauthCache := oauthService.GetOauthCache(j.Code) + if oauthCache == nil { response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired")) return } - u := service.AllService.UserService.CurUser(c) - err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id) + oauthUser := oauthCache.ToOauthUser() + user := service.AllService.UserService.CurUser(c) + err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op) if err != nil { response.Fail(c, 101, response.TranslateMsg(c, "BindFail")) return } - v.UserId = u.Id - service.AllService.OauthService.SetOauthCache(j.Code, v, 0) - response.Success(c, v) + oauthCache.UserId = user.Id + oauthService.SetOauthCache(j.Code, oauthCache, 0) + response.Success(c, oauthCache) } func (o *Oauth) Unbind(c *gin.Context) { @@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound")) return } - if f.Op == model.OauthTypeGithub { - err = service.AllService.OauthService.UnBindGithubUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } + err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op) + if err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) + return } - if f.Op == model.OauthTypeGoogle { - err = service.AllService.OauthService.UnBindGoogleUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } - } - if f.Op == model.OauthTypeOidc { - err = service.AllService.OauthService.UnBindOidcUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } - } - response.Success(c, nil) } diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index 67f71b8..76d0aa6 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -286,10 +286,10 @@ func (ct *User) MyOauth(c *gin.Context) { var res []*adResp.UserOauthItem for _, oa := range oal.Oauths { item := &adResp.UserOauthItem{ - ThirdType: oa.Op, + Op: oa.Op, } for _, ut := range uts { - if ut.ThirdType == oa.Op { + if ut.Op == oa.Op { item.Status = 1 break } diff --git a/http/controller/api/login.go b/http/controller/api/login.go index f3d9481..6c0b95d 100644 --- a/http/controller/api/login.go +++ b/http/controller/api/login.go @@ -83,22 +83,10 @@ func (l *Login) Login(c *gin.Context) { // @Failure 500 {object} response.ErrorResponse // @Router /login-options [get] func (l *Login) LoginOptions(c *gin.Context) { - oauthOks := []string{} - err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeGithub) - } - err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeGoogle) - } - err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeOidc) - } - oauthOks = append(oauthOks, model.OauthTypeWebauth) + ops := service.AllService.OauthService.GetOauthProviders() + ops = append(ops, model.OauthTypeWebauth) var oidcItems []map[string]string - for _, v := range oauthOks { + for _, v := range ops { oidcItems = append(oidcItems, map[string]string{"name": v}) } common, err := json.Marshal(oidcItems) @@ -108,7 +96,7 @@ func (l *Login) LoginOptions(c *gin.Context) { } var res []string res = append(res, "common-oidc/"+string(common)) - for _, v := range oauthOks { + for _, v := range ops { res = append(res, "oidc/"+v) } c.JSON(http.StatusOK, res) diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 96f7a80..5367117 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -9,8 +9,6 @@ import ( "Gwen/service" "github.com/gin-gonic/gin" "net/http" - "strconv" - "strings" ) type Oauth struct { @@ -32,13 +30,17 @@ func (o *Oauth) OidcAuth(c *gin.Context) { response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error()) return } - //fmt.Println(f) - if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc { - response.Error(c, response.TranslateMsg(c, "ParamsError")) + + oauthService := service.AllService.OauthService + err = oauthService.ValidateOauthProvider(f.Op) + if err != nil { + response.Error(c, response.TranslateMsg(c, err.Error())) return } - err, code, url := service.AllService.OauthService.BeginAuth(f.Op) + var code string + var url string + err, code, url = oauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return @@ -149,70 +151,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) { c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state")) return } - cacheKey := state + oauthService := service.AllService.OauthService //从缓存中获取 - v := service.AllService.OauthService.GetOauthCache(cacheKey) - if v == nil { + oauthCache := oauthService.GetOauthCache(cacheKey) + if oauthCache == nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired")) return } - - ty := v.Op - ac := v.Action - var u *model.User - openid := "" - thirdName := "" - //fmt.Println("ty ac ", ty, ac) - - if ty == model.OauthTypeGithub { - code := c.Query("code") - err, userData := service.AllService.OauthService.GithubCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = strconv.Itoa(userData.Id) - thirdName = userData.Login - } else if ty == model.OauthTypeGoogle { - code := c.Query("code") - err, userData := service.AllService.OauthService.GoogleCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = userData.Email - //将空格替换成_ - thirdName = strings.Replace(userData.Name, " ", "_", -1) - } else if ty == model.OauthTypeOidc { - code := c.Query("code") - err, userData := service.AllService.OauthService.OidcCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = userData.Sub - thirdName = userData.PreferredUsername - } else { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError")) + op := oauthCache.Op + action := oauthCache.Action + var user *model.User + // 获取用户信息 + code := c.Query("code") + err, oauthUser := oauthService.Callback(code, op) + if err != nil { + c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) return } - if ac == service.OauthActionTypeBind { + userId := oauthCache.UserId + openid := oauthUser.OpenId + if action == service.OauthActionTypeBind { //fmt.Println("bind", ty, userData) - utr := service.AllService.OauthService.UserThirdInfo(ty, openid) + // 检查此openid是否已经绑定过 + utr := oauthService.UserThirdInfo(op, openid) if utr.UserId > 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser")) return } //绑定 - u = service.AllService.UserService.InfoById(v.UserId) - if u == nil { + user = service.AllService.UserService.InfoById(userId) + if user == nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound")) return } //绑定 - err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId) + err := oauthService.BindOauthUser(userId, oauthUser, op) if err != nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail")) return @@ -220,42 +195,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) { c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess")) return - } else if ac == service.OauthActionTypeLogin { + } else if action == service.OauthActionTypeLogin { //登录 - if v.UserId != 0 { + if userId != 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess")) return } - u = service.AllService.UserService.InfoByGithubId(openid) - if u == nil { - oa := service.AllService.OauthService.InfoByOp(ty) - if !*oa.AutoRegister { + user = service.AllService.UserService.InfoByOauthId(op, openid) + if user == nil { + oauthConfig := oauthService.InfoByOp(op) + if !*oauthConfig.AutoRegister { //c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定") - v.ThirdName = thirdName - v.ThirdOpenId = openid + oauthCache.UpdateFromOauthUser(oauthUser) url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey c.Redirect(http.StatusFound, url) return } //自动注册 - u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid) - if u.Id == 0 { + user = service.AllService.UserService.RegisterByOauth(oauthUser, op) + if user.Id == 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed")) return } } - v.UserId = u.Id - service.AllService.OauthService.SetOauthCache(cacheKey, v, 0) + oauthCache.UserId = user.Id + oauthService.SetOauthCache(cacheKey, oauthCache, 0) // 如果是webadmin,登录成功后跳转到webadmin - if v.DeviceType == "webadmin" { + if oauthCache.DeviceType == "webadmin" { /*service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: "webadmin", Uuid: "", //must be empty Ip: c.ClientIP(), Type: model.LoginLogTypeOauth, - Platform: v.DeviceOs, + Platform: oauthService.DeviceOs, })*/ url := global.Config.Rustdesk.ApiServer + "/_admin/#/" c.Redirect(http.StatusFound, url) diff --git a/http/request/admin/oauth.go b/http/request/admin/oauth.go index db519f8..8d8ad37 100644 --- a/http/request/admin/oauth.go +++ b/http/request/admin/oauth.go @@ -1,6 +1,9 @@ package admin -import "Gwen/model" +import ( + "Gwen/model" + "strings" +) type BindOauthForm struct { Op string `json:"op" binding:"required"` @@ -13,19 +16,37 @@ type UnBindOauthForm struct { Op string `json:"op" binding:"required"` } type OauthForm struct { - Id uint `json:"id"` - Op string `json:"op" validate:"required"` - Issuer string `json:"issuer" validate:"omitempty,url"` - Scopes string `json:"scopes" validate:"omitempty"` - ClientId string `json:"client_id" validate:"required"` - ClientSecret string `json:"client_secret" validate:"required"` - RedirectUrl string `json:"redirect_url" validate:"required"` - AutoRegister *bool `json:"auto_register"` + Id uint `json:"id"` + Op string `json:"op" validate:"omitempty"` + OauthType string `json:"oauth_type" validate:"required"` + Issuer string `json:"issuer" validate:"omitempty,url"` + Scopes string `json:"scopes" validate:"omitempty"` + ClientId string `json:"client_id" validate:"required"` + ClientSecret string `json:"client_secret" validate:"required"` + RedirectUrl string `json:"redirect_url" validate:"required"` + AutoRegister *bool `json:"auto_register"` } func (of *OauthForm) ToOauth() *model.Oauth { + op := strings.ToLower(of.Op) + op = strings.TrimSpace(op) + if op == "" { + switch of.OauthType { + case model.OauthTypeGithub: + of.Op = "GitHub" + case model.OauthTypeGoogle: + of.Op = "Google" + case model.OauthTypeOidc: + of.Op = "OIDC" + case model.OauthTypeWebauth: + of.Op = "WebAuth" + default: + of.Op = of.OauthType + } + } oa := &model.Oauth{ Op: of.Op, + OauthType: of.OauthType, ClientId: of.ClientId, ClientSecret: of.ClientSecret, RedirectUrl: of.RedirectUrl, diff --git a/http/request/admin/user.go b/http/request/admin/user.go index e29133c..f8093c5 100644 --- a/http/request/admin/user.go +++ b/http/request/admin/user.go @@ -5,20 +5,22 @@ import ( ) type UserForm struct { - Id uint `json:"id"` - Username string `json:"username" validate:"required,gte=4,lte=10"` + Id uint `json:"id"` + Username string `json:"username" validate:"required,gte=4,lte=10"` + Email string `json:"email" validate:"required,email"` //Password string `json:"password" validate:"required,gte=4,lte=20"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - GroupId uint `json:"group_id" validate:"required"` - IsAdmin *bool `json:"is_admin" ` - Status model.StatusCode `json:"status" validate:"required,gte=0"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + GroupId uint `json:"group_id" validate:"required"` + IsAdmin *bool `json:"is_admin" ` + Status model.StatusCode `json:"status" validate:"required,gte=0"` } func (uf *UserForm) FromUser(user *model.User) *UserForm { uf.Id = user.Id uf.Username = user.Username uf.Nickname = user.Nickname + uf.Email = user.Email uf.Avatar = user.Avatar uf.GroupId = user.GroupId uf.IsAdmin = user.IsAdmin @@ -30,6 +32,7 @@ func (uf *UserForm) ToUser() *model.User { user.Id = uf.Id user.Username = uf.Username user.Nickname = uf.Nickname + user.Email = uf.Email user.Avatar = uf.Avatar user.GroupId = uf.GroupId user.IsAdmin = uf.IsAdmin diff --git a/http/response/admin/user.go b/http/response/admin/user.go index d3941e8..857fe69 100644 --- a/http/response/admin/user.go +++ b/http/response/admin/user.go @@ -4,6 +4,8 @@ import "Gwen/model" type LoginPayload struct { Username string `json:"username"` + Email string `json:"email"` + Avatar string `json:"avatar"` Token string `json:"token"` RouteNames []string `json:"route_names"` Nickname string `json:"nickname"` @@ -15,8 +17,8 @@ var UserRouteNames = []string{ var AdminRouteNames = []string{"*"} type UserOauthItem struct { - ThirdType string `json:"third_type"` - Status int `json:"status"` + Op string `json:"op"` + Status int `json:"status"` } type GroupUsersPayload struct { diff --git a/http/response/api/user.go b/http/response/api/user.go index b48a4be..dc97370 100644 --- a/http/response/api/user.go +++ b/http/response/api/user.go @@ -29,6 +29,7 @@ type UserPayload struct { func (up *UserPayload) FromUser(user *model.User) *UserPayload { up.Name = user.Username + up.Email = user.Email up.IsAdmin = user.IsAdmin up.Status = int(user.Status) up.Info = map[string]interface{}{} diff --git a/model/oauth.go b/model/oauth.go index 35e7b96..816a01c 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -1,23 +1,110 @@ package model +import ( + "strconv" + "fmt" +) + + +const ( + OauthTypeGithub string = "github" + OauthTypeGoogle string = "google" + OauthTypeOidc string = "oidc" + OauthTypeWebauth string = "webauth" +) + + type Oauth struct { IdModel - Op string `json:"op"` - ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` - RedirectUrl string `json:"redirect_url"` - AutoRegister *bool `json:"auto_register"` - Scopes string `json:"scopes"` - Issuer string `json:"issuer"` + Op string `json:"op"` + OauthType string `json:"oauth_type"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RedirectUrl string `json:"redirect_url"` + AutoRegister *bool `json:"auto_register"` + Scopes string `json:"scopes"` + Issuer string `json:"issuer"` TimeModel } -const ( - OauthTypeGithub = "github" - OauthTypeGoogle = "google" - OauthTypeOidc = "oidc" - OauthTypeWebauth = "webauth" -) +type OauthUser struct { + OpenId string `json:"open_id" gorm:"not null;index"` + Name string `json:"name"` + Username string `json:"username"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email,omitempty"` +} + +func (ou *OauthUser) ToUser(user *User, overideUsername bool) { + if overideUsername { + user.Username = ou.Username + } + user.Email = ou.Email + user.Nickname = ou.Name + +} + + +type OauthUserBase struct { + Name string `json:"name"` + Email string `json:"email"` +} + + +type OidcUser struct { + OauthUserBase + Sub string `json:"sub"` + VerifiedEmail bool `json:"email_verified"` + PreferredUsername string `json:"preferred_username"` +} + +func (ou *OidcUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: ou.Sub, + Name: ou.Name, + Username: ou.PreferredUsername, + Email: ou.Email, + VerifiedEmail: ou.VerifiedEmail, + } +} + +type GoogleUser struct { + OauthUserBase + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + Id string `json:"id"` + Picture string `json:"picture"` + VerifiedEmail bool `json:"verified_email"` +} + +func (gu *GoogleUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: gu.Id, + Name: fmt.Sprintf("%s %s", gu.GivenName, gu.FamilyName), + Username: gu.GivenName, + Email: gu.Email, + VerifiedEmail: gu.VerifiedEmail, + } +} + + +type GithubUser struct { + OauthUserBase + Id int `json:"id"` + Login string `json:"login"` +} + +func (gu *GithubUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: strconv.Itoa(gu.Id), + Name: gu.Name, + Username: gu.Login, + Email: gu.Email, + VerifiedEmail: true, + } +} + + type OauthList struct { Oauths []*Oauth `json:"list"` diff --git a/model/user.go b/model/user.go index 1d83458..4be0049 100644 --- a/model/user.go +++ b/model/user.go @@ -1,8 +1,15 @@ package model +import ( + "fmt" + "gorm.io/gorm" +) + type User struct { IdModel Username string `json:"username" gorm:"default:'';not null;uniqueIndex"` + Email string `json:"email" gorm:"default:'';not null;uniqueIndex"` + // Email string `json:"email" ` Password string `json:"-" gorm:"default:'';not null;"` Nickname string `json:"nickname" gorm:"default:'';not null;"` Avatar string `json:"avatar" gorm:"default:'';not null;"` @@ -12,6 +19,15 @@ type User struct { TimeModel } +// BeforeSave 钩子用于确保 email 字段有合理的默认值 +func (u *User) BeforeSave(tx *gorm.DB) (err error) { + // 如果 email 为空,设置为默认值 + if u.Email == "" { + u.Email = fmt.Sprintf("%s@example.com", u.Username) + } + return nil +} + type UserList struct { Users []*User `json:"list,omitempty"` Pagination diff --git a/model/userThird.go b/model/userThird.go index 4e967b9..6f7f74a 100644 --- a/model/userThird.go +++ b/model/userThird.go @@ -2,11 +2,18 @@ package model type UserThird struct { IdModel - UserId uint `json:"user_id" gorm:"not null;index"` - OpenId string `json:"open_id" gorm:"not null;index"` - UnionId string `json:"union_id" gorm:"not null;"` - ThirdType string `json:"third_type" gorm:"not null;"` - ThirdEmail string `json:"third_email"` - ThirdName string `json:"third_name"` + UserId uint ` json:"user_id" gorm:"not null;index"` + OauthUser + // UnionId string `json:"union_id" gorm:"not null;"` + // OauthType string `json:"oauth_type" gorm:"not null;"` + OauthType string `json:"oauth_type"` + Op string `json:"op" gorm:"not null;"` TimeModel } + +func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) { + u.UserId = userId + u.OauthUser = *oauthUser + u.OauthType = oauthType + u.Op = op +} \ No newline at end of file diff --git a/service/oauth.go b/service/oauth.go index 2880fef..c28eec5 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -11,15 +11,20 @@ import ( "golang.org/x/oauth2/github" "golang.org/x/oauth2/google" "gorm.io/gorm" - "io" + // "io" "net/http" "net/url" "strconv" "strings" "sync" "time" + "fmt" ) + +type OauthService struct { +} + // Define a struct to parse the .well-known/openid-configuration response type OidcEndpoint struct { Issuer string `json:"issuer"` @@ -28,73 +33,6 @@ type OidcEndpoint struct { UserInfo string `json:"userinfo_endpoint"` } -type OauthService struct { -} - -type GithubUserdata struct { - AvatarUrl string `json:"avatar_url"` - Bio string `json:"bio"` - Blog string `json:"blog"` - Collaborators int `json:"collaborators"` - Company interface{} `json:"company"` - CreatedAt time.Time `json:"created_at"` - DiskUsage int `json:"disk_usage"` - Email interface{} `json:"email"` - EventsUrl string `json:"events_url"` - Followers int `json:"followers"` - FollowersUrl string `json:"followers_url"` - Following int `json:"following"` - FollowingUrl string `json:"following_url"` - GistsUrl string `json:"gists_url"` - GravatarId string `json:"gravatar_id"` - Hireable interface{} `json:"hireable"` - HtmlUrl string `json:"html_url"` - Id int `json:"id"` - Location interface{} `json:"location"` - Login string `json:"login"` - Name string `json:"name"` - NodeId string `json:"node_id"` - NotificationEmail interface{} `json:"notification_email"` - OrganizationsUrl string `json:"organizations_url"` - OwnedPrivateRepos int `json:"owned_private_repos"` - Plan struct { - Collaborators int `json:"collaborators"` - Name string `json:"name"` - PrivateRepos int `json:"private_repos"` - Space int `json:"space"` - } `json:"plan"` - PrivateGists int `json:"private_gists"` - PublicGists int `json:"public_gists"` - PublicRepos int `json:"public_repos"` - ReceivedEventsUrl string `json:"received_events_url"` - ReposUrl string `json:"repos_url"` - SiteAdmin bool `json:"site_admin"` - StarredUrl string `json:"starred_url"` - SubscriptionsUrl string `json:"subscriptions_url"` - TotalPrivateRepos int `json:"total_private_repos"` - //TwitterUsername interface{} `json:"twitter_username"` - TwoFactorAuthentication bool `json:"two_factor_authentication"` - Type string `json:"type"` - UpdatedAt time.Time `json:"updated_at"` - Url string `json:"url"` -} -type GoogleUserdata struct { - Email string `json:"email"` - FamilyName string `json:"family_name"` - GivenName string `json:"given_name"` - Id string `json:"id"` - Name string `json:"name"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` -} -type OidcUserdata struct { - Sub string `json:"sub"` - Email string `json:"email"` - VerifiedEmail bool `json:"email_verified"` - Name string `json:"name"` - PreferredUsername string `json:"preferred_username"` -} - type OauthCacheItem struct { UserId uint `json:"user_id"` Id string `json:"id"` //rustdesk的设备ID @@ -104,9 +42,19 @@ type OauthCacheItem struct { DeviceName string `json:"device_name"` DeviceOs string `json:"device_os"` DeviceType string `json:"device_type"` - ThirdOpenId string `json:"third_open_id"` - ThirdName string `json:"third_name"` - ThirdEmail string `json:"third_email"` + OpenId string `json:"open_id"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` +} + +func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { + return &model.OauthUser{ + OpenId: oci.OpenId, + Username: oci.Username, + Name: oci.Name, + Email: oci.Email, + } } var OauthCache = &sync.Map{} @@ -116,6 +64,24 @@ const ( OauthActionTypeBind = "bind" ) +func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { + oa.OpenId = oauthUser.OpenId + oa.Username = oauthUser.Username + oa.Name = oauthUser.Name + oa.Email = oauthUser.Email +} + +// Validate the oauth type +func (os *OauthService) ValidateOauthType(oauthType string) error { + switch oauthType { + case model.OauthTypeGithub, model.OauthTypeGoogle, model.OauthTypeOidc, model.OauthTypeWebauth: + return nil + default: + return errors.New("invalid Oauth type") + } +} + + func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { v, ok := OauthCache.Load(key) if !ok { @@ -141,12 +107,12 @@ func (os *OauthService) DeleteOauthCache(key string) { func (os *OauthService) BeginAuth(op string) (error error, code, url string) { code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) - if op == model.OauthTypeWebauth { + if op == string(model.OauthTypeWebauth) { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, code, url } - err, conf := os.GetOauthConfig(op) + err, _, conf := os.GetOauthConfig(op) if err == nil { return err, code, conf.AuthCodeURL(code) } @@ -155,7 +121,7 @@ func (os *OauthService) BeginAuth(op string) (error error, code, url string) { } // Method to fetch OIDC configuration dynamically -func FetchOidcConfig(issuer string) (error, OidcEndpoint) { +func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) { configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" // Get the HTTP client (with or without proxy based on configuration) @@ -179,76 +145,55 @@ func FetchOidcConfig(issuer string) (error, OidcEndpoint) { return nil, endpoint } -// GetOauthConfig retrieves the OAuth2 configuration based on the provider type -func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) { - switch op { - case model.OauthTypeGithub: - return os.getGithubConfig() - case model.OauthTypeGoogle: - return os.getGoogleConfig() - case model.OauthTypeOidc: - return os.getOidcConfig() - default: - return errors.New("unsupported OAuth type"), nil +func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { + oauthInfo := os.InfoByOp(op) + if oauthInfo.Issuer == "" { + return errors.New("issuer is empty"), OidcEndpoint{} } + return os.FetchOidcEndpoint(oauthInfo.Issuer) } -// Helper function to get GitHub OAuth2 configuration -func (os *OauthService) getGithubConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeGithub) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" { - return errors.New("ConfigNotFound"), nil - } - return nil, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, - Endpoint: github.Endpoint, - Scopes: []string{"read:user", "user:email"}, - } -} - -// Helper function to get Google OAuth2 configuration -func (os *OauthService) getGoogleConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeGoogle) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" { - return errors.New("ConfigNotFound"), nil - } - return nil, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, - Endpoint: google.Endpoint, - Scopes: []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"}, - } -} - -// Helper function to get OIDC OAuth2 configuration -func (os *OauthService) getOidcConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeOidc) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" || g.Issuer == "" { - return errors.New("ConfigNotFound"), nil - } - - // Set scopes - scopes := strings.TrimSpace(g.Scopes) - if scopes == "" { - scopes = "openid,profile,email" - } - scopeList := strings.Split(scopes, ",") - err, endpoint := FetchOidcConfig(g.Issuer) +// GetOauthConfig retrieves the OAuth2 configuration based on the provider name +func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { + err, oauthType, oauthConfig = os.getOauthConfigGeneral(op) if err != nil { - return err, nil + return err, oauthType, nil } - return nil, &oauth2.Config{ + // Maybe should validate the oauthConfig here + switch oauthType { + case model.OauthTypeGithub: + oauthConfig.Endpoint = github.Endpoint + oauthConfig.Scopes = []string{"read:user", "user:email"} + case model.OauthTypeGoogle: + oauthConfig.Endpoint = google.Endpoint + oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"} + case model.OauthTypeOidc: + err, endpoint := os.FetchOidcEndpointByOp(op) + if err != nil { + return err,oauthType, nil + } + oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} + oauthConfig.Scopes = os.getScopesByOp(op) + default: + return errors.New("unsupported OAuth type"), oauthType, nil + } + return nil, oauthType, oauthConfig +} + +// GetOauthConfig retrieves the OAuth2 configuration based on the provider name +func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { + g := os.InfoByOp(op) + if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" { + return errors.New("ConfigNotFound"), "", nil + } + // If the redirect URL is empty, use the default redirect URL + if g.RedirectUrl == "" { + g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" + } + return nil, g.OauthType, &oauth2.Config{ ClientID: g.ClientId, ClientSecret: g.ClientSecret, RedirectURL: g.RedirectUrl, - Endpoint: oauth2.Endpoint{ - AuthURL: endpoint.AuthURL, - TokenURL: endpoint.TokenURL, - }, - Scopes: scopeList, } } @@ -272,194 +217,161 @@ func getHTTPClientWithProxy() *http.Client { return http.DefaultClient } -func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub) +func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error { + err, oauthType, oauthConfig := os.GetOauthConfig(op) if err != nil { - return err, nil + return err + } + + // If the OAuth type is OIDC and the user endpoint is empty + // Fetch the OIDC configuration and get the user endpoint + if oauthType == model.OauthTypeOidc && userEndpoint == "" { + err, endpoint := os.FetchOidcEndpointByOp(op) + if err != nil { + global.Logger.Warn("failed fetching OIDC configuration: ", err) + return errors.New("FetchOidcEndpointError") + } + userEndpoint = endpoint.UserInfo } - // 使用代理配置创建 HTTP 客户端 + // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + // 使用 code 换取 token token, err := oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return + return errors.New("GetOauthTokenError") } - // 使用带有代理的 HTTP 客户端获取用户信息 + // 获取用户信息 client := oauthConfig.Client(ctx, token) - resp, err := client.Get("https://api.github.com/user") + resp, err := client.Get(userEndpoint) if err != nil { global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return + return errors.New("GetOauthUserInfoError") } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + global.Logger.Warn("failed closing response body: ", closeErr) } - }(resp.Body) + }() // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { + if err = json.NewDecoder(resp.Body).Decode(userData); err != nil { global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return + return errors.New("DecodeOauthUserInfoError") } - return + + return nil } -func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle) - if err != nil { +// githubCallback github回调 +func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) { + var user = &model.GithubUser{} + const userEndpoint = "https://api.github.com/user" + if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil { return err, nil } - - // 使用代理配置创建 HTTP 客户端 - httpClient := getHTTPClientWithProxy() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - - token, err := oauthConfig.Exchange(ctx, code) - if err != nil { - global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return - } - - // 使用带有代理的 HTTP 客户端获取用户信息 - client := oauthConfig.Client(ctx, token) - resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") - if err != nil { - global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) - } - }(resp.Body) - - // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { - global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return - } - return + return nil, user.ToOauthUser() } -func (os *OauthService) OidcCallback(code string) (error error, userData *OidcUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeOidc) - if err != nil { +// googleCallback google回调 +func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) { + var user = &model.GoogleUser{} + const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo" + if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil { return err, nil } - // 使用代理配置创建 HTTP 客户端 - httpClient := getHTTPClientWithProxy() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - - token, err := oauthConfig.Exchange(ctx, code) - if err != nil { - global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return - } - - // 使用带有代理的 HTTP 客户端获取用户信息 - client := oauthConfig.Client(ctx, token) - g := os.InfoByOp(model.OauthTypeOidc) - err, endpoint := FetchOidcConfig(g.Issuer) - if err != nil { - global.Logger.Warn("failed fetching OIDC configuration: ", err) - error = errors.New("FetchOidcConfigError") - return - } - resp, err := client.Get(endpoint.UserInfo) - if err != nil { - global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) - } - }(resp.Body) - - // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { - global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return - } - return + return nil, user.ToOauthUser() } -func (os *OauthService) UserThirdInfo(op, openid string) *model.UserThird { +// oidcCallback oidc回调, 通过code获取用户信息 +func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) { + var user = &model.OidcUser{} + if err := os.callbackBase(op, code, "", user); err != nil { + return err, nil + } + return nil, user.ToOauthUser() +} + +// Callback: Get user information by code and op(Oauth provider) +func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { + oauthType := os.GetTypeByOp(op) + if err = os.ValidateOauthType(oauthType); err != nil { + return err, nil + } + + switch oauthType { + case model.OauthTypeGithub: + err, oauthUser = os.githubCallback(code) + case model.OauthTypeGoogle: + err, oauthUser = os.googleCallback(code) + case model.OauthTypeOidc: + err, oauthUser = os.oidcCallback(code, op) + default: + return errors.New("unsupported OAuth type"), nil + } + + return err, oauthUser +} + + +func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("open_id = ? and third_type = ?", openid, op).First(ut) + global.DB.Where("open_id = ? and op = ?", openId, op).First(ut) return ut } -func (os *OauthService) BindGithubUser(openid, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeGithub, openid, username, userId) -} - -func (os *OauthService) BindGoogleUser(email, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeGoogle, email, username, userId) -} - -func (os *OauthService) BindOidcUser(sub, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeOidc, sub, username, userId) -} - -func (os *OauthService) BindOauthUser(thirdType, openid, username string, userId uint) error { - utr := &model.UserThird{ - OpenId: openid, - ThirdType: thirdType, - ThirdName: username, - UserId: userId, - } +// BindOauthUser: Bind third party account +func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error { + utr := &model.UserThird{} + oauthType := os.GetTypeByOp(op) + utr.FromOauthUser(userId, oauthUser, oauthType, op) return global.DB.Create(utr).Error } -func (os *OauthService) UnBindGithubUser(userid uint) error { - return os.UnBindThird(model.OauthTypeGithub, userid) +// UnBindOauthUser: Unbind third party account +func (os *OauthService) UnBindOauthUser(userId uint, op string) error { + return os.UnBindThird(op, userId) } -func (os *OauthService) UnBindGoogleUser(userid uint) error { - return os.UnBindThird(model.OauthTypeGoogle, userid) -} -func (os *OauthService) UnBindOidcUser(userid uint) error { - return os.UnBindThird(model.OauthTypeOidc, userid) -} -func (os *OauthService) UnBindThird(thirdType string, userid uint) error { - return global.DB.Where("user_id = ? and third_type = ?", userid, thirdType).Delete(&model.UserThird{}).Error + +// UnBindThird: Unbind third party account +func (os *OauthService) UnBindThird(op string, userId uint) error { + return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error } // DeleteUserByUserId: When user is deleted, delete all third party bindings -func (os *OauthService) DeleteUserByUserId(userid uint) error { - return global.DB.Where("user_id = ?", userid).Delete(&model.UserThird{}).Error +func (os *OauthService) DeleteUserByUserId(userId uint) error { + return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error } -// InfoById 根据id取用户信息 +// InfoById 根据id获取Oauth信息 func (os *OauthService) InfoById(id uint) *model.Oauth { - u := &model.Oauth{} - global.DB.Where("id = ?", id).First(u) - return u + oauthInfo := &model.Oauth{} + global.DB.Where("id = ?", id).First(oauthInfo) + return oauthInfo } -// InfoByOp 根据op取用户信息 +// InfoByOp 根据op获取Oauth信息 func (os *OauthService) InfoByOp(op string) *model.Oauth { - u := &model.Oauth{} - global.DB.Where("op = ?", op).First(u) - return u + oauthInfo := &model.Oauth{} + global.DB.Where("op = ?", op).First(oauthInfo) + return oauthInfo } + +// Helper function to get scopes by operation +func (os *OauthService) getScopesByOp(op string) []string { + scopes := os.InfoByOp(op).Scopes + scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量 + if scopes == "" { + scopes = "openid,profile,email" + } + return strings.Split(scopes, ",") +} + + func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { res = &model.OauthList{} res.Page = int64(page) @@ -474,16 +386,41 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res return } +// GetTypeByOp 根据op获取OauthType +func (os *OauthService) GetTypeByOp(op string) string { + oauthInfo := &model.Oauth{} + if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil { + return "" + } + return oauthInfo.OauthType +} + +func (os *OauthService) ValidateOauthProvider(op string) error { + oauthInfo := &model.Oauth{} + // 使用 Gorm 的 Take 方法查找符合条件的记录 + if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { + return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err) + } + return nil +} + // Create 创建 -func (os *OauthService) Create(u *model.Oauth) error { - res := global.DB.Create(u).Error +func (os *OauthService) Create(oauthInfo *model.Oauth) error { + res := global.DB.Create(oauthInfo).Error return res } -func (os *OauthService) Delete(u *model.Oauth) error { - return global.DB.Delete(u).Error +func (os *OauthService) Delete(oauthInfo *model.Oauth) error { + return global.DB.Delete(oauthInfo).Error } // Update 更新 -func (os *OauthService) Update(u *model.Oauth) error { - return global.DB.Model(u).Updates(u).Error +func (os *OauthService) Update(oauthInfo *model.Oauth) error { + return global.DB.Model(oauthInfo).Updates(oauthInfo).Error } + +// GetOauthProviders 获取所有的provider +func (os *OauthService) GetOauthProviders() []string { + var res []string + global.DB.Model(&model.Oauth{}).Pluck("op", &res) + return res +} \ No newline at end of file diff --git a/service/user.go b/service/user.go index fc3ffa5..5e50af5 100644 --- a/service/user.go +++ b/service/user.go @@ -21,12 +21,20 @@ func (us *UserService) InfoById(id uint) *model.User { global.DB.Where("id = ?", id).First(u) return u } +// InfoByUsername 根据用户名取用户信息 func (us *UserService) InfoByUsername(un string) *model.User { u := &model.User{} global.DB.Where("username = ?", un).First(u) return u } +// InfoByEmail 根据邮箱取用户信息 +func (us *UserService) InfoByEmail(email string) *model.User { + u := &model.User{} + global.DB.Where("email = ?", email).First(u) + return u +} + // InfoByOpenid 根据openid取用户信息 func (us *UserService) InfoByOpenid(openid string) *model.User { u := &model.User{} @@ -216,24 +224,9 @@ func (us *UserService) RouteNames(u *model.User) []string { return adResp.UserRouteNames } -// InfoByGithubId 根据githubid取用户信息 -func (us *UserService) InfoByGithubId(githubId string) *model.User { - return us.InfoByOauthId(model.OauthTypeGithub, githubId) -} - -// InfoByGoogleEmail 根据googleid取用户信息 -func (us *UserService) InfoByGoogleEmail(email string) *model.User { - return us.InfoByOauthId(model.OauthTypeGithub, email) -} - -// InfoByOidcSub 根据oidc取用户信息 -func (us *UserService) InfoByOidcSub(sub string) *model.User { - return us.InfoByOauthId(model.OauthTypeOidc, sub) -} - -// InfoByOauthId 根据oauth取用户信息 -func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User { - ut := AllService.OauthService.UserThirdInfo(thirdType, uid) +// InfoByOauthId 根据oauth的name和openId取用户信息 +func (us *UserService) InfoByOauthId(op string, openId string) *model.User { + ut := AllService.OauthService.UserThirdInfo(op, openId) if ut.Id == 0 { return nil } @@ -244,55 +237,40 @@ func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User { return u } -// RegisterByGithub 注册 -func (us *UserService) RegisterByGithub(githubName string, githubId string) *model.User { - return us.RegisterByOauth(model.OauthTypeGithub, githubName, githubId) -} - -// RegisterByGoogle 注册 -func (us *UserService) RegisterByGoogle(name string, email string) *model.User { - return us.RegisterByOauth(model.OauthTypeGoogle, name, email) -} - -// RegisterByOidc 注册, use PreferredUsername as username, sub as openid -func (us *UserService) RegisterByOidc(PreferredUsername string, sub string) *model.User { - return us.RegisterByOauth(model.OauthTypeOidc, PreferredUsername, sub) -} - // RegisterByOauth 注册 -func (us *UserService) RegisterByOauth(thirdType, thirdName, uid string) *model.User { +func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) *model.User { global.Lock.Lock("registerByOauth") defer global.Lock.UnLock("registerByOauth") - ut := AllService.OauthService.UserThirdInfo(thirdType, uid) + ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) if ut.Id != 0 { - u := &model.User{} - global.DB.Where("id = ?", ut.UserId).First(u) - return u + return us.InfoById(ut.UserId) } - + //check if this email has been registered + email := oauthUser.Email + oauthType := AllService.OauthService.GetTypeByOp(op) + user := us.InfoByEmail(email) tx := global.DB.Begin() - ut = &model.UserThird{ - OpenId: uid, - ThirdName: thirdName, - ThirdType: thirdType, + if user.Id != 0 { + ut.FromOauthUser(user.Id, oauthUser, oauthType, op) + } else { + ut = &model.UserThird{} + ut.FromOauthUser(0, oauthUser, oauthType, op) + usernameUnique := us.GenerateUsernameByOauth(oauthUser.Username) + user := &model.User{ + Username: usernameUnique, + GroupId: 1, + } + oauthUser.ToUser(user, false) + tx.Create(user) + if user.Id == 0 { + tx.Rollback() + return user + } + ut.UserId = user.Id } - - username := us.GenerateUsernameByOauth(thirdName) - u := &model.User{ - Username: username, - GroupId: 1, - } - tx.Create(u) - if u.Id == 0 { - tx.Rollback() - return u - } - - ut.UserId = u.Id tx.Create(ut) - tx.Commit() - return u + return user } // GenerateUsernameByOauth 生成用户名 @@ -314,7 +292,7 @@ func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird) func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("user_id = ? and third_type = ?", userId, op).First(ut) + global.DB.Where("user_id = ? and op = ?", userId, op).First(ut) return ut }