diff --git a/conf/config.yaml b/conf/config.yaml index 6afca95..11f94dd 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -44,4 +44,7 @@ oss: max-byte: 10240 jwt: private-key: "./conf/jwt_pri.pem" - expire-duration: 360000 \ No newline at end of file + expire-duration: 360000 +proxy: + enable: false + host: "" \ No newline at end of file diff --git a/config/config.go b/config/config.go index 7db7a9a..86d5d4b 100644 --- a/config/config.go +++ b/config/config.go @@ -30,6 +30,7 @@ type Config struct { Oss Oss Jwt Jwt Rustdesk Rustdesk + Proxy Proxy } // Init 初始化配置 diff --git a/config/proxy.go b/config/proxy.go new file mode 100644 index 0000000..c413e7e --- /dev/null +++ b/config/proxy.go @@ -0,0 +1,6 @@ +package config + +type Proxy struct { + Enable bool `mapstructure:"enable"` + Host string `mapstructure:"host"` +} diff --git a/service/oauth.go b/service/oauth.go index c28db44..0ee6add 100644 --- a/service/oauth.go +++ b/service/oauth.go @@ -12,6 +12,8 @@ import ( "golang.org/x/oauth2/google" "gorm.io/gorm" "io" + "net/http" + "net/url" "strconv" "sync" "time" @@ -166,20 +168,44 @@ func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) { return errors.New("ConfigNotFound"), nil } +func getHTTPClientWithProxy() *http.Client { + if global.Config.Proxy.Enable { + if global.Config.Proxy.Host == "" { + global.Logger.Warn("Proxy is enabled but proxy host is empty.") + return http.DefaultClient + } + proxyURL, err := url.Parse(global.Config.Proxy.Host) + if err != nil { + global.Logger.Warn("Invalid proxy URL: ", err) + return http.DefaultClient + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + return &http.Client{Transport: transport} + } + return http.DefaultClient +} + func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) { err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub) if err != nil { return err, nil } - token, err := oauthConfig.Exchange(context.Background(), code) + + // 使用代理配置创建 HTTP 客户端 + httpClient := getHTTPClientWithProxy() + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + + token, err := oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) error = errors.New("GetOauthTokenError") return } - // 创建一个 HTTP 客户端,并将 access_token 添加到 Authorization 头中 - client := oauthConfig.Client(context.Background(), token) + // 使用带有代理的 HTTP 客户端获取用户信息 + client := oauthConfig.Client(ctx, token) resp, err := client.Get("https://api.github.com/user") if err != nil { global.Logger.Warn("failed getting user info: ", err) @@ -193,7 +219,7 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith } }(resp.Body) - // 在这里处理 GitHub 用户信息 + // 解析用户信息 if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { global.Logger.Warn("failed decoding user info: ", err) error = errors.New("DecodeOauthUserInfoError") @@ -204,14 +230,23 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) { err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle) - token, err := oauthConfig.Exchange(context.Background(), code) + if err != nil { + return err, nil + } + + // 使用代理配置创建 HTTP 客户端 + httpClient := getHTTPClientWithProxy() + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + + token, err := oauthConfig.Exchange(ctx, code) if err != nil { global.Logger.Warn("oauthConfig.Exchange() failed: ", err) error = errors.New("GetOauthTokenError") return } - // 创建 HTTP 客户端,并将 access_token 添加到 Authorization 头中 - client := oauthConfig.Client(context.Background(), token) + + // 使用带有代理的 HTTP 客户端获取用户信息 + client := oauthConfig.Client(ctx, token) resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") if err != nil { global.Logger.Warn("failed getting user info: ", err) @@ -225,8 +260,9 @@ func (os *OauthService) GoogleCallback(code string) (error error, userData *Goog } }(resp.Body) + // 解析用户信息 if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil { - global.Logger.Warn("failed decoding user info: %s\n", err) + global.Logger.Warn("failed decoding user info: ", err) error = errors.New("DecodeOauthUserInfoError") return }