diff --git a/model/oauth.go b/model/oauth.go index 816a01c..61ce845 100644 --- a/model/oauth.go +++ b/model/oauth.go @@ -13,6 +13,18 @@ const ( OauthTypeWebauth string = "webauth" ) +const ( + OauthNameGithub string = "GitHub" + OauthNameGoogle string = "Google" + OauthNameOidc string = "OIDC" + OauthNameWebauth string = "WebAuth" +) + +const ( + UserEndpointGithub string = "https://api.github.com/user" + UserEndpointGoogle string = "https://www.googleapis.com/oauth2/v3/userinfo" + UserEndpointOidc string = "" +) type Oauth struct { IdModel @@ -33,6 +45,7 @@ type OauthUser struct { Username string `json:"username"` Email string `json:"email"` VerifiedEmail bool `json:"verified_email,omitempty"` + Picture string `json:"picture,omitempty"` } func (ou *OauthUser) ToUser(user *User, overideUsername bool) { @@ -56,6 +69,7 @@ type OidcUser struct { Sub string `json:"sub"` VerifiedEmail bool `json:"email_verified"` PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` } func (ou *OidcUser) ToOauthUser() *OauthUser { @@ -65,6 +79,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser { Username: ou.PreferredUsername, Email: ou.Email, VerifiedEmail: ou.VerifiedEmail, + Picture: ou.Picture, } } @@ -84,6 +99,7 @@ func (gu *GoogleUser) ToOauthUser() *OauthUser { Username: gu.GivenName, Email: gu.Email, VerifiedEmail: gu.VerifiedEmail, + Picture: gu.Picture, } } @@ -92,6 +108,8 @@ type GithubUser struct { OauthUserBase Id int `json:"id"` Login string `json:"login"` + AvatarUrl string `json:"avatar_url"` + VerifiedEmail bool `json:"verified_email"` } func (gu *GithubUser) ToOauthUser() *OauthUser { @@ -100,7 +118,7 @@ func (gu *GithubUser) ToOauthUser() *OauthUser { Name: gu.Name, Username: gu.Login, Email: gu.Email, - VerifiedEmail: true, + VerifiedEmail: gu.VerifiedEmail, } } diff --git a/service/oauth.go b/service/oauth.go index 446f61c..6cfa083 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -106,15 +106,14 @@ func (os *OauthService) DeleteOauthCache(key string) { func (os *OauthService) BeginAuth(op string) (error error, code, url string) { code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10) - if op == string(model.OauthTypeWebauth) { url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code //url = "http://localhost:8888/_admin/#/oauth/" + code return nil, code, url } - err, _, conf := os.GetOauthConfig(op) + err, _, oauthConfig := os.GetOauthConfig(op) if err == nil { - return err, code, conf.AuthCodeURL(code) + return err, code, oauthConfig.AuthCodeURL(code) } return err, code, "" @@ -154,16 +153,17 @@ func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) { } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { - err = os.ValidateOauthProvider(op) +func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { + err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) if err != nil { - return err, "", nil - } - err, oauthType, oauthConfig = os.getOauthConfigGeneral(op) - if err != nil { - return err, oauthType, nil + return err, nil, nil } // Maybe should validate the oauthConfig here + oauthType := oauthInfo.OauthType + err = os.ValidateOauthType(oauthType) + if err != nil { + return err, nil, nil + } switch oauthType { case model.OauthTypeGithub: oauthConfig.Endpoint = github.Endpoint @@ -172,32 +172,33 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig.Endpoint = google.Endpoint oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"} case model.OauthTypeOidc: - err, endpoint := os.FetchOidcEndpointByOp(op) + var endpoint OidcEndpoint + err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer) if err != nil { - return err,oauthType, nil + return err, nil, nil } oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} - oauthConfig.Scopes = os.getScopesByOp(op) + oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) default: - return errors.New("unsupported OAuth type"), oauthType, nil + return errors.New("unsupported OAuth type"), nil, nil } - return nil, oauthType, oauthConfig + return nil, oauthInfo, oauthConfig } // GetOauthConfig retrieves the OAuth2 configuration based on the provider name -func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) { - g := os.InfoByOp(op) - if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" { - return errors.New("ConfigNotFound"), "", nil +func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) { + oauthInfo = os.InfoByOp(op) + if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" { + return errors.New("ConfigNotFound"), nil, nil } // If the redirect URL is empty, use the default redirect URL - if g.RedirectUrl == "" { - g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" + if oauthInfo.RedirectUrl == "" { + oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback" } - return nil, g.OauthType, &oauth2.Config{ - ClientID: g.ClientId, - ClientSecret: g.ClientSecret, - RedirectURL: g.RedirectUrl, + return nil, oauthInfo, &oauth2.Config{ + ClientID: oauthInfo.ClientId, + ClientSecret: oauthInfo.ClientSecret, + RedirectURL: oauthInfo.RedirectUrl, } } @@ -221,40 +222,26 @@ func getHTTPClientWithProxy() *http.Client { return http.DefaultClient } -func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error { - err, oauthType, oauthConfig := os.GetOauthConfig(op) - if err != nil { - return err - } - - // If the OAuth type is OIDC and the user endpoint is empty - // Fetch the OIDC configuration and get the user endpoint - if oauthType == model.OauthTypeOidc && userEndpoint == "" { - err, endpoint := os.FetchOidcEndpointByOp(op) - if err != nil { - global.Logger.Warn("failed fetching OIDC configuration: ", err) - return errors.New("FetchOidcEndpointError") - } - userEndpoint = endpoint.UserInfo - } +func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) { // 设置代理客户端 httpClient := getHTTPClientWithProxy() ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) // 使用 code 换取 token - token, err := oauthConfig.Exchange(ctx, code) + var token *oauth2.Token + token, err = oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) - return errors.New("GetOauthTokenError") + return errors.New("GetOauthTokenError"), nil } // 获取用户信息 - client := oauthConfig.Client(ctx, token) + client = oauthConfig.Client(ctx, token) resp, err := client.Get(userEndpoint) if err != nil { global.Logger.Warn("failed getting user info: ", err) - return errors.New("GetOauthUserInfoError") + return errors.New("GetOauthUserInfoError"), nil } defer func() { if closeErr := resp.Body.Close(); closeErr != nil { @@ -265,36 +252,39 @@ func (os *OauthService) callbackBase(op string, code string, userEndpoint string // 解析用户信息 if err = json.NewDecoder(resp.Body).Decode(userData); err != nil { global.Logger.Warn("failed decoding user info: ", err) - return errors.New("DecodeOauthUserInfoError") + return errors.New("DecodeOauthUserInfoError"), nil } - return nil + return nil, client } // githubCallback github回调 -func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) { +func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { var user = &model.GithubUser{} - const userEndpoint = "https://api.github.com/user" - if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil { + err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user) + if err != nil { + return err, nil + } + err = os.getGithubPrimaryEmail(client, user) + if err != nil { return err, nil } return nil, user.ToOauthUser() } // googleCallback google回调 -func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) { +func (os *OauthService) googleCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) { var user = &model.GoogleUser{} - const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo" - if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil { + if err, _ := os.callbackBase(oauthConfig, code, model.UserEndpointGoogle, user); err != nil { return err, nil } return nil, user.ToOauthUser() } // oidcCallback oidc回调, 通过code获取用户信息 -func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) { +func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) { var user = &model.OidcUser{} - if err := os.callbackBase(op, code, "", user); err != nil { + if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { return err, nil } return nil, user.ToOauthUser() @@ -302,22 +292,28 @@ func (os *OauthService) oidcCallback(code string, op string) (error, *model.Oaut // Callback: Get user information by code and op(Oauth provider) func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) { - oauthType := os.GetTypeByOp(op) - if err = os.ValidateOauthType(oauthType); err != nil { - return err, nil - } - - switch oauthType { + var oauthInfo *model.Oauth + var oauthConfig *oauth2.Config + err, oauthInfo, oauthConfig = os.GetOauthConfig(op) + // oauthType is already validated in GetOauthConfig + if err != nil { + return err, nil + } + oauthType := oauthInfo.OauthType + switch oauthType { case model.OauthTypeGithub: - err, oauthUser = os.githubCallback(code) + err, oauthUser = os.githubCallback(oauthConfig, code) case model.OauthTypeGoogle: - err, oauthUser = os.googleCallback(code) + err, oauthUser = os.googleCallback(oauthConfig, code) case model.OauthTypeOidc: - err, oauthUser = os.oidcCallback(code, op) + err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) + if err != nil { + return err, nil + } + err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) default: return errors.New("unsupported OAuth type"), nil } - return err, oauthUser } @@ -331,7 +327,10 @@ func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird // BindOauthUser: Bind third party account func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error { utr := &model.UserThird{} - oauthType := os.GetTypeByOp(op) + err, oauthType := os.GetTypeByOp(op) + if err != nil { + return err + } utr.FromOauthUser(userId, oauthUser, oauthType, op) return global.DB.Create(utr).Error } @@ -368,14 +367,18 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth { // Helper function to get scopes by operation func (os *OauthService) getScopesByOp(op string) []string { scopes := os.InfoByOp(op).Scopes - scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量 + return os.constructScopes(scopes) +} + +// Helper function to construct scopes +func (os *OauthService) constructScopes(scopes string) []string { + scopes = strings.TrimSpace(scopes) if scopes == "" { scopes = "openid,profile,email" } return strings.Split(scopes, ",") } - func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { res = &model.OauthList{} res.Page = int64(page) @@ -391,21 +394,30 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res } // GetTypeByOp 根据op获取OauthType -func (os *OauthService) GetTypeByOp(op string) string { +func (os *OauthService) GetTypeByOp(op string) (error, string) { oauthInfo := &model.Oauth{} if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil { - return "" + return fmt.Errorf("OAuth provider with op '%s' not found", op), "" } - return oauthInfo.OauthType + return nil, oauthInfo.OauthType } +// ValidateOauthProvider 验证Oauth提供者是否正确 func (os *OauthService) ValidateOauthProvider(op string) error { + if !os.IsOauthProviderExist(op) { + return fmt.Errorf("OAuth provider with op '%s' not found", op) + } + return nil +} + +// IsOauthProviderExist 验证Oauth提供者是否存在 +func (os *OauthService) IsOauthProviderExist(op string) bool { oauthInfo := &model.Oauth{} - // 使用 Gorm 的 Take 方法查找符合条件的记录 - if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { - return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err) - } - return nil + // 使用 Gorm 的 Take 方法查找符合条件的记录 + if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil { + return false + } + return true } // Create 创建 @@ -427,4 +439,41 @@ func (os *OauthService) GetOauthProviders() []string { var res []string global.DB.Model(&model.Oauth{}).Pluck("op", &res) return res +} + +// getGithubPrimaryEmail: Get the primary email of the user from Github +func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *model.GithubUser) error { + // the client is already set with the token + resp, err := client.Get("https://api.github.com/user/emails") + if err != nil { + return fmt.Errorf("failed to fetch emails: %w", err) + } + defer resp.Body.Close() + + // check the response status code + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to fetch emails: %s", resp.Status) + } + + // decode the response + var emails []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` + } + + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + + // find the primary verified email + for _, e := range emails { + if e.Primary && e.Verified { + githubUser.Email = e.Email + githubUser.VerifiedEmail = e.Verified + return nil + } + } + + return fmt.Errorf("no primary verified email found") } \ No newline at end of file