From 485ae54e9e02b0c1df024c5e2ab244e71017800c Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 04:01:28 +0800 Subject: [PATCH 01/26] re-construct oauth --- http/controller/admin/login.go | 2 + http/controller/admin/oauth.go | 42 +-- http/controller/admin/user.go | 4 +- http/controller/api/login.go | 20 +- http/controller/api/ouath.go | 106 +++---- http/request/admin/oauth.go | 39 ++- http/request/admin/user.go | 17 +- http/response/admin/user.go | 6 +- http/response/api/user.go | 1 + model/oauth.go | 113 +++++++- model/user.go | 16 ++ model/userThird.go | 19 +- service/oauth.go | 501 ++++++++++++++------------------- service/user.go | 96 +++---- 14 files changed, 491 insertions(+), 491 deletions(-) diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index ad25393..cd879cb 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -63,6 +63,8 @@ func (ct *Login) Login(c *gin.Context) { response.Success(c, &adResp.LoginPayload{ Token: ut.Token, Username: u.Username, + Email: u.Email, + Avatar: u.Avatar, RouteNames: service.AllService.UserService.RouteNames(u), Nickname: u.Nickname, }) diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index 3444f3e..ade2f3c 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -5,7 +5,6 @@ import ( "Gwen/http/request/admin" adminReq "Gwen/http/request/admin" "Gwen/http/response" - "Gwen/model" "Gwen/service" "github.com/gin-gonic/gin" "strconv" @@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")) return } - v := service.AllService.OauthService.GetOauthCache(j.Code) - if v == nil { + oauthService := service.AllService.OauthService + oauthCache := oauthService.GetOauthCache(j.Code) + if oauthCache == nil { response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired")) return } - u := service.AllService.UserService.CurUser(c) - err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id) + oauthUser := oauthCache.ToOauthUser() + user := service.AllService.UserService.CurUser(c) + err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op) if err != nil { response.Fail(c, 101, response.TranslateMsg(c, "BindFail")) return } - v.UserId = u.Id - service.AllService.OauthService.SetOauthCache(j.Code, v, 0) - response.Success(c, v) + oauthCache.UserId = user.Id + oauthService.SetOauthCache(j.Code, oauthCache, 0) + response.Success(c, oauthCache) } func (o *Oauth) Unbind(c *gin.Context) { @@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) { response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound")) return } - if f.Op == model.OauthTypeGithub { - err = service.AllService.OauthService.UnBindGithubUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } + err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op) + if err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) + return } - if f.Op == model.OauthTypeGoogle { - err = service.AllService.OauthService.UnBindGoogleUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } - } - if f.Op == model.OauthTypeOidc { - err = service.AllService.OauthService.UnBindOidcUser(u.Id) - if err != nil { - response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) - return - } - } - response.Success(c, nil) } diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index 67f71b8..76d0aa6 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -286,10 +286,10 @@ func (ct *User) MyOauth(c *gin.Context) { var res []*adResp.UserOauthItem for _, oa := range oal.Oauths { item := &adResp.UserOauthItem{ - ThirdType: oa.Op, + Op: oa.Op, } for _, ut := range uts { - if ut.ThirdType == oa.Op { + if ut.Op == oa.Op { item.Status = 1 break } diff --git a/http/controller/api/login.go b/http/controller/api/login.go index f3d9481..6c0b95d 100644 --- a/http/controller/api/login.go +++ b/http/controller/api/login.go @@ -83,22 +83,10 @@ func (l *Login) Login(c *gin.Context) { // @Failure 500 {object} response.ErrorResponse // @Router /login-options [get] func (l *Login) LoginOptions(c *gin.Context) { - oauthOks := []string{} - err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeGithub) - } - err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeGoogle) - } - err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc) - if err == nil { - oauthOks = append(oauthOks, model.OauthTypeOidc) - } - oauthOks = append(oauthOks, model.OauthTypeWebauth) + ops := service.AllService.OauthService.GetOauthProviders() + ops = append(ops, model.OauthTypeWebauth) var oidcItems []map[string]string - for _, v := range oauthOks { + for _, v := range ops { oidcItems = append(oidcItems, map[string]string{"name": v}) } common, err := json.Marshal(oidcItems) @@ -108,7 +96,7 @@ func (l *Login) LoginOptions(c *gin.Context) { } var res []string res = append(res, "common-oidc/"+string(common)) - for _, v := range oauthOks { + for _, v := range ops { res = append(res, "oidc/"+v) } c.JSON(http.StatusOK, res) diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 96f7a80..5367117 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -9,8 +9,6 @@ import ( "Gwen/service" "github.com/gin-gonic/gin" "net/http" - "strconv" - "strings" ) type Oauth struct { @@ -32,13 +30,17 @@ func (o *Oauth) OidcAuth(c *gin.Context) { response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error()) return } - //fmt.Println(f) - if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc { - response.Error(c, response.TranslateMsg(c, "ParamsError")) + + oauthService := service.AllService.OauthService + err = oauthService.ValidateOauthProvider(f.Op) + if err != nil { + response.Error(c, response.TranslateMsg(c, err.Error())) return } - err, code, url := service.AllService.OauthService.BeginAuth(f.Op) + var code string + var url string + err, code, url = oauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return @@ -149,70 +151,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) { c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state")) return } - cacheKey := state + oauthService := service.AllService.OauthService //从缓存中获取 - v := service.AllService.OauthService.GetOauthCache(cacheKey) - if v == nil { + oauthCache := oauthService.GetOauthCache(cacheKey) + if oauthCache == nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired")) return } - - ty := v.Op - ac := v.Action - var u *model.User - openid := "" - thirdName := "" - //fmt.Println("ty ac ", ty, ac) - - if ty == model.OauthTypeGithub { - code := c.Query("code") - err, userData := service.AllService.OauthService.GithubCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = strconv.Itoa(userData.Id) - thirdName = userData.Login - } else if ty == model.OauthTypeGoogle { - code := c.Query("code") - err, userData := service.AllService.OauthService.GoogleCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = userData.Email - //将空格替换成_ - thirdName = strings.Replace(userData.Name, " ", "_", -1) - } else if ty == model.OauthTypeOidc { - code := c.Query("code") - err, userData := service.AllService.OauthService.OidcCallback(code) - if err != nil { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) - return - } - openid = userData.Sub - thirdName = userData.PreferredUsername - } else { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError")) + op := oauthCache.Op + action := oauthCache.Action + var user *model.User + // 获取用户信息 + code := c.Query("code") + err, oauthUser := oauthService.Callback(code, op) + if err != nil { + c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) return } - if ac == service.OauthActionTypeBind { + userId := oauthCache.UserId + openid := oauthUser.OpenId + if action == service.OauthActionTypeBind { //fmt.Println("bind", ty, userData) - utr := service.AllService.OauthService.UserThirdInfo(ty, openid) + // 检查此openid是否已经绑定过 + utr := oauthService.UserThirdInfo(op, openid) if utr.UserId > 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser")) return } //绑定 - u = service.AllService.UserService.InfoById(v.UserId) - if u == nil { + user = service.AllService.UserService.InfoById(userId) + if user == nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound")) return } //绑定 - err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId) + err := oauthService.BindOauthUser(userId, oauthUser, op) if err != nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail")) return @@ -220,42 +195,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) { c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess")) return - } else if ac == service.OauthActionTypeLogin { + } else if action == service.OauthActionTypeLogin { //登录 - if v.UserId != 0 { + if userId != 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess")) return } - u = service.AllService.UserService.InfoByGithubId(openid) - if u == nil { - oa := service.AllService.OauthService.InfoByOp(ty) - if !*oa.AutoRegister { + user = service.AllService.UserService.InfoByOauthId(op, openid) + if user == nil { + oauthConfig := oauthService.InfoByOp(op) + if !*oauthConfig.AutoRegister { //c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定") - v.ThirdName = thirdName - v.ThirdOpenId = openid + oauthCache.UpdateFromOauthUser(oauthUser) url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey c.Redirect(http.StatusFound, url) return } //自动注册 - u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid) - if u.Id == 0 { + user = service.AllService.UserService.RegisterByOauth(oauthUser, op) + if user.Id == 0 { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed")) return } } - v.UserId = u.Id - service.AllService.OauthService.SetOauthCache(cacheKey, v, 0) + oauthCache.UserId = user.Id + oauthService.SetOauthCache(cacheKey, oauthCache, 0) // 如果是webadmin,登录成功后跳转到webadmin - if v.DeviceType == "webadmin" { + if oauthCache.DeviceType == "webadmin" { /*service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: "webadmin", Uuid: "", //must be empty Ip: c.ClientIP(), Type: model.LoginLogTypeOauth, - Platform: v.DeviceOs, + Platform: oauthService.DeviceOs, })*/ url := global.Config.Rustdesk.ApiServer + "/_admin/#/" c.Redirect(http.StatusFound, url) diff --git a/http/request/admin/oauth.go b/http/request/admin/oauth.go index db519f8..8d8ad37 100644 --- a/http/request/admin/oauth.go +++ b/http/request/admin/oauth.go @@ -1,6 +1,9 @@ package admin -import "Gwen/model" +import ( + "Gwen/model" + "strings" +) type BindOauthForm struct { Op string `json:"op" binding:"required"` @@ -13,19 +16,37 @@ type UnBindOauthForm struct { Op string `json:"op" binding:"required"` } type OauthForm struct { - Id uint `json:"id"` - Op string `json:"op" validate:"required"` - Issuer string `json:"issuer" validate:"omitempty,url"` - Scopes string `json:"scopes" validate:"omitempty"` - ClientId string `json:"client_id" validate:"required"` - ClientSecret string `json:"client_secret" validate:"required"` - RedirectUrl string `json:"redirect_url" validate:"required"` - AutoRegister *bool `json:"auto_register"` + Id uint `json:"id"` + Op string `json:"op" validate:"omitempty"` + OauthType string `json:"oauth_type" validate:"required"` + Issuer string `json:"issuer" validate:"omitempty,url"` + Scopes string `json:"scopes" validate:"omitempty"` + ClientId string `json:"client_id" validate:"required"` + ClientSecret string `json:"client_secret" validate:"required"` + RedirectUrl string `json:"redirect_url" validate:"required"` + AutoRegister *bool `json:"auto_register"` } func (of *OauthForm) ToOauth() *model.Oauth { + op := strings.ToLower(of.Op) + op = strings.TrimSpace(op) + if op == "" { + switch of.OauthType { + case model.OauthTypeGithub: + of.Op = "GitHub" + case model.OauthTypeGoogle: + of.Op = "Google" + case model.OauthTypeOidc: + of.Op = "OIDC" + case model.OauthTypeWebauth: + of.Op = "WebAuth" + default: + of.Op = of.OauthType + } + } oa := &model.Oauth{ Op: of.Op, + OauthType: of.OauthType, ClientId: of.ClientId, ClientSecret: of.ClientSecret, RedirectUrl: of.RedirectUrl, diff --git a/http/request/admin/user.go b/http/request/admin/user.go index e29133c..f8093c5 100644 --- a/http/request/admin/user.go +++ b/http/request/admin/user.go @@ -5,20 +5,22 @@ import ( ) type UserForm struct { - Id uint `json:"id"` - Username string `json:"username" validate:"required,gte=4,lte=10"` + Id uint `json:"id"` + Username string `json:"username" validate:"required,gte=4,lte=10"` + Email string `json:"email" validate:"required,email"` //Password string `json:"password" validate:"required,gte=4,lte=20"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - GroupId uint `json:"group_id" validate:"required"` - IsAdmin *bool `json:"is_admin" ` - Status model.StatusCode `json:"status" validate:"required,gte=0"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + GroupId uint `json:"group_id" validate:"required"` + IsAdmin *bool `json:"is_admin" ` + Status model.StatusCode `json:"status" validate:"required,gte=0"` } func (uf *UserForm) FromUser(user *model.User) *UserForm { uf.Id = user.Id uf.Username = user.Username uf.Nickname = user.Nickname + uf.Email = user.Email uf.Avatar = user.Avatar uf.GroupId = user.GroupId uf.IsAdmin = user.IsAdmin @@ -30,6 +32,7 @@ func (uf *UserForm) ToUser() *model.User { user.Id = uf.Id user.Username = uf.Username user.Nickname = uf.Nickname + user.Email = uf.Email user.Avatar = uf.Avatar user.GroupId = uf.GroupId user.IsAdmin = uf.IsAdmin diff --git a/http/response/admin/user.go b/http/response/admin/user.go index d3941e8..857fe69 100644 --- a/http/response/admin/user.go +++ b/http/response/admin/user.go @@ -4,6 +4,8 @@ import "Gwen/model" type LoginPayload struct { Username string `json:"username"` + Email string `json:"email"` + Avatar string `json:"avatar"` Token string `json:"token"` RouteNames []string `json:"route_names"` Nickname string `json:"nickname"` @@ -15,8 +17,8 @@ var UserRouteNames = []string{ var AdminRouteNames = []string{"*"} type UserOauthItem struct { - ThirdType string `json:"third_type"` - Status int `json:"status"` + Op string `json:"op"` + Status int `json:"status"` } type GroupUsersPayload struct { diff --git a/http/response/api/user.go b/http/response/api/user.go index b48a4be..dc97370 100644 --- a/http/response/api/user.go +++ b/http/response/api/user.go @@ -29,6 +29,7 @@ type UserPayload struct { func (up *UserPayload) FromUser(user *model.User) *UserPayload { up.Name = user.Username + up.Email = user.Email up.IsAdmin = user.IsAdmin up.Status = int(user.Status) up.Info = map[string]interface{}{} diff --git a/model/oauth.go b/model/oauth.go index 35e7b96..816a01c 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -1,23 +1,110 @@ package model +import ( + "strconv" + "fmt" +) + + +const ( + OauthTypeGithub string = "github" + OauthTypeGoogle string = "google" + OauthTypeOidc string = "oidc" + OauthTypeWebauth string = "webauth" +) + + type Oauth struct { IdModel - Op string `json:"op"` - ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` - RedirectUrl string `json:"redirect_url"` - AutoRegister *bool `json:"auto_register"` - Scopes string `json:"scopes"` - Issuer string `json:"issuer"` + Op string `json:"op"` + OauthType string `json:"oauth_type"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RedirectUrl string `json:"redirect_url"` + AutoRegister *bool `json:"auto_register"` + Scopes string `json:"scopes"` + Issuer string `json:"issuer"` TimeModel } -const ( - OauthTypeGithub = "github" - OauthTypeGoogle = "google" - OauthTypeOidc = "oidc" - OauthTypeWebauth = "webauth" -) +type OauthUser struct { + OpenId string `json:"open_id" gorm:"not null;index"` + Name string `json:"name"` + Username string `json:"username"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email,omitempty"` +} + +func (ou *OauthUser) ToUser(user *User, overideUsername bool) { + if overideUsername { + user.Username = ou.Username + } + user.Email = ou.Email + user.Nickname = ou.Name + +} + + +type OauthUserBase struct { + Name string `json:"name"` + Email string `json:"email"` +} + + +type OidcUser struct { + OauthUserBase + Sub string `json:"sub"` + VerifiedEmail bool `json:"email_verified"` + PreferredUsername string `json:"preferred_username"` +} + +func (ou *OidcUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: ou.Sub, + Name: ou.Name, + Username: ou.PreferredUsername, + Email: ou.Email, + VerifiedEmail: ou.VerifiedEmail, + } +} + +type GoogleUser struct { + OauthUserBase + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + Id string `json:"id"` + Picture string `json:"picture"` + VerifiedEmail bool `json:"verified_email"` +} + +func (gu *GoogleUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: gu.Id, + Name: fmt.Sprintf("%s %s", gu.GivenName, gu.FamilyName), + Username: gu.GivenName, + Email: gu.Email, + VerifiedEmail: gu.VerifiedEmail, + } +} + + +type GithubUser struct { + OauthUserBase + Id int `json:"id"` + Login string `json:"login"` +} + +func (gu *GithubUser) ToOauthUser() *OauthUser { + return &OauthUser{ + OpenId: strconv.Itoa(gu.Id), + Name: gu.Name, + Username: gu.Login, + Email: gu.Email, + VerifiedEmail: true, + } +} + + type OauthList struct { Oauths []*Oauth `json:"list"` diff --git a/model/user.go b/model/user.go index 1d83458..4be0049 100644 --- a/model/user.go +++ b/model/user.go @@ -1,8 +1,15 @@ package model +import ( + "fmt" + "gorm.io/gorm" +) + type User struct { IdModel Username string `json:"username" gorm:"default:'';not null;uniqueIndex"` + Email string `json:"email" gorm:"default:'';not null;uniqueIndex"` + // Email string `json:"email" ` Password string `json:"-" gorm:"default:'';not null;"` Nickname string `json:"nickname" gorm:"default:'';not null;"` Avatar string `json:"avatar" gorm:"default:'';not null;"` @@ -12,6 +19,15 @@ type User struct { TimeModel } +// BeforeSave 钩子用于确保 email 字段有合理的默认值 +func (u *User) BeforeSave(tx *gorm.DB) (err error) { + // 如果 email 为空,设置为默认值 + if u.Email == "" { + u.Email = fmt.Sprintf("%s@example.com", u.Username) + } + return nil +} + type UserList struct { Users []*User `json:"list,omitempty"` Pagination diff --git a/model/userThird.go b/model/userThird.go index 4e967b9..6f7f74a 100644 --- a/model/userThird.go +++ b/model/userThird.go @@ -2,11 +2,18 @@ package model type UserThird struct { IdModel - UserId uint `json:"user_id" gorm:"not null;index"` - OpenId string `json:"open_id" gorm:"not null;index"` - UnionId string `json:"union_id" gorm:"not null;"` - ThirdType string `json:"third_type" gorm:"not null;"` - ThirdEmail string `json:"third_email"` - ThirdName string `json:"third_name"` + UserId uint ` json:"user_id" gorm:"not null;index"` + OauthUser + // UnionId string `json:"union_id" gorm:"not null;"` + // OauthType string `json:"oauth_type" gorm:"not null;"` + OauthType string `json:"oauth_type"` + Op string `json:"op" gorm:"not null;"` TimeModel } + +func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) { + u.UserId = userId + u.OauthUser = *oauthUser + u.OauthType = oauthType + u.Op = op +} \ No newline at end of file diff --git a/service/oauth.go b/service/oauth.go index 2880fef..c28eec5 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -11,15 +11,20 @@ import ( "golang.org/x/oauth2/github" "golang.org/x/oauth2/google" "gorm.io/gorm" - "io" + // "io" "net/http" "net/url" "strconv" "strings" "sync" "time" + "fmt" ) + +type OauthService struct { +} + // Define a struct to parse the .well-known/openid-configuration response type OidcEndpoint struct { Issuer string `json:"issuer"` @@ -28,73 +33,6 @@ type OidcEndpoint struct { UserInfo string `json:"userinfo_endpoint"` } -type OauthService struct { -} - -type GithubUserdata struct { - AvatarUrl string `json:"avatar_url"` - Bio string `json:"bio"` - Blog string `json:"blog"` - Collaborators int `json:"collaborators"` - Company interface{} `json:"company"` - CreatedAt time.Time `json:"created_at"` - DiskUsage int `json:"disk_usage"` - Email interface{} `json:"email"` - EventsUrl string `json:"events_url"` - Followers int `json:"followers"` - FollowersUrl string `json:"followers_url"` - Following int `json:"following"` - FollowingUrl string `json:"following_url"` - GistsUrl string `json:"gists_url"` - GravatarId string `json:"gravatar_id"` - Hireable interface{} `json:"hireable"` - HtmlUrl string `json:"html_url"` - Id int `json:"id"` - Location interface{} `json:"location"` - Login string `json:"login"` - Name string `json:"name"` - NodeId string `json:"node_id"` - NotificationEmail interface{} `json:"notification_email"` - OrganizationsUrl string `json:"organizations_url"` - OwnedPrivateRepos int `json:"owned_private_repos"` - Plan struct { - Collaborators int `json:"collaborators"` - Name string `json:"name"` - PrivateRepos int `json:"private_repos"` - Space int `json:"space"` - } `json:"plan"` - PrivateGists int `json:"private_gists"` - PublicGists int `json:"public_gists"` - PublicRepos int `json:"public_repos"` - ReceivedEventsUrl string `json:"received_events_url"` - ReposUrl string `json:"repos_url"` - SiteAdmin bool `json:"site_admin"` - StarredUrl string `json:"starred_url"` - SubscriptionsUrl string `json:"subscriptions_url"` - TotalPrivateRepos int `json:"total_private_repos"` - //TwitterUsername interface{} `json:"twitter_username"` - TwoFactorAuthentication bool `json:"two_factor_authentication"` - Type string `json:"type"` - UpdatedAt time.Time `json:"updated_at"` - Url string `json:"url"` -} -type GoogleUserdata struct { - Email string `json:"email"` - FamilyName string `json:"family_name"` - GivenName string `json:"given_name"` - Id string `json:"id"` - Name string `json:"name"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` -} -type OidcUserdata struct { - Sub string `json:"sub"` - Email string `json:"email"` - VerifiedEmail bool `json:"email_verified"` - Name string `json:"name"` - PreferredUsername string `json:"preferred_username"` -} - type OauthCacheItem struct { UserId uint `json:"user_id"` Id string `json:"id"` //rustdesk的设备ID @@ -104,9 +42,19 @@ type OauthCacheItem struct { DeviceName string `json:"device_name"` DeviceOs string `json:"device_os"` DeviceType string `json:"device_type"` - ThirdOpenId string `json:"third_open_id"` - ThirdName string `json:"third_name"` - ThirdEmail string `json:"third_email"` + OpenId string `json:"open_id"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` +} + +func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { + return &model.OauthUser{ + OpenId: oci.OpenId, + Username: oci.Username, + Name: oci.Name, + Email: oci.Email, + } } var OauthCache = &sync.Map{} @@ -116,6 +64,24 @@ const ( OauthActionTypeBind = "bind" ) +func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { + oa.OpenId = oauthUser.OpenId + oa.Username = oauthUser.Username + oa.Name = oauthUser.Name + oa.Email = oauthUser.Email +} + +// Validate the oauth type +func (os *OauthService) ValidateOauthType(oauthType string) error { + switch oauthType { + case model.OauthTypeGithub, model.OauthTypeGoogle, model.OauthTypeOidc, model.OauthTypeWebauth: + return nil + default: + return errors.New("invalid Oauth type") + } +} + + func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { v, ok := OauthCache.Load(key) if !ok { @@ -141,12 +107,12 @@ func (os *OauthService) DeleteOauthCache(key string) { func (os *OauthService) BeginAuth(op string) (error error, code, url string) { code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) - if op == model.OauthTypeWebauth { + if op == string(model.OauthTypeWebauth) { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, code, url } - err, conf := os.GetOauthConfig(op) + err, _, conf := os.GetOauthConfig(op) if err == nil { return err, code, conf.AuthCodeURL(code) } @@ -155,7 +121,7 @@ func (os *OauthService) BeginAuth(op string) (error error, code, url string) { } // Method to fetch OIDC configuration dynamically -func FetchOidcConfig(issuer string) (error, OidcEndpoint) { +func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) { configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" // Get the HTTP client (with or without proxy based on configuration) @@ -179,76 +145,55 @@ func FetchOidcConfig(issuer string) (error, OidcEndpoint) { return nil, endpoint } -// GetOauthConfig retrieves the OAuth2 configuration based on the provider type -func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) { - switch op { - case model.OauthTypeGithub: - return os.getGithubConfig() - case model.OauthTypeGoogle: - return os.getGoogleConfig() - case model.OauthTypeOidc: - return os.getOidcConfig() - default: - return errors.New("unsupported OAuth type"), nil +func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { + oauthInfo := os.InfoByOp(op) + if oauthInfo.Issuer == "" { + return errors.New("issuer is empty"), OidcEndpoint{} } + return os.FetchOidcEndpoint(oauthInfo.Issuer) } -// Helper function to get GitHub OAuth2 configuration -func (os *OauthService) getGithubConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeGithub) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" { - return errors.New("ConfigNotFound"), nil - } - return nil, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, - Endpoint: github.Endpoint, - Scopes: []string{"read:user", "user:email"}, - } -} - -// Helper function to get Google OAuth2 configuration -func (os *OauthService) getGoogleConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeGoogle) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" { - return errors.New("ConfigNotFound"), nil - } - return nil, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, - Endpoint: google.Endpoint, - Scopes: []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"}, - } -} - -// Helper function to get OIDC OAuth2 configuration -func (os *OauthService) getOidcConfig() (error, *oauth2.Config) { - g := os.InfoByOp(model.OauthTypeOidc) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" || g.Issuer == "" { - return errors.New("ConfigNotFound"), nil - } - - // Set scopes - scopes := strings.TrimSpace(g.Scopes) - if scopes == "" { - scopes = "openid,profile,email" - } - scopeList := strings.Split(scopes, ",") - err, endpoint := FetchOidcConfig(g.Issuer) +// GetOauthConfig retrieves the OAuth2 configuration based on the provider name +func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { + err, oauthType, oauthConfig = os.getOauthConfigGeneral(op) if err != nil { - return err, nil + return err, oauthType, nil } - return nil, &oauth2.Config{ + // Maybe should validate the oauthConfig here + switch oauthType { + case model.OauthTypeGithub: + oauthConfig.Endpoint = github.Endpoint + oauthConfig.Scopes = []string{"read:user", "user:email"} + case model.OauthTypeGoogle: + oauthConfig.Endpoint = google.Endpoint + oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"} + case model.OauthTypeOidc: + err, endpoint := os.FetchOidcEndpointByOp(op) + if err != nil { + return err,oauthType, nil + } + oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} + oauthConfig.Scopes = os.getScopesByOp(op) + default: + return errors.New("unsupported OAuth type"), oauthType, nil + } + return nil, oauthType, oauthConfig +} + +// GetOauthConfig retrieves the OAuth2 configuration based on the provider name +func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { + g := os.InfoByOp(op) + if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" { + return errors.New("ConfigNotFound"), "", nil + } + // If the redirect URL is empty, use the default redirect URL + if g.RedirectUrl == "" { + g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" + } + return nil, g.OauthType, &oauth2.Config{ ClientID: g.ClientId, ClientSecret: g.ClientSecret, RedirectURL: g.RedirectUrl, - Endpoint: oauth2.Endpoint{ - AuthURL: endpoint.AuthURL, - TokenURL: endpoint.TokenURL, - }, - Scopes: scopeList, } } @@ -272,194 +217,161 @@ func getHTTPClientWithProxy() *http.Client { return http.DefaultClient } -func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub) +func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error { + err, oauthType, oauthConfig := os.GetOauthConfig(op) if err != nil { - return err, nil + return err + } + + // If the OAuth type is OIDC and the user endpoint is empty + // Fetch the OIDC configuration and get the user endpoint + if oauthType == model.OauthTypeOidc && userEndpoint == "" { + err, endpoint := os.FetchOidcEndpointByOp(op) + if err != nil { + global.Logger.Warn("failed fetching OIDC configuration: ", err) + return errors.New("FetchOidcEndpointError") + } + userEndpoint = endpoint.UserInfo } - // 使用代理配置创建 HTTP 客户端 + // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + // 使用 code 换取 token token, err := oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return + return errors.New("GetOauthTokenError") } - // 使用带有代理的 HTTP 客户端获取用户信息 + // 获取用户信息 client := oauthConfig.Client(ctx, token) - resp, err := client.Get("https://api.github.com/user") + resp, err := client.Get(userEndpoint) if err != nil { global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return + return errors.New("GetOauthUserInfoError") } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + global.Logger.Warn("failed closing response body: ", closeErr) } - }(resp.Body) + }() // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { + if err = json.NewDecoder(resp.Body).Decode(userData); err != nil { global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return + return errors.New("DecodeOauthUserInfoError") } - return + + return nil } -func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle) - if err != nil { +// githubCallback github回调 +func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) { + var user = &model.GithubUser{} + const userEndpoint = "https://api.github.com/user" + if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil { return err, nil } - - // 使用代理配置创建 HTTP 客户端 - httpClient := getHTTPClientWithProxy() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - - token, err := oauthConfig.Exchange(ctx, code) - if err != nil { - global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return - } - - // 使用带有代理的 HTTP 客户端获取用户信息 - client := oauthConfig.Client(ctx, token) - resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") - if err != nil { - global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) - } - }(resp.Body) - - // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { - global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return - } - return + return nil, user.ToOauthUser() } -func (os *OauthService) OidcCallback(code string) (error error, userData *OidcUserdata) { - err, oauthConfig := os.GetOauthConfig(model.OauthTypeOidc) - if err != nil { +// googleCallback google回调 +func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) { + var user = &model.GoogleUser{} + const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo" + if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil { return err, nil } - // 使用代理配置创建 HTTP 客户端 - httpClient := getHTTPClientWithProxy() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - - token, err := oauthConfig.Exchange(ctx, code) - if err != nil { - global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - error = errors.New("GetOauthTokenError") - return - } - - // 使用带有代理的 HTTP 客户端获取用户信息 - client := oauthConfig.Client(ctx, token) - g := os.InfoByOp(model.OauthTypeOidc) - err, endpoint := FetchOidcConfig(g.Issuer) - if err != nil { - global.Logger.Warn("failed fetching OIDC configuration: ", err) - error = errors.New("FetchOidcConfigError") - return - } - resp, err := client.Get(endpoint.UserInfo) - if err != nil { - global.Logger.Warn("failed getting user info: ", err) - error = errors.New("GetOauthUserInfoError") - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - global.Logger.Warn("failed closing response body: ", err) - } - }(resp.Body) - - // 解析用户信息 - if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { - global.Logger.Warn("failed decoding user info: ", err) - error = errors.New("DecodeOauthUserInfoError") - return - } - return + return nil, user.ToOauthUser() } -func (os *OauthService) UserThirdInfo(op, openid string) *model.UserThird { +// oidcCallback oidc回调, 通过code获取用户信息 +func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) { + var user = &model.OidcUser{} + if err := os.callbackBase(op, code, "", user); err != nil { + return err, nil + } + return nil, user.ToOauthUser() +} + +// Callback: Get user information by code and op(Oauth provider) +func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { + oauthType := os.GetTypeByOp(op) + if err = os.ValidateOauthType(oauthType); err != nil { + return err, nil + } + + switch oauthType { + case model.OauthTypeGithub: + err, oauthUser = os.githubCallback(code) + case model.OauthTypeGoogle: + err, oauthUser = os.googleCallback(code) + case model.OauthTypeOidc: + err, oauthUser = os.oidcCallback(code, op) + default: + return errors.New("unsupported OAuth type"), nil + } + + return err, oauthUser +} + + +func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("open_id = ? and third_type = ?", openid, op).First(ut) + global.DB.Where("open_id = ? and op = ?", openId, op).First(ut) return ut } -func (os *OauthService) BindGithubUser(openid, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeGithub, openid, username, userId) -} - -func (os *OauthService) BindGoogleUser(email, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeGoogle, email, username, userId) -} - -func (os *OauthService) BindOidcUser(sub, username string, userId uint) error { - return os.BindOauthUser(model.OauthTypeOidc, sub, username, userId) -} - -func (os *OauthService) BindOauthUser(thirdType, openid, username string, userId uint) error { - utr := &model.UserThird{ - OpenId: openid, - ThirdType: thirdType, - ThirdName: username, - UserId: userId, - } +// BindOauthUser: Bind third party account +func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error { + utr := &model.UserThird{} + oauthType := os.GetTypeByOp(op) + utr.FromOauthUser(userId, oauthUser, oauthType, op) return global.DB.Create(utr).Error } -func (os *OauthService) UnBindGithubUser(userid uint) error { - return os.UnBindThird(model.OauthTypeGithub, userid) +// UnBindOauthUser: Unbind third party account +func (os *OauthService) UnBindOauthUser(userId uint, op string) error { + return os.UnBindThird(op, userId) } -func (os *OauthService) UnBindGoogleUser(userid uint) error { - return os.UnBindThird(model.OauthTypeGoogle, userid) -} -func (os *OauthService) UnBindOidcUser(userid uint) error { - return os.UnBindThird(model.OauthTypeOidc, userid) -} -func (os *OauthService) UnBindThird(thirdType string, userid uint) error { - return global.DB.Where("user_id = ? and third_type = ?", userid, thirdType).Delete(&model.UserThird{}).Error + +// UnBindThird: Unbind third party account +func (os *OauthService) UnBindThird(op string, userId uint) error { + return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error } // DeleteUserByUserId: When user is deleted, delete all third party bindings -func (os *OauthService) DeleteUserByUserId(userid uint) error { - return global.DB.Where("user_id = ?", userid).Delete(&model.UserThird{}).Error +func (os *OauthService) DeleteUserByUserId(userId uint) error { + return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error } -// InfoById 根据id取用户信息 +// InfoById 根据id获取Oauth信息 func (os *OauthService) InfoById(id uint) *model.Oauth { - u := &model.Oauth{} - global.DB.Where("id = ?", id).First(u) - return u + oauthInfo := &model.Oauth{} + global.DB.Where("id = ?", id).First(oauthInfo) + return oauthInfo } -// InfoByOp 根据op取用户信息 +// InfoByOp 根据op获取Oauth信息 func (os *OauthService) InfoByOp(op string) *model.Oauth { - u := &model.Oauth{} - global.DB.Where("op = ?", op).First(u) - return u + oauthInfo := &model.Oauth{} + global.DB.Where("op = ?", op).First(oauthInfo) + return oauthInfo } + +// Helper function to get scopes by operation +func (os *OauthService) getScopesByOp(op string) []string { + scopes := os.InfoByOp(op).Scopes + scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量 + if scopes == "" { + scopes = "openid,profile,email" + } + return strings.Split(scopes, ",") +} + + func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { res = &model.OauthList{} res.Page = int64(page) @@ -474,16 +386,41 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res return } +// GetTypeByOp 根据op获取OauthType +func (os *OauthService) GetTypeByOp(op string) string { + oauthInfo := &model.Oauth{} + if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil { + return "" + } + return oauthInfo.OauthType +} + +func (os *OauthService) ValidateOauthProvider(op string) error { + oauthInfo := &model.Oauth{} + // 使用 Gorm 的 Take 方法查找符合条件的记录 + if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { + return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err) + } + return nil +} + // Create 创建 -func (os *OauthService) Create(u *model.Oauth) error { - res := global.DB.Create(u).Error +func (os *OauthService) Create(oauthInfo *model.Oauth) error { + res := global.DB.Create(oauthInfo).Error return res } -func (os *OauthService) Delete(u *model.Oauth) error { - return global.DB.Delete(u).Error +func (os *OauthService) Delete(oauthInfo *model.Oauth) error { + return global.DB.Delete(oauthInfo).Error } // Update 更新 -func (os *OauthService) Update(u *model.Oauth) error { - return global.DB.Model(u).Updates(u).Error +func (os *OauthService) Update(oauthInfo *model.Oauth) error { + return global.DB.Model(oauthInfo).Updates(oauthInfo).Error } + +// GetOauthProviders 获取所有的provider +func (os *OauthService) GetOauthProviders() []string { + var res []string + global.DB.Model(&model.Oauth{}).Pluck("op", &res) + return res +} \ No newline at end of file diff --git a/service/user.go b/service/user.go index fc3ffa5..5e50af5 100644 --- a/service/user.go +++ b/service/user.go @@ -21,12 +21,20 @@ func (us *UserService) InfoById(id uint) *model.User { global.DB.Where("id = ?", id).First(u) return u } +// InfoByUsername 根据用户名取用户信息 func (us *UserService) InfoByUsername(un string) *model.User { u := &model.User{} global.DB.Where("username = ?", un).First(u) return u } +// InfoByEmail 根据邮箱取用户信息 +func (us *UserService) InfoByEmail(email string) *model.User { + u := &model.User{} + global.DB.Where("email = ?", email).First(u) + return u +} + // InfoByOpenid 根据openid取用户信息 func (us *UserService) InfoByOpenid(openid string) *model.User { u := &model.User{} @@ -216,24 +224,9 @@ func (us *UserService) RouteNames(u *model.User) []string { return adResp.UserRouteNames } -// InfoByGithubId 根据githubid取用户信息 -func (us *UserService) InfoByGithubId(githubId string) *model.User { - return us.InfoByOauthId(model.OauthTypeGithub, githubId) -} - -// InfoByGoogleEmail 根据googleid取用户信息 -func (us *UserService) InfoByGoogleEmail(email string) *model.User { - return us.InfoByOauthId(model.OauthTypeGithub, email) -} - -// InfoByOidcSub 根据oidc取用户信息 -func (us *UserService) InfoByOidcSub(sub string) *model.User { - return us.InfoByOauthId(model.OauthTypeOidc, sub) -} - -// InfoByOauthId 根据oauth取用户信息 -func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User { - ut := AllService.OauthService.UserThirdInfo(thirdType, uid) +// InfoByOauthId 根据oauth的name和openId取用户信息 +func (us *UserService) InfoByOauthId(op string, openId string) *model.User { + ut := AllService.OauthService.UserThirdInfo(op, openId) if ut.Id == 0 { return nil } @@ -244,55 +237,40 @@ func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User { return u } -// RegisterByGithub 注册 -func (us *UserService) RegisterByGithub(githubName string, githubId string) *model.User { - return us.RegisterByOauth(model.OauthTypeGithub, githubName, githubId) -} - -// RegisterByGoogle 注册 -func (us *UserService) RegisterByGoogle(name string, email string) *model.User { - return us.RegisterByOauth(model.OauthTypeGoogle, name, email) -} - -// RegisterByOidc 注册, use PreferredUsername as username, sub as openid -func (us *UserService) RegisterByOidc(PreferredUsername string, sub string) *model.User { - return us.RegisterByOauth(model.OauthTypeOidc, PreferredUsername, sub) -} - // RegisterByOauth 注册 -func (us *UserService) RegisterByOauth(thirdType, thirdName, uid string) *model.User { +func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) *model.User { global.Lock.Lock("registerByOauth") defer global.Lock.UnLock("registerByOauth") - ut := AllService.OauthService.UserThirdInfo(thirdType, uid) + ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) if ut.Id != 0 { - u := &model.User{} - global.DB.Where("id = ?", ut.UserId).First(u) - return u + return us.InfoById(ut.UserId) } - + //check if this email has been registered + email := oauthUser.Email + oauthType := AllService.OauthService.GetTypeByOp(op) + user := us.InfoByEmail(email) tx := global.DB.Begin() - ut = &model.UserThird{ - OpenId: uid, - ThirdName: thirdName, - ThirdType: thirdType, + if user.Id != 0 { + ut.FromOauthUser(user.Id, oauthUser, oauthType, op) + } else { + ut = &model.UserThird{} + ut.FromOauthUser(0, oauthUser, oauthType, op) + usernameUnique := us.GenerateUsernameByOauth(oauthUser.Username) + user := &model.User{ + Username: usernameUnique, + GroupId: 1, + } + oauthUser.ToUser(user, false) + tx.Create(user) + if user.Id == 0 { + tx.Rollback() + return user + } + ut.UserId = user.Id } - - username := us.GenerateUsernameByOauth(thirdName) - u := &model.User{ - Username: username, - GroupId: 1, - } - tx.Create(u) - if u.Id == 0 { - tx.Rollback() - return u - } - - ut.UserId = u.Id tx.Create(ut) - tx.Commit() - return u + return user } // GenerateUsernameByOauth 生成用户名 @@ -314,7 +292,7 @@ func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird) func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird { ut := &model.UserThird{} - global.DB.Where("user_id = ? and third_type = ?", userId, op).First(ut) + global.DB.Where("user_id = ? and op = ?", userId, op).First(ut) return ut } From c021ebfbdff4c82b7dccc554051deb77247fa612 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 04:20:00 +0800 Subject: [PATCH 02/26] fix bug ValidateOauthProvider location --- http/controller/api/ouath.go | 6 ------ service/oauth.go | 4 ++++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 5367117..5f2867f 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -32,12 +32,6 @@ func (o *Oauth) OidcAuth(c *gin.Context) { } oauthService := service.AllService.OauthService - err = oauthService.ValidateOauthProvider(f.Op) - if err != nil { - response.Error(c, response.TranslateMsg(c, err.Error())) - return - } - var code string var url string err, code, url = oauthService.BeginAuth(f.Op) diff --git a/service/oauth.go b/service/oauth.go index c28eec5..446f61c 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -155,6 +155,10 @@ func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { // GetOauthConfig retrieves the OAuth2 configuration based on the provider name func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { + err = os.ValidateOauthProvider(op) + if err != nil { + return err, "", nil + } err, oauthType, oauthConfig = os.getOauthConfigGeneral(op) if err != nil { return err, oauthType, nil From d31d669734e03090e3ff54b686786d25e4372dbd Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 05:07:41 +0800 Subject: [PATCH 03/26] logout should unbind uuid and uid of peer --- model/userToken.go | 7 ++++--- service/peer.go | 15 +++++++++++++++ service/user.go | 29 ++++++++++++++++++++++++----- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/model/userToken.go b/model/userToken.go index f80c4f0..a359eef 100644 --- a/model/userToken.go +++ b/model/userToken.go @@ -2,9 +2,10 @@ package model type UserToken struct { IdModel - UserId uint `json:"user_id" gorm:"default:0;not null;index"` - Token string `json:"token" gorm:"default:'';not null;index"` - ExpiredAt int64 `json:"expired_at" gorm:"default:0;not null;"` + UserId uint `json:"user_id" gorm:"default:0;not null;index"` + DeviceUuid string `json:"device_uuid"` + Token string `json:"token" gorm:"default:'';not null;index"` + ExpiredAt int64 `json:"expired_at" gorm:"default:0;not null;"` TimeModel } diff --git a/service/peer.go b/service/peer.go index 483e128..d64eb99 100644 --- a/service/peer.go +++ b/service/peer.go @@ -26,6 +26,13 @@ func (ps *PeerService) InfoByRowId(id uint) *model.Peer { return p } +// FindByUserIdAndUuid 根据用户id和uuid查找peer +func (ps *PeerService) FindByUserIdAndUuid(uuid string,userId uint) *model.Peer { + p := &model.Peer{} + global.DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p) + return p +} + // UuidBindUserId 绑定用户id func (ps *PeerService) UuidBindUserId(uuid string, userId uint) { peer := ps.FindByUuid(uuid) @@ -35,6 +42,14 @@ func (ps *PeerService) UuidBindUserId(uuid string, userId uint) { } } +// UuidUnbindUserId 解绑用户id, 用于用户注销 +func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) { + peer := ps.FindByUserIdAndUuid(uuid, userId) + if peer.RowId > 0 { + global.DB.Model(peer).Update("user_id", 0) + } +} + // ListByUserIds 根据用户id取列表 func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *model.PeerList) { res = &model.PeerList{} diff --git a/service/user.go b/service/user.go index 5e50af5..c7f29ce 100644 --- a/service/user.go +++ b/service/user.go @@ -73,9 +73,10 @@ func (us *UserService) GenerateToken(u *model.User) string { func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken { token := us.GenerateToken(u) ut := &model.UserToken{ - UserId: u.Id, - Token: token, - ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), + UserId: u.Id, + Token: token, + DeviceUuid: llog.Uuid, + ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), } global.DB.Create(ut) llog.UserTokenId = ut.UserId @@ -153,9 +154,27 @@ func (us *UserService) Create(u *model.User) error { return res } -// Logout 退出登录 +// GetUuidByToken 根据token和user取uuid +func (us *UserService) GetUuidByToken(u *model.User, token string) string { + ut := &model.UserToken{} + err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error + if err != nil { + return "" + } + return ut.DeviceUuid +} + +// Logout 退出登录 -> 删除token, 解绑uuid func (us *UserService) Logout(u *model.User, token string) error { - return global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error + uuid := us.GetUuidByToken(u, token) + err := global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error + if err != nil { + return err + } + if uuid != "" { + AllService.PeerService.UuidUnbindUserId(uuid, u.Id) + } + return nil } // Delete 删除用户和oauth信息 From 71d1c431a9c0bdaa16b6e95e030e07e1644d5a6e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 05:42:47 +0800 Subject: [PATCH 04/26] add DeviceId to userToken --- http/controller/api/login.go | 1 + http/controller/api/ouath.go | 1 + model/loginLog.go | 1 + model/userToken.go | 3 ++- service/user.go | 1 + 5 files changed, 6 insertions(+), 1 deletion(-) diff --git a/http/controller/api/login.go b/http/controller/api/login.go index 6c0b95d..f8be2e3 100644 --- a/http/controller/api/login.go +++ b/http/controller/api/login.go @@ -60,6 +60,7 @@ func (l *Login) Login(c *gin.Context) { ut := service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: f.DeviceInfo.Type, + DeviceId: f.Id, Uuid: f.Uuid, Ip: c.ClientIP(), Type: model.LoginLogTypeAccount, diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 5f2867f..b3a3ba3 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -94,6 +94,7 @@ func (o *Oauth) OidcAuthQueryPre(c *gin.Context) (*model.User, *model.UserToken) ut = service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: v.DeviceType, + DeviceId: v.Id, Uuid: v.Uuid, Ip: c.ClientIP(), Type: model.LoginLogTypeOauth, diff --git a/model/loginLog.go b/model/loginLog.go index 51fd97f..0dbb5f4 100644 --- a/model/loginLog.go +++ b/model/loginLog.go @@ -4,6 +4,7 @@ type LoginLog struct { IdModel UserId uint `json:"user_id" gorm:"default:0;not null;"` Client string `json:"client"` //webadmin,webclient,app, + DeviceId string `json:"device_id"` Uuid string `json:"uuid"` Ip string `json:"ip"` Type string `json:"type"` //account,oauth diff --git a/model/userToken.go b/model/userToken.go index a359eef..fce216e 100644 --- a/model/userToken.go +++ b/model/userToken.go @@ -3,7 +3,8 @@ package model type UserToken struct { IdModel UserId uint `json:"user_id" gorm:"default:0;not null;index"` - DeviceUuid string `json:"device_uuid"` + DeviceUuid string `json:"device_uuid" gorm:"default:'';omitempty;"` + DeviceId string `json:"device_id" gorm:"default:'';omitempty;"` Token string `json:"token" gorm:"default:'';not null;index"` ExpiredAt int64 `json:"expired_at" gorm:"default:0;not null;"` TimeModel diff --git a/service/user.go b/service/user.go index c7f29ce..4074bd1 100644 --- a/service/user.go +++ b/service/user.go @@ -76,6 +76,7 @@ func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserTok UserId: u.Id, Token: token, DeviceUuid: llog.Uuid, + DeviceId: llog.DeviceId, ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), } global.DB.Create(ut) From dfcc7d54c1fcf2da52cf40529dac9f95bbfa38e2 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 05:43:55 +0800 Subject: [PATCH 05/26] add email for register --- http/controller/admin/user.go | 1 + http/request/admin/user.go | 1 + 2 files changed, 2 insertions(+) diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index 76d0aa6..2b1f17c 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -361,6 +361,7 @@ func (ct *User) Register(c *gin.Context) { response.Success(c, &adResp.LoginPayload{ Token: ut.Token, Username: u.Username, + Email: u.Email, RouteNames: service.AllService.UserService.RouteNames(u), Nickname: u.Nickname, }) diff --git a/http/request/admin/user.go b/http/request/admin/user.go index f8093c5..227332d 100644 --- a/http/request/admin/user.go +++ b/http/request/admin/user.go @@ -65,6 +65,7 @@ type GroupUsersQuery struct { type RegisterForm struct { Username string `json:"username" validate:"required,gte=4,lte=10"` + Email string `json:"email" validate:"required,email"` Password string `json:"password" validate:"required,gte=4,lte=20"` ConfirmPassword string `json:"confirm_password" validate:"required,gte=4,lte=20"` } From a4dd39043e7a03ee9f6cbc42bbcc16bde6352c13 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 07:35:26 +0800 Subject: [PATCH 06/26] add MyPeers for user --- http/controller/admin/user.go | 46 +++++++++++++++++++++++++++++++++++ http/controller/api/login.go | 2 +- http/response/admin/user.go | 2 +- http/router/admin.go | 1 + service/peer.go | 12 +++++++++ 5 files changed, 61 insertions(+), 2 deletions(-) diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index 2b1f17c..ea8c8cd 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "gorm.io/gorm" "strconv" + "time" ) type User struct { @@ -299,6 +300,51 @@ func (ct *User) MyOauth(c *gin.Context) { response.Success(c, res) } +// List 列表 +// @Tags 设备 +// @Summary 设备列表 +// @Description 设备列表 +// @Accept json +// @Produce json +// @Param page query int false "页码" +// @Param page_size query int false "页大小" +// @Param time_ago query int false "时间" +// @Param id query string false "ID" +// @Param hostname query string false "主机名" +// @Param uuids query string false "uuids 用逗号分隔" +// @Success 200 {object} response.Response{data=model.PeerList} +// @Failure 500 {object} response.Response +// @Router /admin/user/myPeer [get] +// @Security token +func (ct *User) MyPeer(c *gin.Context) { + query := &admin.PeerQuery{} + if err := c.ShouldBindQuery(query); err != nil { + response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error()) + return + } + u := service.AllService.UserService.CurUser(c) + res := service.AllService.PeerService.ListFilterByUserId(query.Page, query.PageSize, func(tx *gorm.DB) { + if query.TimeAgo > 0 { + lt := time.Now().Unix() - int64(query.TimeAgo) + tx.Where("last_online_time < ?", lt) + } + if query.TimeAgo < 0 { + lt := time.Now().Unix() + int64(query.TimeAgo) + tx.Where("last_online_time > ?", lt) + } + if query.Id != "" { + tx.Where("id like ?", "%"+query.Id+"%") + } + if query.Hostname != "" { + tx.Where("hostname like ?", "%"+query.Hostname+"%") + } + if query.Uuids != "" { + tx.Where("uuid in (?)", query.Uuids) + } + }, u.Id) + response.Success(c, res) +} + // groupUsers func (ct *User) GroupUsers(c *gin.Context) { q := &admin.GroupUsersQuery{} diff --git a/http/controller/api/login.go b/http/controller/api/login.go index f8be2e3..1d093b5 100644 --- a/http/controller/api/login.go +++ b/http/controller/api/login.go @@ -60,7 +60,7 @@ func (l *Login) Login(c *gin.Context) { ut := service.AllService.UserService.Login(u, &model.LoginLog{ UserId: u.Id, Client: f.DeviceInfo.Type, - DeviceId: f.Id, + DeviceId: f.Id, Uuid: f.Uuid, Ip: c.ClientIP(), Type: model.LoginLogTypeAccount, diff --git a/http/response/admin/user.go b/http/response/admin/user.go index 857fe69..df2a5ae 100644 --- a/http/response/admin/user.go +++ b/http/response/admin/user.go @@ -12,7 +12,7 @@ type LoginPayload struct { } var UserRouteNames = []string{ - "MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection", + "MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection", "MyPeer", } var AdminRouteNames = []string{"*"} diff --git a/http/router/admin.go b/http/router/admin.go index 368a081..bd1aac9 100644 --- a/http/router/admin.go +++ b/http/router/admin.go @@ -53,6 +53,7 @@ func UserBind(rg *gin.RouterGroup) { aR.GET("/current", cont.Current) aR.POST("/changeCurPwd", cont.ChangeCurPwd) aR.POST("/myOauth", cont.MyOauth) + aR.GET("/myPeer", cont.MyPeer) aR.POST("/groupUsers", cont.GroupUsers) } aRP := rg.Group("/user").Use(middleware.AdminPrivilege()) diff --git a/service/peer.go b/service/peer.go index d64eb99..d534d93 100644 --- a/service/peer.go +++ b/service/peer.go @@ -77,6 +77,18 @@ func (ps *PeerService) List(page, pageSize uint, where func(tx *gorm.DB)) (res * return } +// ListFilterByUserId 根据用户id过滤Peer列表 +func (ps *PeerService) ListFilterByUserId(page, pageSize uint, where func(tx *gorm.DB), userId uint) (res *model.PeerList) { + userWhere := func(tx *gorm.DB) { + tx.Where("user_id = ?", userId) + // 如果还有额外的筛选条件,执行它 + if where != nil { + where(tx) + } + } + return ps.List(page, pageSize, userWhere) +} + // Create 创建 func (ps *PeerService) Create(u *model.Peer) error { res := global.DB.Create(u).Error From 6a7ef29089221b4d79897eae5c9d606a0f337d99 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 08:02:03 +0800 Subject: [PATCH 07/26] delete the token when delete a peer --- service/peer.go | 30 +++++++++++++++++++++++++++--- service/user.go | 10 ++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/service/peer.go b/service/peer.go index d534d93..e351bac 100644 --- a/service/peer.go +++ b/service/peer.go @@ -94,16 +94,40 @@ func (ps *PeerService) Create(u *model.Peer) error { res := global.DB.Create(u).Error return res } + +// Delete 删除, 同时也应该删除token func (ps *PeerService) Delete(u *model.Peer) error { - return global.DB.Delete(u).Error + uuid := u.Uuid + err := global.DB.Delete(u).Error + if err != nil { + return err + } + // 删除token + return AllService.UserService.FlushTokenByUuid(uuid) } -// BatchDelete +// GetUuidListByIDs 根据ids获取uuid列表 +func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) { + var uuids []string + err := global.DB.Model(&model.Peer{}). + Where("row_id in (?)", ids). + Pluck("uuid", &uuids).Error + return uuids, err +} + +// BatchDelete 批量删除, 同时也应该删除token func (ps *PeerService) BatchDelete(ids []uint) error { - return global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error + uuids, err := ps.GetUuidListByIDs(ids) + err = global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error + if err != nil { + return err + } + // 删除token + return AllService.UserService.FlushTokenByUuids(uuids) } // Update 更新 func (ps *PeerService) Update(u *model.Peer) error { return global.DB.Model(u).Updates(u).Error } + diff --git a/service/user.go b/service/user.go index 4074bd1..2d04386 100644 --- a/service/user.go +++ b/service/user.go @@ -220,6 +220,16 @@ func (us *UserService) FlushToken(u *model.User) error { return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error } +// FlushTokenByUuid 清空token +func (us *UserService) FlushTokenByUuid(uuid string) error { + return global.DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error +} + +// FlushTokenByUuids 清空token +func (us *UserService) FlushTokenByUuids(uuids []string) error { + return global.DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error +} + // UpdatePassword 更新密码 func (us *UserService) UpdatePassword(u *model.User, password string) error { u.Password = us.EncryptPassword(password) From d38117107d7fded710c5dc54a8ecb2c7fb168402 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 08:19:44 +0800 Subject: [PATCH 08/26] When login, peer doesn't exist, it should create --- service/peer.go | 10 +++++++++- service/user.go | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/service/peer.go b/service/peer.go index e351bac..2876874 100644 --- a/service/peer.go +++ b/service/peer.go @@ -34,11 +34,19 @@ func (ps *PeerService) FindByUserIdAndUuid(uuid string,userId uint) *model.Peer } // UuidBindUserId 绑定用户id -func (ps *PeerService) UuidBindUserId(uuid string, userId uint) { +func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint) { peer := ps.FindByUuid(uuid) + // 如果存在则更新 if peer.RowId > 0 { peer.UserId = userId ps.Update(peer) + } else { + // 不存在则创建 + global.DB.Create(&model.Peer{ + Id: deviceId, + Uuid: uuid, + UserId: userId, + }) } } diff --git a/service/user.go b/service/user.go index 2d04386..7859764 100644 --- a/service/user.go +++ b/service/user.go @@ -83,7 +83,7 @@ func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserTok llog.UserTokenId = ut.UserId global.DB.Create(llog) if llog.Uuid != "" { - AllService.PeerService.UuidBindUserId(llog.Uuid, u.Id) + AllService.PeerService.UuidBindUserId(llog.DeviceId, llog.Uuid, u.Id) } return ut } From d85a799d0bca3b3a0e5e990638fcc45f94ca02fe Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sat, 2 Nov 2024 08:24:07 +0800 Subject: [PATCH 09/26] set user_id=0 at peers, when the user is deleted --- service/peer.go | 5 +++++ service/user.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/service/peer.go b/service/peer.go index 2876874..b4eefc8 100644 --- a/service/peer.go +++ b/service/peer.go @@ -58,6 +58,11 @@ func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) { } } +// EraseUserId 清除用户id, 用于用户删除 +func (ps *PeerService) EraseUserId(userId uint) error { + return global.DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error +} + // ListByUserIds 根据用户id取列表 func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *model.PeerList) { res = &model.PeerList{} diff --git a/service/user.go b/service/user.go index 7859764..56e1ae1 100644 --- a/service/user.go +++ b/service/user.go @@ -207,6 +207,11 @@ func (us *UserService) Delete(u *model.User) error { return err } tx.Commit() + // 删除关联的peer + if err := AllService.PeerService.EraseUserId(u.Id); err != nil { + tx.Rollback() + return err + } return nil } From 6fb1fbc5b1654a407983a01cf361eea91f2fb5e6 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 04:35:39 +0800 Subject: [PATCH 10/26] fix: RegisterByOauth without Email --- service/user.go | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/service/user.go b/service/user.go index 56e1ae1..876b40e 100644 --- a/service/user.go +++ b/service/user.go @@ -10,6 +10,8 @@ import ( "math/rand" "strconv" "time" + "strings" + "fmt" ) type UserService struct { @@ -150,6 +152,8 @@ func (us *UserService) CheckUserEnable(u *model.User) bool { // Create 创建 func (us *UserService) Create(u *model.User) error { + // The initial username should be formatted, and the username should be unique + u.Username = us.formatUsername(u.Username) u.Password = us.EncryptPassword(u.Password) res := global.DB.Create(u).Error return res @@ -282,7 +286,17 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) * } //check if this email has been registered email := oauthUser.Email - oauthType := AllService.OauthService.GetTypeByOp(op) + err, oauthType := AllService.OauthService.GetTypeByOp(op) + if err != nil { + return nil + } + // if email is empty, use username and op as email + if email == "" { + email = oauthUser.Username + "@" + op + } + email = strings.ToLower(email) + // update email to oauthUser, in case it contain upper case + oauthUser.Email = email user := us.InfoByEmail(email) tx := global.DB.Begin() if user.Id != 0 { @@ -290,8 +304,10 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) * } else { ut = &model.UserThird{} ut.FromOauthUser(0, oauthUser, oauthType, op) - usernameUnique := us.GenerateUsernameByOauth(oauthUser.Username) - user := &model.User{ + // The initial username should be formatted + username := us.formatUsername(oauthUser.Username) + usernameUnique := us.GenerateUsernameByOauth(username) + user = &model.User{ Username: usernameUnique, GroupId: 1, } @@ -361,6 +377,7 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool { return us.IsPasswordEmptyById(u.Id) } +// Register 注册 func (us *UserService) Register(username string, password string) *model.User { u := &model.User{ Username: username, @@ -394,3 +411,10 @@ func (us *UserService) TokenInfoById(id uint) *model.UserToken { func (us *UserService) DeleteToken(l *model.UserToken) error { return global.DB.Delete(l).Error } + +// Helper functions, used for formatting username +func (us *UserService) formatUsername(username string) string { + username = strings.ReplaceAll(username, " ", "") + username = strings.ToLower(username) + return username +} \ No newline at end of file From 2ceaa0091ba4c6f1fe957763818954a2c85bdae1 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:07:17 +0800 Subject: [PATCH 11/26] fix: Email of Register --- http/controller/admin/user.go | 2 +- service/user.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index ea8c8cd..561ec1b 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -391,7 +391,7 @@ func (ct *User) Register(c *gin.Context) { response.Fail(c, 101, errList[0]) return } - u := service.AllService.UserService.Register(f.Username, f.Password) + u := service.AllService.UserService.Register(f.Username, f.Email, f.Password) if u == nil || u.Id == 0 { response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")) return diff --git a/service/user.go b/service/user.go index 876b40e..358240e 100644 --- a/service/user.go +++ b/service/user.go @@ -11,7 +11,6 @@ import ( "strconv" "time" "strings" - "fmt" ) type UserService struct { @@ -378,9 +377,10 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool { } // Register 注册 -func (us *UserService) Register(username string, password string) *model.User { +func (us *UserService) Register(username string, email string, password string) *model.User { u := &model.User{ Username: username, + Email: email, Password: us.EncryptPassword(password), GroupId: 1, } From 6698877761ed193eb798cb266fbe323d6848d861 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:11:31 +0800 Subject: [PATCH 12/26] fix: email from github --- model/oauth.go | 20 ++++- service/oauth.go | 201 +++++++++++++++++++++++++++++------------------ 2 files changed, 144 insertions(+), 77 deletions(-) diff --git a/model/oauth.go b/model/oauth.go index 816a01c..61ce845 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -13,6 +13,18 @@ const ( OauthTypeWebauth string = "webauth" ) +const ( + OauthNameGithub string = "GitHub" + OauthNameGoogle string = "Google" + OauthNameOidc string = "OIDC" + OauthNameWebauth string = "WebAuth" +) + +const ( + UserEndpointGithub string = "https://api.github.com/user" + UserEndpointGoogle string = "https://www.googleapis.com/oauth2/v3/userinfo" + UserEndpointOidc string = "" +) type Oauth struct { IdModel @@ -33,6 +45,7 @@ type OauthUser struct { Username string `json:"username"` Email string `json:"email"` VerifiedEmail bool `json:"verified_email,omitempty"` + Picture string `json:"picture,omitempty"` } func (ou *OauthUser) ToUser(user *User, overideUsername bool) { @@ -56,6 +69,7 @@ type OidcUser struct { Sub string `json:"sub"` VerifiedEmail bool `json:"email_verified"` PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` } func (ou *OidcUser) ToOauthUser() *OauthUser { @@ -65,6 +79,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser { Username: ou.PreferredUsername, Email: ou.Email, VerifiedEmail: ou.VerifiedEmail, + Picture: ou.Picture, } } @@ -84,6 +99,7 @@ func (gu *GoogleUser) ToOauthUser() *OauthUser { Username: gu.GivenName, Email: gu.Email, VerifiedEmail: gu.VerifiedEmail, + Picture: gu.Picture, } } @@ -92,6 +108,8 @@ type GithubUser struct { OauthUserBase Id int `json:"id"` Login string `json:"login"` + AvatarUrl string `json:"avatar_url"` + VerifiedEmail bool `json:"verified_email"` } func (gu *GithubUser) ToOauthUser() *OauthUser { @@ -100,7 +118,7 @@ func (gu *GithubUser) ToOauthUser() *OauthUser { Name: gu.Name, Username: gu.Login, Email: gu.Email, - VerifiedEmail: true, + VerifiedEmail: gu.VerifiedEmail, } } diff --git a/service/oauth.go b/service/oauth.go index 446f61c..6cfa083 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -106,15 +106,14 @@ func (os *OauthService) DeleteOauthCache(key string) { func (os *OauthService) BeginAuth(op string) (error error, code, url string) { code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) - if op == string(model.OauthTypeWebauth) { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, code, url } - err, _, conf := os.GetOauthConfig(op) + err, _, oauthConfig := os.GetOauthConfig(op) if err == nil { - return err, code, conf.AuthCodeURL(code) + return err, code, oauthConfig.AuthCodeURL(code) } return err, code, "" @@ -154,16 +153,17 @@ func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { - err = os.ValidateOauthProvider(op) +func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { + err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) if err != nil { - return err, "", nil - } - err, oauthType, oauthConfig = os.getOauthConfigGeneral(op) - if err != nil { - return err, oauthType, nil + return err, nil, nil } // Maybe should validate the oauthConfig here + oauthType := oauthInfo.OauthType + err = os.ValidateOauthType(oauthType) + if err != nil { + return err, nil, nil + } switch oauthType { case model.OauthTypeGithub: oauthConfig.Endpoint = github.Endpoint @@ -172,32 +172,33 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig.Endpoint = google.Endpoint oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"} case model.OauthTypeOidc: - err, endpoint := os.FetchOidcEndpointByOp(op) + var endpoint OidcEndpoint + err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { - return err,oauthType, nil + return err, nil, nil } oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} - oauthConfig.Scopes = os.getScopesByOp(op) + oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) default: - return errors.New("unsupported OAuth type"), oauthType, nil + return errors.New("unsupported OAuth type"), nil, nil } - return nil, oauthType, oauthConfig + return nil, oauthInfo, oauthConfig } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { - g := os.InfoByOp(op) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" { - return errors.New("ConfigNotFound"), "", nil +func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { + oauthInfo = os.InfoByOp(op) + if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" { + return errors.New("ConfigNotFound"), nil, nil } // If the redirect URL is empty, use the default redirect URL - if g.RedirectUrl == "" { - g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" + if oauthInfo.RedirectUrl == "" { + oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" } - return nil, g.OauthType, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, + return nil, oauthInfo, &oauth2.Config{ + ClientID: oauthInfo.ClientId, + ClientSecret: oauthInfo.ClientSecret, + RedirectURL: oauthInfo.RedirectUrl, } } @@ -221,40 +222,26 @@ func getHTTPClientWithProxy() *http.Client { return http.DefaultClient } -func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error { - err, oauthType, oauthConfig := os.GetOauthConfig(op) - if err != nil { - return err - } - - // If the OAuth type is OIDC and the user endpoint is empty - // Fetch the OIDC configuration and get the user endpoint - if oauthType == model.OauthTypeOidc && userEndpoint == "" { - err, endpoint := os.FetchOidcEndpointByOp(op) - if err != nil { - global.Logger.Warn("failed fetching OIDC configuration: ", err) - return errors.New("FetchOidcEndpointError") - } - userEndpoint = endpoint.UserInfo - } +func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) { // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) // 使用 code 换取 token - token, err := oauthConfig.Exchange(ctx, code) + var token *oauth2.Token + token, err = oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - return errors.New("GetOauthTokenError") + return errors.New("GetOauthTokenError"), nil } // 获取用户信息 - client := oauthConfig.Client(ctx, token) + client = oauthConfig.Client(ctx, token) resp, err := client.Get(userEndpoint) if err != nil { global.Logger.Warn("failed getting user info: ", err) - return errors.New("GetOauthUserInfoError") + return errors.New("GetOauthUserInfoError"), nil } defer func() { if closeErr := resp.Body.Close(); closeErr != nil { @@ -265,36 +252,39 @@ func (os *OauthService) callbackBase(op string, code string, userEndpoint string // 解析用户信息 if err = json.NewDecoder(resp.Body).Decode(userData); err != nil { global.Logger.Warn("failed decoding user info: ", err) - return errors.New("DecodeOauthUserInfoError") + return errors.New("DecodeOauthUserInfoError"), nil } - return nil + return nil, client } // githubCallback github回调 -func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) { +func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { var user = &model.GithubUser{} - const userEndpoint = "https://api.github.com/user" - if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil { + err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user) + if err != nil { + return err, nil + } + err = os.getGithubPrimaryEmail(client, user) + if err != nil { return err, nil } return nil, user.ToOauthUser() } // googleCallback google回调 -func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) { +func (os *OauthService) googleCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { var user = &model.GoogleUser{} - const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo" - if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil { + if err, _ := os.callbackBase(oauthConfig, code, model.UserEndpointGoogle, user); err != nil { return err, nil } return nil, user.ToOauthUser() } // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) { var user = &model.OidcUser{} - if err := os.callbackBase(op, code, "", user); err != nil { + if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { return err, nil } return nil, user.ToOauthUser() @@ -302,22 +292,28 @@ func (os *OauthService) oidcCallback(code string, op string) (error, *model.Oaut // Callback: Get user information by code and op(Oauth provider) func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { - oauthType := os.GetTypeByOp(op) - if err = os.ValidateOauthType(oauthType); err != nil { - return err, nil - } - - switch oauthType { + var oauthInfo *model.Oauth + var oauthConfig *oauth2.Config + err, oauthInfo, oauthConfig = os.GetOauthConfig(op) + // oauthType is already validated in GetOauthConfig + if err != nil { + return err, nil + } + oauthType := oauthInfo.OauthType + switch oauthType { case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(code) + err, oauthUser = os.githubCallback(oauthConfig, code) case model.OauthTypeGoogle: - err, oauthUser = os.googleCallback(code) + err, oauthUser = os.googleCallback(oauthConfig, code) case model.OauthTypeOidc: - err, oauthUser = os.oidcCallback(code, op) + err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) + if err != nil { + return err, nil + } + err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) default: return errors.New("unsupported OAuth type"), nil } - return err, oauthUser } @@ -331,7 +327,10 @@ func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird // BindOauthUser: Bind third party account func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error { utr := &model.UserThird{} - oauthType := os.GetTypeByOp(op) + err, oauthType := os.GetTypeByOp(op) + if err != nil { + return err + } utr.FromOauthUser(userId, oauthUser, oauthType, op) return global.DB.Create(utr).Error } @@ -368,14 +367,18 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth { // Helper function to get scopes by operation func (os *OauthService) getScopesByOp(op string) []string { scopes := os.InfoByOp(op).Scopes - scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量 + return os.constructScopes(scopes) +} + +// Helper function to construct scopes +func (os *OauthService) constructScopes(scopes string) []string { + scopes = strings.TrimSpace(scopes) if scopes == "" { scopes = "openid,profile,email" } return strings.Split(scopes, ",") } - func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { res = &model.OauthList{} res.Page = int64(page) @@ -391,21 +394,30 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res } // GetTypeByOp 根据op获取OauthType -func (os *OauthService) GetTypeByOp(op string) string { +func (os *OauthService) GetTypeByOp(op string) (error, string) { oauthInfo := &model.Oauth{} if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil { - return "" + return fmt.Errorf("OAuth provider with op '%s' not found", op), "" } - return oauthInfo.OauthType + return nil, oauthInfo.OauthType } +// ValidateOauthProvider 验证Oauth提供者是否正确 func (os *OauthService) ValidateOauthProvider(op string) error { + if !os.IsOauthProviderExist(op) { + return fmt.Errorf("OAuth provider with op '%s' not found", op) + } + return nil +} + +// IsOauthProviderExist 验证Oauth提供者是否存在 +func (os *OauthService) IsOauthProviderExist(op string) bool { oauthInfo := &model.Oauth{} - // 使用 Gorm 的 Take 方法查找符合条件的记录 - if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { - return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err) - } - return nil + // 使用 Gorm 的 Take 方法查找符合条件的记录 + if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { + return false + } + return true } // Create 创建 @@ -427,4 +439,41 @@ func (os *OauthService) GetOauthProviders() []string { var res []string global.DB.Model(&model.Oauth{}).Pluck("op", &res) return res +} + +// getGithubPrimaryEmail: Get the primary email of the user from Github +func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *model.GithubUser) error { + // the client is already set with the token + resp, err := client.Get("https://api.github.com/user/emails") + if err != nil { + return fmt.Errorf("failed to fetch emails: %w", err) + } + defer resp.Body.Close() + + // check the response status code + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to fetch emails: %s", resp.Status) + } + + // decode the response + var emails []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` + } + + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + + // find the primary verified email + for _, e := range emails { + if e.Primary && e.Verified { + githubUser.Email = e.Email + githubUser.VerifiedEmail = e.Verified + return nil + } + } + + return fmt.Errorf("no primary verified email found") } \ No newline at end of file From 9ed376715facedaf6dab60ac5a897e6d14542253 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:13:22 +0800 Subject: [PATCH 13/26] const var for op name --- http/request/admin/oauth.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/http/request/admin/oauth.go b/http/request/admin/oauth.go index 8d8ad37..98a0e48 100644 --- a/http/request/admin/oauth.go +++ b/http/request/admin/oauth.go @@ -33,13 +33,13 @@ func (of *OauthForm) ToOauth() *model.Oauth { if op == "" { switch of.OauthType { case model.OauthTypeGithub: - of.Op = "GitHub" + of.Op = model.OauthNameGithub case model.OauthTypeGoogle: - of.Op = "Google" + of.Op = model.OauthNameGoogle case model.OauthTypeOidc: - of.Op = "OIDC" + of.Op = model.OauthNameOidc case model.OauthTypeWebauth: - of.Op = "WebAuth" + of.Op = model.OauthNameWebauth default: of.Op = of.OauthType } From 1ca50b5e9d31e3409716b59828592e4d3c0103b0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:25:10 +0800 Subject: [PATCH 14/26] low case email --- model/userThird.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/model/userThird.go b/model/userThird.go index 6f7f74a..5b8f5ef 100644 --- a/model/userThird.go +++ b/model/userThird.go @@ -1,5 +1,9 @@ package model +import ( + "strings" +) + type UserThird struct { IdModel UserId uint ` json:"user_id" gorm:"not null;index"` @@ -16,4 +20,6 @@ func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType s u.OauthUser = *oauthUser u.OauthType = oauthType u.Op = op + // make sure email is lower case + u.Email = strings.ToLower(u.Email) } \ No newline at end of file From 64f28c17d8bc4bc77b6c864e397273a1bbd03518 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:33:59 +0800 Subject: [PATCH 15/26] add Avatar to OauthUser --- model/oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/oauth.go b/model/oauth.go index 61ce845..94db710 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -54,7 +54,7 @@ func (ou *OauthUser) ToUser(user *User, overideUsername bool) { } user.Email = ou.Email user.Nickname = ou.Name - + user.Avatar = ou.Picture } From b9efc7302505a6971320dac4aaa8a2a6a7d66ba0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:34:19 +0800 Subject: [PATCH 16/26] chroe --- http/controller/admin/login.go | 27 +++++++++++---------------- http/response/admin/user.go | 7 +++++++ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index cd879cb..bb594f8 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -60,14 +60,7 @@ func (ct *Login) Login(c *gin.Context) { Platform: f.Platform, }) - response.Success(c, &adResp.LoginPayload{ - Token: ut.Token, - Username: u.Username, - Email: u.Email, - Avatar: u.Avatar, - RouteNames: service.AllService.UserService.RouteNames(u), - Nickname: u.Nickname, - }) + responseLoginSuccess(c, u, ut.Token) } // Logout 登出 @@ -165,12 +158,14 @@ func (ct *Login) OidcAuthQuery(c *gin.Context) { if ut == nil { return } - //fmt.Println("u:", u) - //fmt.Println("ut:", ut) - response.Success(c, &adResp.LoginPayload{ - Token: ut.Token, - Username: u.Username, - RouteNames: service.AllService.UserService.RouteNames(u), - Nickname: u.Nickname, - }) + responseLoginSuccess(c, u, ut.Token) } + + +func responseLoginSuccess(c *gin.Context, u *model.User, token string) { + lp := &adResp.LoginPayload{} + lp.FromUser(u) + lp.Token = token + lp.RouteNames = service.AllService.UserService.RouteNames(u) + response.Success(c, lp) +} \ No newline at end of file diff --git a/http/response/admin/user.go b/http/response/admin/user.go index df2a5ae..d441106 100644 --- a/http/response/admin/user.go +++ b/http/response/admin/user.go @@ -11,6 +11,13 @@ type LoginPayload struct { Nickname string `json:"nickname"` } +func (lp *LoginPayload) FromUser(user *model.User) { + lp.Username = user.Username + lp.Email = user.Email + lp.Avatar = user.Avatar + lp.Nickname = user.Nickname +} + var UserRouteNames = []string{ "MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection", "MyPeer", } From fb9173ed532317b877d371c950f5b71e242e4994 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 05:37:34 +0800 Subject: [PATCH 17/26] optimize /admin/login-options --- http/controller/admin/login.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index bb594f8..c00e2b6 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -11,7 +11,6 @@ import ( "Gwen/service" "fmt" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) type Login struct { @@ -91,13 +90,7 @@ func (ct *Login) Logout(c *gin.Context) { // @Failure 500 {object} response.ErrorResponse // @Router /admin/login-options [post] func (ct *Login) LoginOptions(c *gin.Context) { - res := service.AllService.OauthService.List(1, 100, func(tx *gorm.DB) { - tx.Select("op").Order("id") - }) - var ops []string - for _, v := range res.Oauths { - ops = append(ops, v.Op) - } + ops := service.AllService.OauthService.GetOauthProviders() response.Success(c, gin.H{ "ops": ops, "register": global.Config.App.Register, From 3cd90c8f741a25d82712796d871e17630049bdf4 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 16:34:50 +0800 Subject: [PATCH 18/26] fronted for docker-dev --- Dockerfile.dev | 5 ++++- docker-compose-dev.yaml | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Dockerfile.dev b/Dockerfile.dev index 7269a73..d48794f 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -42,8 +42,11 @@ RUN if [ "$COUNTRY" = "CN" ] ; then \ fi && \ apk update && apk add --no-cache git +ARG FREONTEND_GIT_REPO=https://github.com/lejianwen/rustdesk-api-web.git +ARG FRONTEND_GIT_BRANCH=master # Clone the frontend repository -RUN git clone https://github.com/lejianwen/rustdesk-api-web . + +RUN git clone -b $FRONTEND_GIT_BRANCH $FREONTEND_GIT_REPO . # Install required tools without caching index to minimize image size RUN if [ "$COUNTRY" = "CN" ] ; then \ diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml index 9e30042..6118d4a 100644 --- a/docker-compose-dev.yaml +++ b/docker-compose-dev.yaml @@ -5,6 +5,8 @@ services: dockerfile: Dockerfile.dev args: COUNTRY: CN + FREONTEND_GIT_REPO: https://github.com/lejianwen/rustdesk-api-web.git + FRONTEND_GIT_BRANCH: master # image: lejianwen/rustdesk-api container_name: rustdesk-api environment: From da7b70c471762857185493eab512eb2ffdb30c5e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 16:49:03 +0800 Subject: [PATCH 19/26] add err for RegisterByOauth --- http/controller/api/ouath.go | 6 +++--- service/user.go | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index b3a3ba3..3d7ab67 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -208,9 +208,9 @@ func (o *Oauth) OauthCallback(c *gin.Context) { } //自动注册 - user = service.AllService.UserService.RegisterByOauth(oauthUser, op) - if user.Id == 0 { - c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed")) + err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op) + if err != nil { + c.String(http.StatusInternalServerError, response.TranslateMsg(c, err.Error())) return } } diff --git a/service/user.go b/service/user.go index 358240e..4a9fa66 100644 --- a/service/user.go +++ b/service/user.go @@ -11,6 +11,7 @@ import ( "strconv" "time" "strings" + "errors" ) type UserService struct { @@ -276,18 +277,18 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User { } // RegisterByOauth 注册 -func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) *model.User { +func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) { global.Lock.Lock("registerByOauth") defer global.Lock.UnLock("registerByOauth") ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) if ut.Id != 0 { - return us.InfoById(ut.UserId) + return nil, us.InfoById(ut.UserId) } //check if this email has been registered email := oauthUser.Email err, oauthType := AllService.OauthService.GetTypeByOp(op) if err != nil { - return nil + return err, nil } // if email is empty, use username and op as email if email == "" { @@ -314,13 +315,13 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) * tx.Create(user) if user.Id == 0 { tx.Rollback() - return user + return errors.New("OauthRegisterFailed"), user } ut.UserId = user.Id } tx.Create(ut) tx.Commit() - return user + return nil, user } // GenerateUsernameByOauth 生成用户名 From fea3960672b4099db43c701f2c1f29cab51d01cd Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 16:49:28 +0800 Subject: [PATCH 20/26] fix: Github AvatarUrl to OauthUser --- model/oauth.go | 1 + 1 file changed, 1 insertion(+) diff --git a/model/oauth.go b/model/oauth.go index 94db710..8f2b84e 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -119,6 +119,7 @@ func (gu *GithubUser) ToOauthUser() *OauthUser { Username: gu.Login, Email: gu.Email, VerifiedEmail: gu.VerifiedEmail, + Picture: gu.AvatarUrl, } } From aee25a6c99ac8776d3f300fe56179d36a612da34 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 17:19:05 +0800 Subject: [PATCH 21/26] fix: last admin shouldn't be deleted, disabled or demoted --- service/user.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/service/user.go b/service/user.go index 4a9fa66..1259559 100644 --- a/service/user.go +++ b/service/user.go @@ -184,6 +184,10 @@ func (us *UserService) Logout(u *model.User, token string) error { // Delete 删除用户和oauth信息 func (us *UserService) Delete(u *model.User) error { + userCount := us.getAdminUserCount() + if userCount <= 1 { + return errors.New("The last admin user cannot be deleted") + } tx := global.DB.Begin() // 删除用户 if err := tx.Delete(u).Error; err != nil { @@ -221,6 +225,15 @@ func (us *UserService) Delete(u *model.User) error { // Update 更新 func (us *UserService) Update(u *model.User) error { + currentUser := us.InfoById(u.Id) + // 如果当前用户是管理员并且 IsAdmin 不为空,进行检查 + if currentUser.IsAdmin != nil && *currentUser.IsAdmin { + adminCount := us.getAdminUserCount() + // 如果这是唯一的管理员,确保不能禁用或取消管理员权限 + if adminCount <= 1 && (u.IsAdmin == nil || !*u.IsAdmin || u.Status == model.COMMON_STATUS_DISABLED) { + return errors.New("The last admin user cannot be disabled or demoted") + } + } return global.DB.Model(u).Updates(u).Error } @@ -418,4 +431,18 @@ func (us *UserService) formatUsername(username string) string { username = strings.ReplaceAll(username, " ", "") username = strings.ToLower(username) return username +} + +// Helper functions, getUserCount +func (us *UserService) getUserCount() int64 { + var count int64 + global.DB.Model(&model.User{}).Count(&count) + return count +} + +// helper functions, getAdminUserCount +func (us *UserService) getAdminUserCount() int64 { + var count int64 + global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count) + return count } \ No newline at end of file From 18d59d704728076475051a80a735af30b5d16755 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 17:25:27 +0800 Subject: [PATCH 22/26] re-use responseLoginSuccess --- http/controller/admin/user.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/http/controller/admin/user.go b/http/controller/admin/user.go index 561ec1b..2ffabf9 100644 --- a/http/controller/admin/user.go +++ b/http/controller/admin/user.go @@ -217,12 +217,7 @@ func (ct *User) Current(c *gin.Context) { u := service.AllService.UserService.CurUser(c) token, _ := c.Get("token") t := token.(string) - response.Success(c, &adResp.LoginPayload{ - Token: t, - Username: u.Username, - RouteNames: service.AllService.UserService.RouteNames(u), - Nickname: u.Nickname, - }) + responseLoginSuccess(c, u, t) } // ChangeCurPwd 修改当前用户密码 @@ -404,11 +399,5 @@ func (ct *User) Register(c *gin.Context) { Ip: c.ClientIP(), Type: model.LoginLogTypeAccount, }) - response.Success(c, &adResp.LoginPayload{ - Token: ut.Token, - Username: u.Username, - Email: u.Email, - RouteNames: service.AllService.UserService.RouteNames(u), - Nickname: u.Nickname, - }) + responseLoginSuccess(c, u, ut.Token) } From 9dfe74562955839a26f9073637f48abb0fce3f9e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 18:04:28 +0800 Subject: [PATCH 23/26] fix google --- model/oauth.go | 47 ++++++++++++++++++++++------------------------- service/oauth.go | 4 ++-- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/model/oauth.go b/model/oauth.go index 8f2b84e..91b4181 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -2,9 +2,10 @@ package model import ( "strconv" - "fmt" + "strings" ) +const OIDC_DEFAULT_SCOPES = "openid,profile,email" const ( OauthTypeGithub string = "github" @@ -57,50 +58,45 @@ func (ou *OauthUser) ToUser(user *User, overideUsername bool) { user.Avatar = ou.Picture } - type OauthUserBase struct { Name string `json:"name"` Email string `json:"email"` } - type OidcUser struct { OauthUserBase Sub string `json:"sub"` VerifiedEmail bool `json:"email_verified"` PreferredUsername string `json:"preferred_username"` - Picture string `json:"picture"` + Picture string `json:"picture"` } func (ou *OidcUser) ToOauthUser() *OauthUser { + var username string + // 使用 PreferredUsername,如果不存在,降级到 Email 前缀 + if ou.PreferredUsername != "" { + username = ou.PreferredUsername + } else { + username = strings.ToLower(strings.Split(ou.Email, "@")[0]) + } + return &OauthUser{ - OpenId: ou.Sub, - Name: ou.Name, - Username: ou.PreferredUsername, - Email: ou.Email, - VerifiedEmail: ou.VerifiedEmail, - Picture: ou.Picture, + OpenId: ou.Sub, + Name: ou.Name, + Username: username, + Email: ou.Email, + VerifiedEmail: ou.VerifiedEmail, + Picture: ou.Picture, } } type GoogleUser struct { - OauthUserBase - FamilyName string `json:"family_name"` - GivenName string `json:"given_name"` - Id string `json:"id"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` + OidcUser } +// GoogleUser 使用特定的 Username 规则来调用 ToOauthUser func (gu *GoogleUser) ToOauthUser() *OauthUser { - return &OauthUser{ - OpenId: gu.Id, - Name: fmt.Sprintf("%s %s", gu.GivenName, gu.FamilyName), - Username: gu.GivenName, - Email: gu.Email, - VerifiedEmail: gu.VerifiedEmail, - Picture: gu.Picture, - } + return gu.OidcUser.ToOauthUser() } @@ -113,10 +109,11 @@ type GithubUser struct { } func (gu *GithubUser) ToOauthUser() *OauthUser { + username := strings.ToLower(gu.Login) return &OauthUser{ OpenId: strconv.Itoa(gu.Id), Name: gu.Name, - Username: gu.Login, + Username: username, Email: gu.Email, VerifiedEmail: gu.VerifiedEmail, Picture: gu.AvatarUrl, diff --git a/service/oauth.go b/service/oauth.go index 6cfa083..331ca97 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -170,7 +170,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O oauthConfig.Scopes = []string{"read:user", "user:email"} case model.OauthTypeGoogle: oauthConfig.Endpoint = google.Endpoint - oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"} + oauthConfig.Scopes = os.constructScopes(model.OIDC_DEFAULT_SCOPES) case model.OauthTypeOidc: var endpoint OidcEndpoint err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer) @@ -374,7 +374,7 @@ func (os *OauthService) getScopesByOp(op string) []string { func (os *OauthService) constructScopes(scopes string) []string { scopes = strings.TrimSpace(scopes) if scopes == "" { - scopes = "openid,profile,email" + scopes = model.OIDC_DEFAULT_SCOPES } return strings.Split(scopes, ",") } From ca79a63492d51e89447a434140d0ac6803e590b8 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 21:59:17 +0800 Subject: [PATCH 24/26] fix: call us.IsAdmin(u) to check admin --- service/user.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/service/user.go b/service/user.go index 1259559..0ea71fb 100644 --- a/service/user.go +++ b/service/user.go @@ -227,10 +227,10 @@ func (us *UserService) Delete(u *model.User) error { func (us *UserService) Update(u *model.User) error { currentUser := us.InfoById(u.Id) // 如果当前用户是管理员并且 IsAdmin 不为空,进行检查 - if currentUser.IsAdmin != nil && *currentUser.IsAdmin { + if us.IsAdmin(currentUser) { adminCount := us.getAdminUserCount() // 如果这是唯一的管理员,确保不能禁用或取消管理员权限 - if adminCount <= 1 && (u.IsAdmin == nil || !*u.IsAdmin || u.Status == model.COMMON_STATUS_DISABLED) { + if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) { return errors.New("The last admin user cannot be disabled or demoted") } } From 5a53f180e48fa3849f541adda70b34b4070e7289 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Sun, 3 Nov 2024 22:23:24 +0800 Subject: [PATCH 25/26] fix: delete check --- service/user.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/user.go b/service/user.go index 0ea71fb..774ec1c 100644 --- a/service/user.go +++ b/service/user.go @@ -185,7 +185,7 @@ func (us *UserService) Logout(u *model.User, token string) error { // Delete 删除用户和oauth信息 func (us *UserService) Delete(u *model.User) error { userCount := us.getAdminUserCount() - if userCount <= 1 { + if userCount <= 1 && us.IsAdmin(u) { return errors.New("The last admin user cannot be deleted") } tx := global.DB.Begin() From 3acfb36c5d01821b4ef80e083ad576f6e1ed1f7f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 4 Nov 2024 21:30:58 +0800 Subject: [PATCH 26/26] modify google ro re-use oidc --- http/request/admin/oauth.go | 17 ----------- model/oauth.go | 58 ++++++++++++++++++++++++++++++------- service/oauth.go | 39 ++++++++----------------- 3 files changed, 59 insertions(+), 55 deletions(-) diff --git a/http/request/admin/oauth.go b/http/request/admin/oauth.go index 98a0e48..57c96c4 100644 --- a/http/request/admin/oauth.go +++ b/http/request/admin/oauth.go @@ -2,7 +2,6 @@ package admin import ( "Gwen/model" - "strings" ) type BindOauthForm struct { @@ -28,22 +27,6 @@ type OauthForm struct { } func (of *OauthForm) ToOauth() *model.Oauth { - op := strings.ToLower(of.Op) - op = strings.TrimSpace(op) - if op == "" { - switch of.OauthType { - case model.OauthTypeGithub: - of.Op = model.OauthNameGithub - case model.OauthTypeGoogle: - of.Op = model.OauthNameGoogle - case model.OauthTypeOidc: - of.Op = model.OauthNameOidc - case model.OauthTypeWebauth: - of.Op = model.OauthNameWebauth - default: - of.Op = of.OauthType - } - } oa := &model.Oauth{ Op: of.Op, OauthType: of.OauthType, diff --git a/model/oauth.go b/model/oauth.go index 91b4181..e879caf 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -3,17 +3,29 @@ package model import ( "strconv" "strings" + "errors" ) const OIDC_DEFAULT_SCOPES = "openid,profile,email" const ( + // make sure the value shouldbe lowercase OauthTypeGithub string = "github" OauthTypeGoogle string = "google" OauthTypeOidc string = "oidc" OauthTypeWebauth string = "webauth" ) +// Validate the oauth type +func ValidateOauthType(oauthType string) error { + switch oauthType { + case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth: + return nil + default: + return errors.New("invalid Oauth type") + } +} + const ( OauthNameGithub string = "GitHub" OauthNameGoogle string = "Google" @@ -23,8 +35,7 @@ const ( const ( UserEndpointGithub string = "https://api.github.com/user" - UserEndpointGoogle string = "https://www.googleapis.com/oauth2/v3/userinfo" - UserEndpointOidc string = "" + IssuerGoogle string = "https://accounts.google.com" ) type Oauth struct { @@ -40,6 +51,40 @@ type Oauth struct { TimeModel } + + +// Helper function to format oauth info, it's used in the update and create method +func (oa *Oauth) FormatOauthInfo() error { + oauthType := strings.TrimSpace(oa.OauthType) + err := ValidateOauthType(oa.OauthType) + if err != nil { + return err + } + // check if the op is empty, set the default value + op := strings.TrimSpace(oa.Op) + if op == "" { + switch oauthType { + case OauthTypeGithub: + oa.Op = OauthNameGithub + case OauthTypeGoogle: + oa.Op = OauthNameGoogle + case OauthTypeOidc: + oa.Op = OauthNameOidc + case OauthTypeWebauth: + oa.Op = OauthNameWebauth + default: + oa.Op = oauthType + } + } + // check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value + issuer := strings.TrimSpace(oa.Issuer) + // If the oauth type is google and the issuer is empty, set the issuer to the default value + if oauthType == OauthTypeGoogle && issuer == "" { + oa.Issuer = IssuerGoogle + } + return nil +} + type OauthUser struct { OpenId string `json:"open_id" gorm:"not null;index"` Name string `json:"name"` @@ -90,15 +135,6 @@ func (ou *OidcUser) ToOauthUser() *OauthUser { } } -type GoogleUser struct { - OidcUser -} - -// GoogleUser 使用特定的 Username 规则来调用 ToOauthUser -func (gu *GoogleUser) ToOauthUser() *OauthUser { - return gu.OidcUser.ToOauthUser() -} - type GithubUser struct { OauthUserBase diff --git a/service/oauth.go b/service/oauth.go index 331ca97..1966237 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -9,7 +9,7 @@ import ( "errors" "golang.org/x/oauth2" "golang.org/x/oauth2/github" - "golang.org/x/oauth2/google" + // "golang.org/x/oauth2/google" "gorm.io/gorm" // "io" "net/http" @@ -71,16 +71,6 @@ func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { oa.Email = oauthUser.Email } -// Validate the oauth type -func (os *OauthService) ValidateOauthType(oauthType string) error { - switch oauthType { - case model.OauthTypeGithub, model.OauthTypeGoogle, model.OauthTypeOidc, model.OauthTypeWebauth: - return nil - default: - return errors.New("invalid Oauth type") - } -} - func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { v, ok := OauthCache.Load(key) @@ -160,7 +150,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O } // Maybe should validate the oauthConfig here oauthType := oauthInfo.OauthType - err = os.ValidateOauthType(oauthType) + err = model.ValidateOauthType(oauthType) if err != nil { return err, nil, nil } @@ -168,10 +158,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O case model.OauthTypeGithub: oauthConfig.Endpoint = github.Endpoint oauthConfig.Scopes = []string{"read:user", "user:email"} - case model.OauthTypeGoogle: - oauthConfig.Endpoint = google.Endpoint - oauthConfig.Scopes = os.constructScopes(model.OIDC_DEFAULT_SCOPES) - case model.OauthTypeOidc: + case model.OauthTypeOidc, model.OauthTypeGoogle: var endpoint OidcEndpoint err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { @@ -272,14 +259,6 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) return nil, user.ToOauthUser() } -// googleCallback google回调 -func (os *OauthService) googleCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { - var user = &model.GoogleUser{} - if err, _ := os.callbackBase(oauthConfig, code, model.UserEndpointGoogle, user); err != nil { - return err, nil - } - return nil, user.ToOauthUser() -} // oidcCallback oidc回调, 通过code获取用户信息 func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) { @@ -303,9 +282,7 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser * switch oauthType { case model.OauthTypeGithub: err, oauthUser = os.githubCallback(oauthConfig, code) - case model.OauthTypeGoogle: - err, oauthUser = os.googleCallback(oauthConfig, code) - case model.OauthTypeOidc: + case model.OauthTypeOidc, model.OauthTypeGoogle: err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { return err, nil @@ -422,6 +399,10 @@ func (os *OauthService) IsOauthProviderExist(op string) bool { // Create 创建 func (os *OauthService) Create(oauthInfo *model.Oauth) error { + err := oauthInfo.FormatOauthInfo() + if err != nil { + return err + } res := global.DB.Create(oauthInfo).Error return res } @@ -431,6 +412,10 @@ func (os *OauthService) Delete(oauthInfo *model.Oauth) error { // Update 更新 func (os *OauthService) Update(oauthInfo *model.Oauth) error { + err := oauthInfo.FormatOauthInfo() + if err != nil { + return err + } return global.DB.Model(oauthInfo).Updates(oauthInfo).Error }