feat(oidc): add pkce (#150)

This commit is contained in:
Tao Chen
2025-02-19 09:31:25 +08:00
committed by GitHub
parent 99e63cadcf
commit 6f55c5b642
7 changed files with 75 additions and 30 deletions

View File

@@ -166,7 +166,7 @@ func InitGlobal() {
global.Lock = lock.NewLocal() global.Lock = lock.NewLocal()
} }
func DatabaseAutoUpdate() { func DatabaseAutoUpdate() {
version := 260 version := 261
db := global.DB db := global.DB
@@ -210,6 +210,12 @@ func DatabaseAutoUpdate() {
if v.Version < uint(version) { if v.Version < uint(version) {
Migrate(uint(version)) Migrate(uint(version))
} }
// 261迁移
if v.Version < 261 {
// 在oauths表中添加pkce_enable 和 pkce_method 字段
db.Exec("ALTER TABLE oauths ADD COLUMN pkce_enable TINYINT(1) NOT NULL DEFAULT 0")
db.Exec("ALTER TABLE oauths ADD COLUMN pkce_method VARCHAR(20) NOT NULL DEFAULT 'S256'")
}
// 245迁移 // 245迁移
if v.Version < 245 { if v.Version < 245 {
//oauths 表的 oauth_type 字段设置为 op同样的值 //oauths 表的 oauth_type 字段设置为 op同样的值

View File

@@ -283,13 +283,13 @@ func (ct *Login) OidcAuth(c *gin.Context) {
return return
} }
err, code, url := service.AllService.OauthService.BeginAuth(f.Op) err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
} }
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeLogin, Action: service.OauthActionTypeLogin,
Op: f.Op, Op: f.Op,
Id: f.Id, Id: f.Id,
@@ -297,10 +297,11 @@ func (ct *Login) OidcAuth(c *gin.Context) {
// DeviceOs: ct.Platform(c), // DeviceOs: ct.Platform(c),
DeviceOs: f.DeviceInfo.Os, DeviceOs: f.DeviceInfo.Os,
Uuid: f.Uuid, Uuid: f.Uuid,
Verifier: verifier,
}, 5*60) }, 5*60)
response.Success(c, gin.H{ response.Success(c, gin.H{
"code": code, "code": state,
"url": url, "url": url,
}) })
} }

View File

@@ -43,20 +43,21 @@ func (o *Oauth) ToBind(c *gin.Context) {
return return
} }
err, code, url := service.AllService.OauthService.BeginAuth(f.Op) err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
} }
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeBind, Action: service.OauthActionTypeBind,
Op: f.Op, Op: f.Op,
UserId: u.Id, UserId: u.Id,
Verifier: verifier,
}, 5*60) }, 5*60)
response.Success(c, gin.H{ response.Success(c, gin.H{
"code": code, "code": state,
"url": url, "url": url,
}) })
} }

View File

@@ -32,15 +32,16 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
} }
oauthService := service.AllService.OauthService oauthService := service.AllService.OauthService
var code string var state string
var url string var url string
err, code, url = oauthService.BeginAuth(f.Op) var verifier string
err, state, verifier, url = oauthService.BeginAuth(f.Op)
if err != nil { if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error())) response.Error(c, response.TranslateMsg(c, err.Error()))
return return
} }
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeLogin, Action: service.OauthActionTypeLogin,
Id: f.Id, Id: f.Id,
Op: f.Op, Op: f.Op,
@@ -48,10 +49,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
DeviceName: f.DeviceInfo.Name, DeviceName: f.DeviceInfo.Name,
DeviceOs: f.DeviceInfo.Os, DeviceOs: f.DeviceInfo.Os,
DeviceType: f.DeviceInfo.Type, DeviceType: f.DeviceInfo.Type,
Verifier: verifier,
}, 5*60) }, 5*60)
//fmt.Println("code url", code, url) //fmt.Println("code url", code, url)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": code, "code": state,
"url": url, "url": url,
}) })
} }
@@ -156,10 +158,11 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
} }
op := oauthCache.Op op := oauthCache.Op
action := oauthCache.Action action := oauthCache.Action
verifier := oauthCache.Verifier
var user *model.User var user *model.User
// 获取用户信息 // 获取用户信息
code := c.Query("code") code := c.Query("code")
err, oauthUser := oauthService.Callback(code, op) err, oauthUser := oauthService.Callback(code, verifier, op)
if err != nil { if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
return return

View File

@@ -24,6 +24,8 @@ type OauthForm struct {
ClientSecret string `json:"client_secret" validate:"required"` ClientSecret string `json:"client_secret" validate:"required"`
RedirectUrl string `json:"redirect_url" validate:"required"` RedirectUrl string `json:"redirect_url" validate:"required"`
AutoRegister *bool `json:"auto_register"` AutoRegister *bool `json:"auto_register"`
PkceEnable *bool `json:"pkce_enable"`
PkceMethod string `json:"pkce_method"`
} }
func (of *OauthForm) ToOauth() *model.Oauth { func (of *OauthForm) ToOauth() *model.Oauth {
@@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth {
AutoRegister: of.AutoRegister, AutoRegister: of.AutoRegister,
Issuer: of.Issuer, Issuer: of.Issuer,
Scopes: of.Scopes, Scopes: of.Scopes,
PkceEnable: of.PkceEnable,
PkceMethod: of.PkceMethod,
} }
oa.Id = of.Id oa.Id = of.Id
return oa return oa

View File

@@ -14,6 +14,8 @@ const (
OauthTypeGoogle string = "google" OauthTypeGoogle string = "google"
OauthTypeOidc string = "oidc" OauthTypeOidc string = "oidc"
OauthTypeWebauth string = "webauth" OauthTypeWebauth string = "webauth"
PKCEMethodS256 string = "S256"
PKCEMethodPlain string = "plain"
) )
// Validate the oauth type // Validate the oauth type
@@ -41,6 +43,8 @@ type Oauth struct {
AutoRegister *bool `json:"auto_register"` AutoRegister *bool `json:"auto_register"`
Scopes string `json:"scopes"` Scopes string `json:"scopes"`
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
PkceEnable *bool `json:"pkce_enable"`
PkceMethod string `json:"pkce_method"`
TimeModel TimeModel
} }
@@ -68,6 +72,13 @@ func (oa *Oauth) FormatOauthInfo() error {
if oauthType == OauthTypeGoogle && issuer == "" { if oauthType == OauthTypeGoogle && issuer == "" {
oa.Issuer = IssuerGoogle oa.Issuer = IssuerGoogle
} }
if oa.PkceEnable == nil {
oa.PkceEnable = new(bool)
*oa.PkceEnable = false
}
if oa.PkceMethod == "" {
oa.PkceMethod = PKCEMethodS256
}
return nil return nil
} }

View File

@@ -45,6 +45,7 @@ type OauthCacheItem struct {
Username string `json:"username"` Username string `json:"username"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
Verifier string `json:"verifier"` // used for oauth pkce
} }
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) {
OauthCache.Delete(key) OauthCache.Delete(key)
} }
func (os *OauthService) BeginAuth(op string) (error error, code, url string) { func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
verifier = ""
if op == string(model.OauthTypeWebauth) { if op == string(model.OauthTypeWebauth) {
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
//url = "http://localhost:8888/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code
return nil, code, url return nil, state, verifier, url
} }
err, _, oauthConfig := os.GetOauthConfig(op) err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
if err == nil { if err == nil {
return err, code, oauthConfig.AuthCodeURL(code) extras := make([]oauth2.AuthCodeOption, 0, 3)
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
extras = append(extras, oauth2.AccessTypeOffline)
verifier = oauth2.GenerateVerifier()
switch oauthInfo.PkceMethod {
case model.PKCEMethodS256:
extras = append(extras, oauth2.S256ChallengeOption(verifier))
case model.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
}
}
return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
} }
return err, code, "" return err, state, verifier, ""
} }
// Method to fetch OIDC configuration dynamically // Method to fetch OIDC configuration dynamically
@@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client {
return http.DefaultClient return http.DefaultClient
} }
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) { func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
// 设置代理客户端 // 设置代理客户端
httpClient := getHTTPClientWithProxy() httpClient := getHTTPClientWithProxy()
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
var exchangeOpts []oauth2.AuthCodeOption
if verifier != "" {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
}
// 使用 code 换取 token // 使用 code 换取 token
var token *oauth2.Token var token *oauth2.Token
token, err = oauthConfig.Exchange(ctx, code) token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
if err != nil { if err != nil {
global.Logger.Warn("oauthConfig.Exchange() failed: ", err) global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
return errors.New("GetOauthTokenError"), nil return errors.New("GetOauthTokenError"), nil
@@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
} }
// githubCallback github回调 // githubCallback github回调
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
var user = &model.GithubUser{} var user = &model.GithubUser{}
err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user) err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
} }
// oidcCallback oidc回调, 通过code获取用户信息 // oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) { func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
var user = &model.OidcUser{} var user = &model.OidcUser{}
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
return err, nil return err, nil
} }
return nil, user.ToOauthUser() return nil, user.ToOauthUser()
} }
// Callback: Get user information by code and op(Oauth provider) // Callback: Get user information by code and op(Oauth provider)
func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
var oauthInfo *model.Oauth var oauthInfo *model.Oauth
var oauthConfig *oauth2.Config var oauthConfig *oauth2.Config
err, oauthInfo, oauthConfig = os.GetOauthConfig(op) err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
@@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
oauthType := oauthInfo.OauthType oauthType := oauthInfo.OauthType
switch oauthType { switch oauthType {
case model.OauthTypeGithub: case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code) err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
case model.OauthTypeOidc, model.OauthTypeGoogle: case model.OauthTypeOidc, model.OauthTypeGoogle:
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil { if err != nil {
return err, nil return err, nil
} }
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
default: default:
return errors.New("unsupported OAuth type"), nil return errors.New("unsupported OAuth type"), nil
} }