Add oauth callback via proxy

Improved support for environment variables and configuration files, and standardized default behaviors
This commit is contained in:
Oganneson
2024-10-19 15:46:45 +08:00
parent e142cc00c6
commit 0091a9dd7f
4 changed files with 55 additions and 9 deletions

View File

@@ -44,4 +44,7 @@ oss:
max-byte: 10240 max-byte: 10240
jwt: jwt:
private-key: "./conf/jwt_pri.pem" private-key: "./conf/jwt_pri.pem"
expire-duration: 360000 expire-duration: 360000
proxy:
enable: false
host: ""

View File

@@ -30,6 +30,7 @@ type Config struct {
Oss Oss Oss Oss
Jwt Jwt Jwt Jwt
Rustdesk Rustdesk Rustdesk Rustdesk
Proxy Proxy
} }
// Init 初始化配置 // Init 初始化配置

6
config/proxy.go Normal file
View File

@@ -0,0 +1,6 @@
package config
type Proxy struct {
Enable bool `mapstructure:"enable"`
Host string `mapstructure:"host"`
}

View File

@@ -12,6 +12,8 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
"gorm.io/gorm" "gorm.io/gorm"
"io" "io"
"net/http"
"net/url"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -166,20 +168,44 @@ func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
return errors.New("ConfigNotFound"), nil return errors.New("ConfigNotFound"), nil
} }
func getHTTPClientWithProxy() *http.Client {
if global.Config.Proxy.Enable {
if global.Config.Proxy.Host == "" {
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
return http.DefaultClient
}
proxyURL, err := url.Parse(global.Config.Proxy.Host)
if err != nil {
global.Logger.Warn("Invalid proxy URL: ", err)
return http.DefaultClient
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
return &http.Client{Transport: transport}
}
return http.DefaultClient
}
func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) { func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub) err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
if err != nil { if err != nil {
return err, nil return err, nil
} }
token, err := oauthConfig.Exchange(context.Background(), code)
// 使用代理配置创建 HTTP 客户端
httpClient := getHTTPClientWithProxy()
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
token, err := oauthConfig.Exchange(ctx, code)
if err != nil { if err != nil {
global.Logger.Warn("oauthConfig.Exchange() failed: ", err) global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
error = errors.New("GetOauthTokenError") error = errors.New("GetOauthTokenError")
return return
} }
// 创建一个 HTTP 客户端,并将 access_token 添加到 Authorization 头中 // 使用带有代理的 HTTP 客户端获取用户信息
client := oauthConfig.Client(context.Background(), token) client := oauthConfig.Client(ctx, token)
resp, err := client.Get("https://api.github.com/user") resp, err := client.Get("https://api.github.com/user")
if err != nil { if err != nil {
global.Logger.Warn("failed getting user info: ", err) global.Logger.Warn("failed getting user info: ", err)
@@ -193,7 +219,7 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
} }
}(resp.Body) }(resp.Body)
// 在这里处理 GitHub 用户信息 // 解析用户信息
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
global.Logger.Warn("failed decoding user info: ", err) global.Logger.Warn("failed decoding user info: ", err)
error = errors.New("DecodeOauthUserInfoError") error = errors.New("DecodeOauthUserInfoError")
@@ -204,14 +230,23 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) { func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle) err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
token, err := oauthConfig.Exchange(context.Background(), code) if err != nil {
return err, nil
}
// 使用代理配置创建 HTTP 客户端
httpClient := getHTTPClientWithProxy()
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
token, err := oauthConfig.Exchange(ctx, code)
if err != nil { if err != nil {
global.Logger.Warn("oauthConfig.Exchange() failed: ", err) global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
error = errors.New("GetOauthTokenError") error = errors.New("GetOauthTokenError")
return return
} }
// 创建 HTTP 客户端,并将 access_token 添加到 Authorization 头中
client := oauthConfig.Client(context.Background(), token) // 使用带有代理的 HTTP 客户端获取用户信息
client := oauthConfig.Client(ctx, token)
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil { if err != nil {
global.Logger.Warn("failed getting user info: ", err) global.Logger.Warn("failed getting user info: ", err)
@@ -225,8 +260,9 @@ func (os *OauthService) GoogleCallback(code string) (error error, userData *Goog
} }
}(resp.Body) }(resp.Body)
// 解析用户信息
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
global.Logger.Warn("failed decoding user info: %s\n", err) global.Logger.Warn("failed decoding user info: ", err)
error = errors.New("DecodeOauthUserInfoError") error = errors.New("DecodeOauthUserInfoError")
return return
} }