modify google ro re-use oidc

This commit is contained in:
Tao Chen
2024-11-04 21:30:58 +08:00
parent 5a53f180e4
commit 3acfb36c5d
3 changed files with 59 additions and 55 deletions

View File

@@ -2,7 +2,6 @@ package admin
import (
"Gwen/model"
"strings"
)
type BindOauthForm struct {
@@ -28,22 +27,6 @@ type OauthForm struct {
}
func (of *OauthForm) ToOauth() *model.Oauth {
op := strings.ToLower(of.Op)
op = strings.TrimSpace(op)
if op == "" {
switch of.OauthType {
case model.OauthTypeGithub:
of.Op = model.OauthNameGithub
case model.OauthTypeGoogle:
of.Op = model.OauthNameGoogle
case model.OauthTypeOidc:
of.Op = model.OauthNameOidc
case model.OauthTypeWebauth:
of.Op = model.OauthNameWebauth
default:
of.Op = of.OauthType
}
}
oa := &model.Oauth{
Op: of.Op,
OauthType: of.OauthType,

View File

@@ -3,17 +3,29 @@ package model
import (
"strconv"
"strings"
"errors"
)
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
const (
// make sure the value shouldbe lowercase
OauthTypeGithub string = "github"
OauthTypeGoogle string = "google"
OauthTypeOidc string = "oidc"
OauthTypeWebauth string = "webauth"
)
// Validate the oauth type
func ValidateOauthType(oauthType string) error {
switch oauthType {
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth:
return nil
default:
return errors.New("invalid Oauth type")
}
}
const (
OauthNameGithub string = "GitHub"
OauthNameGoogle string = "Google"
@@ -23,8 +35,7 @@ const (
const (
UserEndpointGithub string = "https://api.github.com/user"
UserEndpointGoogle string = "https://www.googleapis.com/oauth2/v3/userinfo"
UserEndpointOidc string = ""
IssuerGoogle string = "https://accounts.google.com"
)
type Oauth struct {
@@ -40,6 +51,40 @@ type Oauth struct {
TimeModel
}
// Helper function to format oauth info, it's used in the update and create method
func (oa *Oauth) FormatOauthInfo() error {
oauthType := strings.TrimSpace(oa.OauthType)
err := ValidateOauthType(oa.OauthType)
if err != nil {
return err
}
// check if the op is empty, set the default value
op := strings.TrimSpace(oa.Op)
if op == "" {
switch oauthType {
case OauthTypeGithub:
oa.Op = OauthNameGithub
case OauthTypeGoogle:
oa.Op = OauthNameGoogle
case OauthTypeOidc:
oa.Op = OauthNameOidc
case OauthTypeWebauth:
oa.Op = OauthNameWebauth
default:
oa.Op = oauthType
}
}
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
issuer := strings.TrimSpace(oa.Issuer)
// If the oauth type is google and the issuer is empty, set the issuer to the default value
if oauthType == OauthTypeGoogle && issuer == "" {
oa.Issuer = IssuerGoogle
}
return nil
}
type OauthUser struct {
OpenId string `json:"open_id" gorm:"not null;index"`
Name string `json:"name"`
@@ -90,15 +135,6 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
}
}
type GoogleUser struct {
OidcUser
}
// GoogleUser 使用特定的 Username 规则来调用 ToOauthUser
func (gu *GoogleUser) ToOauthUser() *OauthUser {
return gu.OidcUser.ToOauthUser()
}
type GithubUser struct {
OauthUserBase

View File

@@ -9,7 +9,7 @@ import (
"errors"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/google"
// "golang.org/x/oauth2/google"
"gorm.io/gorm"
// "io"
"net/http"
@@ -71,16 +71,6 @@ func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
oa.Email = oauthUser.Email
}
// Validate the oauth type
func (os *OauthService) ValidateOauthType(oauthType string) error {
switch oauthType {
case model.OauthTypeGithub, model.OauthTypeGoogle, model.OauthTypeOidc, model.OauthTypeWebauth:
return nil
default:
return errors.New("invalid Oauth type")
}
}
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
v, ok := OauthCache.Load(key)
@@ -160,7 +150,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
}
// Maybe should validate the oauthConfig here
oauthType := oauthInfo.OauthType
err = os.ValidateOauthType(oauthType)
err = model.ValidateOauthType(oauthType)
if err != nil {
return err, nil, nil
}
@@ -168,10 +158,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
case model.OauthTypeGithub:
oauthConfig.Endpoint = github.Endpoint
oauthConfig.Scopes = []string{"read:user", "user:email"}
case model.OauthTypeGoogle:
oauthConfig.Endpoint = google.Endpoint
oauthConfig.Scopes = os.constructScopes(model.OIDC_DEFAULT_SCOPES)
case model.OauthTypeOidc:
case model.OauthTypeOidc, model.OauthTypeGoogle:
var endpoint OidcEndpoint
err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil {
@@ -272,14 +259,6 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
return nil, user.ToOauthUser()
}
// googleCallback google回调
func (os *OauthService) googleCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
var user = &model.GoogleUser{}
if err, _ := os.callbackBase(oauthConfig, code, model.UserEndpointGoogle, user); err != nil {
return err, nil
}
return nil, user.ToOauthUser()
}
// oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
@@ -303,9 +282,7 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
switch oauthType {
case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code)
case model.OauthTypeGoogle:
err, oauthUser = os.googleCallback(oauthConfig, code)
case model.OauthTypeOidc:
case model.OauthTypeOidc, model.OauthTypeGoogle:
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil {
return err, nil
@@ -422,6 +399,10 @@ func (os *OauthService) IsOauthProviderExist(op string) bool {
// Create 创建
func (os *OauthService) Create(oauthInfo *model.Oauth) error {
err := oauthInfo.FormatOauthInfo()
if err != nil {
return err
}
res := global.DB.Create(oauthInfo).Error
return res
}
@@ -431,6 +412,10 @@ func (os *OauthService) Delete(oauthInfo *model.Oauth) error {
// Update 更新
func (os *OauthService) Update(oauthInfo *model.Oauth) error {
err := oauthInfo.FormatOauthInfo()
if err != nil {
return err
}
return global.DB.Model(oauthInfo).Updates(oauthInfo).Error
}