diff --git a/http/controller/admin/login.go b/http/controller/admin/login.go index a67738d..0ff8deb 100644 --- a/http/controller/admin/login.go +++ b/http/controller/admin/login.go @@ -169,8 +169,8 @@ func (ct *Login) LoginOptions(c *gin.Context) { "ops": ops, "register": global.Config.App.Register, "need_captcha": needCaptcha, - "disable_pwd": global.Config.App.DisablePwdLogin, - "auto_oidc": global.Config.App.DisablePwdLogin && len(ops) == 1, + "disable_pwd": global.Config.App.DisablePwdLogin, + "auto_oidc": global.Config.App.DisablePwdLogin && len(ops) == 1, }) } @@ -191,7 +191,7 @@ func (ct *Login) OidcAuth(c *gin.Context) { return } - err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(c, f.Op) + err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return diff --git a/http/controller/admin/oauth.go b/http/controller/admin/oauth.go index 6ae1c3a..811fabb 100644 --- a/http/controller/admin/oauth.go +++ b/http/controller/admin/oauth.go @@ -44,7 +44,7 @@ func (o *Oauth) ToBind(c *gin.Context) { return } - err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(c, f.Op) + err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return diff --git a/http/controller/api/ouath.go b/http/controller/api/ouath.go index 0ea2ebd..7a61d81 100644 --- a/http/controller/api/ouath.go +++ b/http/controller/api/ouath.go @@ -36,7 +36,7 @@ func (o *Oauth) OidcAuth(c *gin.Context) { oauthService := service.AllService.OauthService - err, state, verifier, nonce, url := oauthService.BeginAuth(c, f.Op) + err, state, verifier, nonce, url := oauthService.BeginAuth(f.Op) if err != nil { response.Error(c, response.TranslateMsg(c, err.Error())) return @@ -170,7 +170,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) { var user *model.User // 获取用户信息 code := c.Query("code") - err, oauthUser := oauthService.Callback(c, code, verifier, op, nonce) + err, oauthUser := oauthService.Callback(code, verifier, op, nonce) if err != nil { c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ "message": "OauthFailed", diff --git a/model/oauth.go b/model/oauth.go index 98a4be1..6294384 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -41,6 +41,7 @@ type Oauth struct { OauthType string `json:"oauth_type"` ClientId string `json:"client_id"` ClientSecret string `json:"client_secret"` + //RedirectUrl string `json:"redirect_url"` AutoRegister *bool `json:"auto_register"` Scopes string `json:"scopes"` Issuer string `json:"issuer"` diff --git a/service/oauth.go b/service/oauth.go index fbeff8b..48c7ddd 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -6,7 +6,6 @@ import ( "errors" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gin-gonic/gin" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/utils" "golang.org/x/oauth2" @@ -96,20 +95,16 @@ func (os *OauthService) DeleteOauthCache(key string) { OauthCache.Delete(key) } -func (os *OauthService) BeginAuth(c *gin.Context, op string) (error error, state, verifier, nonce, url string) { +func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonce, url string) { state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) verifier = "" nonce = "" if op == model.OauthTypeWebauth { - host := c.GetHeader("Origin") - if host == "" { - host = Config.Rustdesk.ApiServer - } - url = host + "/_admin/#/oauth/" + state + url = Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, state, verifier, nonce, url } - err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(c, op) + err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op) if err == nil { extras := make([]oauth2.AuthCodeOption, 0, 3) @@ -174,18 +169,16 @@ func (os *OauthService) LinuxdoProvider() *oidc.Provider { } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) GetOauthConfig(c *gin.Context, op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) { +func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) { //err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) oauthInfo = os.InfoByOp(op) if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" { return errors.New("ConfigNotFound"), nil, nil, nil } - redirectUrl := os.buildRedirectURL(c) - Logger.Debug("Redirect URL: ", redirectUrl) oauthConfig = &oauth2.Config{ ClientID: oauthInfo.ClientId, ClientSecret: oauthInfo.ClientSecret, - RedirectURL: redirectUrl, + RedirectURL: Config.Rustdesk.ApiServer + "/api/oidc/callback", } // Maybe should validate the oauthConfig here @@ -340,8 +333,8 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc. } // Callback: Get user information by code and op(Oauth provider) -func (os *OauthService) Callback(c *gin.Context, code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) { - err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(c, op) +func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) { + err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op) // oauthType is already validated in GetOauthConfig if err != nil { return err, nil @@ -527,22 +520,3 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m return fmt.Errorf("no primary verified email found") } - -func (os *OauthService) buildRedirectURL(c *gin.Context) string { - baseUrl := Config.Rustdesk.ApiServer - host := c.Request.Host - - if host != "" { - scheme := c.GetHeader("X-Forwarded-Proto") - if scheme == "" { - if c.Request.TLS != nil { - scheme = "https" - } else { - scheme = "http" - } - } - baseUrl = fmt.Sprintf("%s://%s", scheme, host) - } - - return fmt.Sprintf("%s/api/oidc/callback", baseUrl) -} \ No newline at end of file