diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index 3ca71a9..bfbe801 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -283,7 +283,7 @@ func (ct *Login) OidcAuth(c *gin.Context) { return } - err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) + err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return @@ -298,6 +298,7 @@ func (ct *Login) OidcAuth(c *gin.Context) { DeviceOs: f.DeviceInfo.Os, Uuid: f.Uuid, Verifier: verifier, + Nonce: nonce, }, 5*60) response.Success(c, gin.H{ diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index 88716d5..efdf9e6 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -43,17 +43,18 @@ func (o *Oauth) ToBind(c *gin.Context) { return } - err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) + err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return } service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{ - Action: service.OauthActionTypeBind, - Op: f.Op, - UserId: u.Id, - Verifier: verifier, + Action: service.OauthActionTypeBind, + Op: f.Op, + UserId: u.Id, + Verifier: verifier, + Nonce: nonce, }, 5*60) response.Success(c, gin.H{ diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 106bdf7..2dceecf 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -32,10 +32,8 @@ func (o *Oauth) OidcAuth(c *gin.Context) { } oauthService := service.AllService.OauthService - var state string - var url string - var verifier string - err, state, verifier, url = oauthService.BeginAuth(f.Op) + + err, state, verifier, nonce, url := oauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return @@ -50,6 +48,7 @@ func (o *Oauth) OidcAuth(c *gin.Context) { DeviceOs: f.DeviceInfo.Os, DeviceType: f.DeviceInfo.Type, Verifier: verifier, + Nonce: nonce, }, 5*60) //fmt.Println("code url", code, url) c.JSON(http.StatusOK, gin.H{ @@ -160,13 +159,14 @@ func (o *Oauth) OauthCallback(c *gin.Context) { }) return } + nonce := oauthCache.Nonce op := oauthCache.Op action := oauthCache.Action verifier := oauthCache.Verifier var user *model.User // 获取用户信息 code := c.Query("code") - err, oauthUser := oauthService.Callback(code, verifier, op) + err, oauthUser := oauthService.Callback(code, verifier, op, nonce) if err != nil { c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ "message": response.TranslateMsg(c, "OauthFailed") + response.TranslateMsg(c, err.Error()), diff --git a/service/oauth.go b/service/oauth.go index 17f90a2..e380b87 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -47,6 +47,7 @@ type OauthCacheItem struct { Name string `json:"name"` Email string `json:"email"` Verifier string `json:"verifier"` // used for oauth pkce + Nonce string `json:"nonce"` } func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { @@ -93,17 +94,22 @@ func (os *OauthService) DeleteOauthCache(key string) { OauthCache.Delete(key) } -func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) { +func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonce, url string) { state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) verifier = "" + nonce = "" if op == model.OauthTypeWebauth { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state //url = "http://localhost:8888/_admin/#/oauth/" + code - return nil, state, verifier, url + return nil, state, verifier, nonce, url } err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op) if err == nil { extras := make([]oauth2.AuthCodeOption, 0, 3) + + nonce = utils.RandomString(10) + extras = append(extras, oauth2.SetAuthURLParam("nonce", nonce)) + if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable { extras = append(extras, oauth2.AccessTypeOffline) verifier = oauth2.GenerateVerifier() @@ -115,10 +121,11 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) } } - return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...) + + return err, state, verifier, nonce, oauthConfig.AuthCodeURL(state, extras...) } - return err, state, verifier, "" + return err, state, verifier, nonce, "" } func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) { @@ -280,9 +287,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc. } // githubCallback github回调 -func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { +func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) { var user = &model.GithubUser{} - err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user) + err, client := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user) if err != nil { return err, nil } @@ -294,16 +301,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oid } // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) { var user = &model.OidcUser{} - if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil { + if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user); err != nil { return err, nil } return nil, user.ToOauthUser() } // Callback: Get user information by code and op(Oauth provider) -func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) { +func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) { err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op) // oauthType is already validated in GetOauthConfig if err != nil { @@ -312,9 +319,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse oauthType := oauthInfo.OauthType switch oauthType { case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier) + err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce) case model.OauthTypeOidc, model.OauthTypeGoogle: - err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier) + err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce) default: return errors.New("unsupported OAuth type"), nil }