feat(oauth): Oauth nonce (#148)

This commit is contained in:
lejianwen
2025-02-26 16:36:53 +08:00
parent 0dd92311b2
commit 0f16f61ab3
4 changed files with 31 additions and 22 deletions

View File

@@ -283,7 +283,7 @@ func (ct *Login) OidcAuth(c *gin.Context) {
return return
} }
err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
@@ -298,6 +298,7 @@ func (ct *Login) OidcAuth(c *gin.Context) {
DeviceOs: f.DeviceInfo.Os, DeviceOs: f.DeviceInfo.Os,
Uuid: f.Uuid, Uuid: f.Uuid,
Verifier: verifier, Verifier: verifier,
Nonce: nonce,
}, 5*60) }, 5*60)
response.Success(c, gin.H{ response.Success(c, gin.H{

View File

@@ -43,17 +43,18 @@ func (o *Oauth) ToBind(c *gin.Context) {
return return
} }
err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
} }
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{ service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeBind, Action: service.OauthActionTypeBind,
Op: f.Op, Op: f.Op,
UserId: u.Id, UserId: u.Id,
Verifier: verifier, Verifier: verifier,
Nonce: nonce,
}, 5*60) }, 5*60)
response.Success(c, gin.H{ response.Success(c, gin.H{

View File

@@ -32,10 +32,8 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
} }
oauthService := service.AllService.OauthService oauthService := service.AllService.OauthService
var state string
var url string err, state, verifier, nonce, url := oauthService.BeginAuth(f.Op)
var verifier string
err, state, verifier, url = oauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
@@ -50,6 +48,7 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
DeviceOs: f.DeviceInfo.Os, DeviceOs: f.DeviceInfo.Os,
DeviceType: f.DeviceInfo.Type, DeviceType: f.DeviceInfo.Type,
Verifier: verifier, Verifier: verifier,
Nonce: nonce,
}, 5*60) }, 5*60)
//fmt.Println("code url", code, url) //fmt.Println("code url", code, url)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -160,13 +159,14 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
}) })
return return
} }
nonce := oauthCache.Nonce
op := oauthCache.Op op := oauthCache.Op
action := oauthCache.Action action := oauthCache.Action
verifier := oauthCache.Verifier verifier := oauthCache.Verifier
var user *model.User var user *model.User
// 获取用户信息 // 获取用户信息
code := c.Query("code") code := c.Query("code")
err, oauthUser := oauthService.Callback(code, verifier, op) err, oauthUser := oauthService.Callback(code, verifier, op, nonce)
if err != nil { if err != nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "OauthFailed") + response.TranslateMsg(c, err.Error()), "message": response.TranslateMsg(c, "OauthFailed") + response.TranslateMsg(c, err.Error()),

View File

@@ -47,6 +47,7 @@ type OauthCacheItem struct {
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
Verifier string `json:"verifier"` // used for oauth pkce Verifier string `json:"verifier"` // used for oauth pkce
Nonce string `json:"nonce"`
} }
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -93,17 +94,22 @@ func (os *OauthService) DeleteOauthCache(key string) {
OauthCache.Delete(key) OauthCache.Delete(key)
} }
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) { func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonce, url string) {
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
verifier = "" verifier = ""
nonce = ""
if op == model.OauthTypeWebauth { if op == model.OauthTypeWebauth {
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
//url = "http://localhost:8888/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code
return nil, state, verifier, url return nil, state, verifier, nonce, url
} }
err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op) err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
if err == nil { if err == nil {
extras := make([]oauth2.AuthCodeOption, 0, 3) extras := make([]oauth2.AuthCodeOption, 0, 3)
nonce = utils.RandomString(10)
extras = append(extras, oauth2.SetAuthURLParam("nonce", nonce))
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable { if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
extras = append(extras, oauth2.AccessTypeOffline) extras = append(extras, oauth2.AccessTypeOffline)
verifier = oauth2.GenerateVerifier() verifier = oauth2.GenerateVerifier()
@@ -115,10 +121,11 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
} }
} }
return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
return err, state, verifier, nonce, oauthConfig.AuthCodeURL(state, extras...)
} }
return err, state, verifier, "" return err, state, verifier, nonce, ""
} }
func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) { func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) {
@@ -280,9 +287,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.
} }
// githubCallback github回调 // githubCallback github回调
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.GithubUser{} var user = &model.GithubUser{}
err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user) err, client := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -294,16 +301,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oid
} }
// oidcCallback oidc回调, 通过code获取用户信息 // oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.OidcUser{} var user = &model.OidcUser{}
if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil { if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user); err != nil {
return err, nil return err, nil
} }
return nil, user.ToOauthUser() return nil, user.ToOauthUser()
} }
// Callback: Get user information by code and op(Oauth provider) // Callback: Get user information by code and op(Oauth provider)
func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) { func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) {
err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op) err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
// oauthType is already validated in GetOauthConfig // oauthType is already validated in GetOauthConfig
if err != nil { if err != nil {
@@ -312,9 +319,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse
oauthType := oauthInfo.OauthType oauthType := oauthInfo.OauthType
switch oauthType { switch oauthType {
case model.OauthTypeGithub: case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier) err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeOidc, model.OauthTypeGoogle: case model.OauthTypeOidc, model.OauthTypeGoogle:
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier) err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce)
default: default:
return errors.New("unsupported OAuth type"), nil return errors.New("unsupported OAuth type"), nil
} }