feat(oidc): add pkce (#150)
This commit is contained in:
@@ -166,7 +166,7 @@ func InitGlobal() {
|
||||
global.Lock = lock.NewLocal()
|
||||
}
|
||||
func DatabaseAutoUpdate() {
|
||||
version := 260
|
||||
version := 261
|
||||
|
||||
db := global.DB
|
||||
|
||||
@@ -210,6 +210,12 @@ func DatabaseAutoUpdate() {
|
||||
if v.Version < uint(version) {
|
||||
Migrate(uint(version))
|
||||
}
|
||||
// 261迁移
|
||||
if v.Version < 261 {
|
||||
// 在oauths表中添加pkce_enable 和 pkce_method 字段
|
||||
db.Exec("ALTER TABLE oauths ADD COLUMN pkce_enable TINYINT(1) NOT NULL DEFAULT 0")
|
||||
db.Exec("ALTER TABLE oauths ADD COLUMN pkce_method VARCHAR(20) NOT NULL DEFAULT 'S256'")
|
||||
}
|
||||
// 245迁移
|
||||
if v.Version < 245 {
|
||||
//oauths 表的 oauth_type 字段设置为 op同样的值
|
||||
|
||||
@@ -283,13 +283,13 @@ func (ct *Login) OidcAuth(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
|
||||
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
|
||||
Action: service.OauthActionTypeLogin,
|
||||
Op: f.Op,
|
||||
Id: f.Id,
|
||||
@@ -297,10 +297,11 @@ func (ct *Login) OidcAuth(c *gin.Context) {
|
||||
// DeviceOs: ct.Platform(c),
|
||||
DeviceOs: f.DeviceInfo.Os,
|
||||
Uuid: f.Uuid,
|
||||
Verifier: verifier,
|
||||
}, 5*60)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"code": code,
|
||||
"code": state,
|
||||
"url": url,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -43,20 +43,21 @@ func (o *Oauth) ToBind(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
|
||||
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
|
||||
Action: service.OauthActionTypeBind,
|
||||
Op: f.Op,
|
||||
UserId: u.Id,
|
||||
Verifier: verifier,
|
||||
}, 5*60)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"code": code,
|
||||
"code": state,
|
||||
"url": url,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -32,15 +32,16 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
|
||||
}
|
||||
|
||||
oauthService := service.AllService.OauthService
|
||||
var code string
|
||||
var state string
|
||||
var url string
|
||||
err, code, url = oauthService.BeginAuth(f.Op)
|
||||
var verifier string
|
||||
err, state, verifier, url = oauthService.BeginAuth(f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
|
||||
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
|
||||
Action: service.OauthActionTypeLogin,
|
||||
Id: f.Id,
|
||||
Op: f.Op,
|
||||
@@ -48,10 +49,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
|
||||
DeviceName: f.DeviceInfo.Name,
|
||||
DeviceOs: f.DeviceInfo.Os,
|
||||
DeviceType: f.DeviceInfo.Type,
|
||||
Verifier: verifier,
|
||||
}, 5*60)
|
||||
//fmt.Println("code url", code, url)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"code": state,
|
||||
"url": url,
|
||||
})
|
||||
}
|
||||
@@ -156,10 +158,11 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
||||
}
|
||||
op := oauthCache.Op
|
||||
action := oauthCache.Action
|
||||
verifier := oauthCache.Verifier
|
||||
var user *model.User
|
||||
// 获取用户信息
|
||||
code := c.Query("code")
|
||||
err, oauthUser := oauthService.Callback(code, op)
|
||||
err, oauthUser := oauthService.Callback(code, verifier, op)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
|
||||
@@ -24,6 +24,8 @@ type OauthForm struct {
|
||||
ClientSecret string `json:"client_secret" validate:"required"`
|
||||
RedirectUrl string `json:"redirect_url" validate:"required"`
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
PkceEnable *bool `json:"pkce_enable"`
|
||||
PkceMethod string `json:"pkce_method"`
|
||||
}
|
||||
|
||||
func (of *OauthForm) ToOauth() *model.Oauth {
|
||||
@@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth {
|
||||
AutoRegister: of.AutoRegister,
|
||||
Issuer: of.Issuer,
|
||||
Scopes: of.Scopes,
|
||||
PkceEnable: of.PkceEnable,
|
||||
PkceMethod: of.PkceMethod,
|
||||
}
|
||||
oa.Id = of.Id
|
||||
return oa
|
||||
|
||||
@@ -14,6 +14,8 @@ const (
|
||||
OauthTypeGoogle string = "google"
|
||||
OauthTypeOidc string = "oidc"
|
||||
OauthTypeWebauth string = "webauth"
|
||||
PKCEMethodS256 string = "S256"
|
||||
PKCEMethodPlain string = "plain"
|
||||
)
|
||||
|
||||
// Validate the oauth type
|
||||
@@ -41,6 +43,8 @@ type Oauth struct {
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
Scopes string `json:"scopes"`
|
||||
Issuer string `json:"issuer"`
|
||||
PkceEnable *bool `json:"pkce_enable"`
|
||||
PkceMethod string `json:"pkce_method"`
|
||||
TimeModel
|
||||
}
|
||||
|
||||
@@ -68,6 +72,13 @@ func (oa *Oauth) FormatOauthInfo() error {
|
||||
if oauthType == OauthTypeGoogle && issuer == "" {
|
||||
oa.Issuer = IssuerGoogle
|
||||
}
|
||||
if oa.PkceEnable == nil {
|
||||
oa.PkceEnable = new(bool)
|
||||
*oa.PkceEnable = false
|
||||
}
|
||||
if oa.PkceMethod == "" {
|
||||
oa.PkceMethod = PKCEMethodS256
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ type OauthCacheItem struct {
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Verifier string `json:"verifier"` // used for oauth pkce
|
||||
}
|
||||
|
||||
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
||||
@@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) {
|
||||
OauthCache.Delete(key)
|
||||
}
|
||||
|
||||
func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
|
||||
code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
|
||||
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||
verifier = ""
|
||||
if op == string(model.OauthTypeWebauth) {
|
||||
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
|
||||
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
||||
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
||||
return nil, code, url
|
||||
return nil, state, verifier, url
|
||||
}
|
||||
err, _, oauthConfig := os.GetOauthConfig(op)
|
||||
err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
|
||||
if err == nil {
|
||||
return err, code, oauthConfig.AuthCodeURL(code)
|
||||
extras := make([]oauth2.AuthCodeOption, 0, 3)
|
||||
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
|
||||
extras = append(extras, oauth2.AccessTypeOffline)
|
||||
verifier = oauth2.GenerateVerifier()
|
||||
switch oauthInfo.PkceMethod {
|
||||
case model.PKCEMethodS256:
|
||||
extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
||||
case model.PKCEMethodPlain:
|
||||
// oauth2 does not have a plain challenge option, so we add it manually
|
||||
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
|
||||
}
|
||||
}
|
||||
return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
|
||||
}
|
||||
|
||||
return err, code, ""
|
||||
return err, state, verifier, ""
|
||||
}
|
||||
|
||||
// Method to fetch OIDC configuration dynamically
|
||||
@@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client {
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
||||
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
||||
|
||||
// 设置代理客户端
|
||||
httpClient := getHTTPClientWithProxy()
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||
|
||||
var exchangeOpts []oauth2.AuthCodeOption
|
||||
if verifier != "" {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
||||
}
|
||||
|
||||
// 使用 code 换取 token
|
||||
var token *oauth2.Token
|
||||
token, err = oauthConfig.Exchange(ctx, code)
|
||||
token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
||||
if err != nil {
|
||||
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
||||
return errors.New("GetOauthTokenError"), nil
|
||||
@@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
|
||||
}
|
||||
|
||||
// githubCallback github回调
|
||||
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
|
||||
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
|
||||
var user = &model.GithubUser{}
|
||||
err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
|
||||
err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
|
||||
}
|
||||
|
||||
// oidcCallback oidc回调, 通过code获取用户信息
|
||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
|
||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
|
||||
var user = &model.OidcUser{}
|
||||
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
||||
if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
|
||||
return err, nil
|
||||
}
|
||||
return nil, user.ToOauthUser()
|
||||
}
|
||||
|
||||
// Callback: Get user information by code and op(Oauth provider)
|
||||
func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
|
||||
func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
|
||||
var oauthInfo *model.Oauth
|
||||
var oauthConfig *oauth2.Config
|
||||
err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
|
||||
@@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
|
||||
oauthType := oauthInfo.OauthType
|
||||
switch oauthType {
|
||||
case model.OauthTypeGithub:
|
||||
err, oauthUser = os.githubCallback(oauthConfig, code)
|
||||
err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
|
||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
||||
err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
|
||||
default:
|
||||
return errors.New("unsupported OAuth type"), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user