From f0b4b0d7c6fd4c9ea77049f6cf7365e702e9c136 Mon Sep 17 00:00:00 2001 From: lejianwen <84855512@qq.com> Date: Fri, 21 Feb 2025 09:49:41 +0800 Subject: [PATCH] style(oidc): Oidc style --- go.mod | 2 + service/oauth.go | 179 +++++++++++++++++++++++++---------------------- 2 files changed, 97 insertions(+), 84 deletions(-) diff --git a/go.mod b/go.mod index c05d998..67bd924 100644 --- a/go.mod +++ b/go.mod @@ -36,9 +36,11 @@ require ( github.com/bytedance/sonic v1.8.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/coreos/go-oidc/v3 v3.12.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-ldap/ldap/v3 v3.4.10 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonreference v0.19.6 // indirect diff --git a/service/oauth.go b/service/oauth.go index b99ec23..17f90a2 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "github.com/coreos/go-oidc/v3/oidc" "github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/utils" @@ -45,7 +46,7 @@ type OauthCacheItem struct { Username string `json:"username"` Name string `json:"name"` Email string `json:"email"` - Verifier string `json:"verifier"` // used for oauth pkce + Verifier string `json:"verifier"` // used for oauth pkce } func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { @@ -82,10 +83,9 @@ func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) { OauthCache.Store(key, item) if expire > 0 { - go func() { - time.Sleep(time.Duration(expire) * time.Second) + time.AfterFunc(time.Duration(expire)*time.Second, func() { os.DeleteOauthCache(key) - }() + }) } } @@ -96,12 +96,12 @@ func (os *OauthService) DeleteOauthCache(key string) { func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) { state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) verifier = "" - if op == string(model.OauthTypeWebauth) { + if op == model.OauthTypeWebauth { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, state, verifier, url } - err, oauthInfo, oauthConfig := os.GetOauthConfig(op) + err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op) if err == nil { extras := make([]oauth2.AuthCodeOption, 0, 3) if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable { @@ -121,88 +121,80 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url return err, state, verifier, "" } -// Method to fetch OIDC configuration dynamically -func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) { - configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" +func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) { // Get the HTTP client (with or without proxy based on configuration) client := getHTTPClientWithProxy() - resp, err := client.Get(configURL) + ctx := oidc.ClientContext(context.Background(), client) + + provider, err := oidc.NewProvider(ctx, issuer) if err != nil { - return errors.New("failed to fetch OIDC configuration"), OidcEndpoint{} - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.New("OIDC configuration not found, status code: %d"), OidcEndpoint{} + return err, nil } - var endpoint OidcEndpoint - if err := json.NewDecoder(resp.Body).Decode(&endpoint); err != nil { - return errors.New("failed to parse OIDC configuration"), OidcEndpoint{} - } - - return nil, endpoint + return nil, provider } -func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { - oauthInfo := os.InfoByOp(op) - if oauthInfo.Issuer == "" { - return errors.New("issuer is empty"), OidcEndpoint{} - } - return os.FetchOidcEndpoint(oauthInfo.Issuer) +func (os *OauthService) GithubProvider() *oidc.Provider { + return (&oidc.ProviderConfig{ + IssuerURL: "", + AuthURL: github.Endpoint.AuthURL, + TokenURL: github.Endpoint.TokenURL, + DeviceAuthURL: github.Endpoint.DeviceAuthURL, + UserInfoURL: model.UserEndpointGithub, + JWKSURL: "", + Algorithms: nil, + }).NewProvider(context.Background()) } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { - err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) - if err != nil { - return err, nil, nil - } - // Maybe should validate the oauthConfig here - oauthType := oauthInfo.OauthType - err = model.ValidateOauthType(oauthType) - if err != nil { - return err, nil, nil - } - switch oauthType { - case model.OauthTypeGithub: - oauthConfig.Endpoint = github.Endpoint - oauthConfig.Scopes = []string{"read:user", "user:email"} - case model.OauthTypeOidc, model.OauthTypeGoogle: - var endpoint OidcEndpoint - err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer) - if err != nil { - return err, nil, nil - } - oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL} - oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) - default: - return errors.New("unsupported OAuth type"), nil, nil - } - return nil, oauthInfo, oauthConfig -} - -// GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { +func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) { + //err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) oauthInfo = os.InfoByOp(op) if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" { - return errors.New("ConfigNotFound"), nil, nil + return errors.New("ConfigNotFound"), nil, nil, nil } // If the redirect URL is empty, use the default redirect URL if oauthInfo.RedirectUrl == "" { oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" } - return nil, oauthInfo, &oauth2.Config{ + oauthConfig = &oauth2.Config{ ClientID: oauthInfo.ClientId, ClientSecret: oauthInfo.ClientSecret, RedirectURL: oauthInfo.RedirectUrl, } + + // Maybe should validate the oauthConfig here + oauthType := oauthInfo.OauthType + err = model.ValidateOauthType(oauthType) + if err != nil { + return err, nil, nil, nil + } + switch oauthType { + case model.OauthTypeGithub: + oauthConfig.Endpoint = github.Endpoint + oauthConfig.Scopes = []string{"read:user", "user:email"} + provider = os.GithubProvider() + //case model.OauthTypeGoogle: //google单独出来,可以少一次FetchOidcEndpoint请求 + // oauthConfig.Endpoint = google.Endpoint + // oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) + case model.OauthTypeOidc, model.OauthTypeGoogle: + err, provider = os.FetchOidcProvider(oauthInfo.Issuer) + if err != nil { + return err, nil, nil, nil + } + oauthConfig.Endpoint = provider.Endpoint() + oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) + default: + return errors.New("unsupported OAuth type"), nil, nil, nil + } + return nil, oauthInfo, oauthConfig, provider } func getHTTPClientWithProxy() *http.Client { - //todo add timeout + //add timeout 30s + timeout := time.Duration(60) * time.Second if global.Config.Proxy.Enable { if global.Config.Proxy.Host == "" { global.Logger.Warn("Proxy is enabled but proxy host is empty.") @@ -216,33 +208,58 @@ func getHTTPClientWithProxy() *http.Client { transport := &http.Transport{ Proxy: http.ProxyURL(proxyURL), } - return &http.Client{Transport: transport} + return &http.Client{Transport: transport, Timeout: timeout} } return http.DefaultClient } - -func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) { +func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string, nonce string, userData interface{}) (err error, client *http.Client) { // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - var exchangeOpts []oauth2.AuthCodeOption + exchangeOpts := make([]oauth2.AuthCodeOption, 0, 1) if verifier != "" { - exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)} + exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier)) } - // 使用 code 换取 token - var token *oauth2.Token - token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...) + token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...) + if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) return errors.New("GetOauthTokenError"), nil } + // 获取 ID Token, github没有id_token + rawIDToken, ok := token.Extra("id_token").(string) + if ok && rawIDToken != "" { + // 验证 ID Token + v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}) + idToken, err2 := v.Verify(ctx, rawIDToken) + if err2 != nil { + global.Logger.Warn("IdTokenVerifyError: ", err2) + return errors.New("IdTokenVerifyError"), nil + } + if nonce != "" { + // 验证 nonce + var claims struct { + Nonce string `json:"nonce"` + } + if err2 = idToken.Claims(&claims); err2 != nil { + global.Logger.Warn("Failed to parse ID Token claims: ", err) + return errors.New("IDTokenClaimsError"), nil + } + + if claims.Nonce != nonce { + global.Logger.Warn("Nonce does not match") + return errors.New("NonceDoesNotMatch"), nil + } + } + } + // 获取用户信息 client = oauthConfig.Client(ctx, token) - resp, err := client.Get(userEndpoint) + resp, err := client.Get(provider.UserInfoEndpoint()) if err != nil { global.Logger.Warn("failed getting user info: ", err) return errors.New("GetOauthUserInfoError"), nil @@ -263,9 +280,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, ve } // githubCallback github回调 -func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) { +func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { var user = &model.GithubUser{} - err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user) + err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user) if err != nil { return err, nil } @@ -277,9 +294,9 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, } // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) { var user = &model.OidcUser{} - if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil { + if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil { return err, nil } return nil, user.ToOauthUser() @@ -287,9 +304,7 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, ve // Callback: Get user information by code and op(Oauth provider) func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) { - var oauthInfo *model.Oauth - var oauthConfig *oauth2.Config - err, oauthInfo, oauthConfig = os.GetOauthConfig(op) + err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op) // oauthType is already validated in GetOauthConfig if err != nil { return err, nil @@ -297,13 +312,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse oauthType := oauthInfo.OauthType switch oauthType { case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(oauthConfig, code, verifier) + err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier) case model.OauthTypeOidc, model.OauthTypeGoogle: - err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) - if err != nil { - return err, nil - } - err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo) + err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier) default: return errors.New("unsupported OAuth type"), nil }