diff --git a/cmd/apimain.go b/cmd/apimain.go index 41a6999..c0d6672 100644 --- a/cmd/apimain.go +++ b/cmd/apimain.go @@ -166,7 +166,7 @@ func InitGlobal() { global.Lock = lock.NewLocal() } func DatabaseAutoUpdate() { - version := 260 + version := 261 db := global.DB @@ -210,6 +210,12 @@ func DatabaseAutoUpdate() { if v.Version < uint(version) { Migrate(uint(version)) } + // 261迁移 + if v.Version < 261 { + // 在oauths表中添加pkce_enable 和 pkce_method 字段 + db.Exec("ALTER TABLE oauths ADD COLUMN pkce_enable TINYINT(1) NOT NULL DEFAULT 0") + db.Exec("ALTER TABLE oauths ADD COLUMN pkce_method VARCHAR(20) NOT NULL DEFAULT 'S256'") + } // 245迁移 if v.Version < 245 { //oauths 表的 oauth_type 字段设置为 op同样的值 diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index 68fffef..3ca71a9 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -283,13 +283,13 @@ func (ct *Login) OidcAuth(c *gin.Context) { return } - err, code, url := service.AllService.OauthService.BeginAuth(f.Op) + err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return } - service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ + service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{ Action: service.OauthActionTypeLogin, Op: f.Op, Id: f.Id, @@ -297,10 +297,11 @@ func (ct *Login) OidcAuth(c *gin.Context) { // DeviceOs: ct.Platform(c), DeviceOs: f.DeviceInfo.Os, Uuid: f.Uuid, + Verifier: verifier, }, 5*60) response.Success(c, gin.H{ - "code": code, + "code": state, "url": url, }) } diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index eb2b4fc..88716d5 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -43,20 +43,21 @@ func (o *Oauth) ToBind(c *gin.Context) { return } - err, code, url := service.AllService.OauthService.BeginAuth(f.Op) + err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return } - service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ + service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{ Action: service.OauthActionTypeBind, - Op: f.Op, - UserId: u.Id, + Op: f.Op, + UserId: u.Id, + Verifier: verifier, }, 5*60) response.Success(c, gin.H{ - "code": code, + "code": state, "url": url, }) } diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index b1eaa41..d18070f 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -32,15 +32,16 @@ func (o *Oauth) OidcAuth(c *gin.Context) { } oauthService := service.AllService.OauthService - var code string + var state string var url string - err, code, url = oauthService.BeginAuth(f.Op) + var verifier string + err, state, verifier, url = oauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return } - service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{ + service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{ Action: service.OauthActionTypeLogin, Id: f.Id, Op: f.Op, @@ -48,10 +49,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) { DeviceName: f.DeviceInfo.Name, DeviceOs: f.DeviceInfo.Os, DeviceType: f.DeviceInfo.Type, + Verifier: verifier, }, 5*60) //fmt.Println("code url", code, url) c.JSON(http.StatusOK, gin.H{ - "code": code, + "code": state, "url": url, }) } @@ -156,10 +158,11 @@ func (o *Oauth) OauthCallback(c *gin.Context) { } op := oauthCache.Op action := oauthCache.Action + verifier := oauthCache.Verifier var user *model.User // 获取用户信息 code := c.Query("code") - err, oauthUser := oauthService.Callback(code, op) + err, oauthUser := oauthService.Callback(code, verifier, op) if err != nil { c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error())) return diff --git a/http/request/admin/oauth.go b/http/request/admin/oauth.go index 866b37a..00da0a3 100644 --- a/http/request/admin/oauth.go +++ b/http/request/admin/oauth.go @@ -24,6 +24,8 @@ type OauthForm struct { ClientSecret string `json:"client_secret" validate:"required"` RedirectUrl string `json:"redirect_url" validate:"required"` AutoRegister *bool `json:"auto_register"` + PkceEnable *bool `json:"pkce_enable"` + PkceMethod string `json:"pkce_method"` } func (of *OauthForm) ToOauth() *model.Oauth { @@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth { AutoRegister: of.AutoRegister, Issuer: of.Issuer, Scopes: of.Scopes, + PkceEnable: of.PkceEnable, + PkceMethod: of.PkceMethod, } oa.Id = of.Id return oa diff --git a/model/oauth.go b/model/oauth.go index b6004ab..29d8014 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -14,6 +14,8 @@ const ( OauthTypeGoogle string = "google" OauthTypeOidc string = "oidc" OauthTypeWebauth string = "webauth" + PKCEMethodS256 string = "S256" + PKCEMethodPlain string = "plain" ) // Validate the oauth type @@ -41,6 +43,8 @@ type Oauth struct { AutoRegister *bool `json:"auto_register"` Scopes string `json:"scopes"` Issuer string `json:"issuer"` + PkceEnable *bool `json:"pkce_enable"` + PkceMethod string `json:"pkce_method"` TimeModel } @@ -68,6 +72,13 @@ func (oa *Oauth) FormatOauthInfo() error { if oauthType == OauthTypeGoogle && issuer == "" { oa.Issuer = IssuerGoogle } + if oa.PkceEnable == nil { + oa.PkceEnable = new(bool) + *oa.PkceEnable = false + } + if oa.PkceMethod == "" { + oa.PkceMethod = PKCEMethodS256 + } return nil } diff --git a/service/oauth.go b/service/oauth.go index f1d97aa..b99ec23 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -45,6 +45,7 @@ type OauthCacheItem struct { Username string `json:"username"` Name string `json:"name"` Email string `json:"email"` + Verifier string `json:"verifier"` // used for oauth pkce } func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { @@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) { OauthCache.Delete(key) } -func (os *OauthService) BeginAuth(op string) (error error, code, url string) { - code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) +func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) { + state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) + verifier = "" if op == string(model.OauthTypeWebauth) { - url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code + url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state //url = "http://localhost:8888/_admin/#/oauth/" + code - return nil, code, url + return nil, state, verifier, url } - err, _, oauthConfig := os.GetOauthConfig(op) + err, oauthInfo, oauthConfig := os.GetOauthConfig(op) if err == nil { - return err, code, oauthConfig.AuthCodeURL(code) + extras := make([]oauth2.AuthCodeOption, 0, 3) + if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable { + extras = append(extras, oauth2.AccessTypeOffline) + verifier = oauth2.GenerateVerifier() + switch oauthInfo.PkceMethod { + case model.PKCEMethodS256: + extras = append(extras, oauth2.S256ChallengeOption(verifier)) + case model.PKCEMethodPlain: + // oauth2 does not have a plain challenge option, so we add it manually + extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) + } + } + return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...) } - return err, code, "" + return err, state, verifier, "" } // Method to fetch OIDC configuration dynamically @@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client { return http.DefaultClient } -func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) { +func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) { // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + var exchangeOpts []oauth2.AuthCodeOption + if verifier != "" { + exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)} + } + // 使用 code 换取 token var token *oauth2.Token - token, err = oauthConfig.Exchange(ctx, code) + token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) return errors.New("GetOauthTokenError"), nil @@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us } // githubCallback github回调 -func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { +func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) { var user = &model.GithubUser{} - err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user) + err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user) if err != nil { return err, nil } @@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) } // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) { var user = &model.OidcUser{} - if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { + if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil { return err, nil } return nil, user.ToOauthUser() } // Callback: Get user information by code and op(Oauth provider) -func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { +func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) { var oauthInfo *model.Oauth var oauthConfig *oauth2.Config err, oauthInfo, oauthConfig = os.GetOauthConfig(op) @@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser * oauthType := oauthInfo.OauthType switch oauthType { case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(oauthConfig, code) + err, oauthUser = os.githubCallback(oauthConfig, code, verifier) case model.OauthTypeOidc, model.OauthTypeGoogle: err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { return err, nil } - err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) + err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo) default: return errors.New("unsupported OAuth type"), nil }