mirror of
https://github.com/lejianwen/rustdesk-api.git
synced 2025-11-29 08:33:21 +00:00
Add oauth callback via proxy
Improved support for environment variables and configuration files, and standardized default behaviors
This commit is contained in:
@@ -45,3 +45,6 @@ oss:
|
|||||||
jwt:
|
jwt:
|
||||||
private-key: "./conf/jwt_pri.pem"
|
private-key: "./conf/jwt_pri.pem"
|
||||||
expire-duration: 360000
|
expire-duration: 360000
|
||||||
|
proxy:
|
||||||
|
enable: false
|
||||||
|
host: ""
|
||||||
@@ -30,6 +30,7 @@ type Config struct {
|
|||||||
Oss Oss
|
Oss Oss
|
||||||
Jwt Jwt
|
Jwt Jwt
|
||||||
Rustdesk Rustdesk
|
Rustdesk Rustdesk
|
||||||
|
Proxy Proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init 初始化配置
|
// Init 初始化配置
|
||||||
|
|||||||
6
config/proxy.go
Normal file
6
config/proxy.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
type Proxy struct {
|
||||||
|
Enable bool `mapstructure:"enable"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -166,20 +168,44 @@ func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
|
|||||||
return errors.New("ConfigNotFound"), nil
|
return errors.New("ConfigNotFound"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getHTTPClientWithProxy() *http.Client {
|
||||||
|
if global.Config.Proxy.Enable {
|
||||||
|
if global.Config.Proxy.Host == "" {
|
||||||
|
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
||||||
|
proxyURL, err := url.Parse(global.Config.Proxy.Host)
|
||||||
|
if err != nil {
|
||||||
|
global.Logger.Warn("Invalid proxy URL: ", err)
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
|
func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
|
||||||
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
|
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
token, err := oauthConfig.Exchange(context.Background(), code)
|
|
||||||
|
// 使用代理配置创建 HTTP 客户端
|
||||||
|
httpClient := getHTTPClientWithProxy()
|
||||||
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
|
token, err := oauthConfig.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
||||||
error = errors.New("GetOauthTokenError")
|
error = errors.New("GetOauthTokenError")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建一个 HTTP 客户端,并将 access_token 添加到 Authorization 头中
|
// 使用带有代理的 HTTP 客户端获取用户信息
|
||||||
client := oauthConfig.Client(context.Background(), token)
|
client := oauthConfig.Client(ctx, token)
|
||||||
resp, err := client.Get("https://api.github.com/user")
|
resp, err := client.Get("https://api.github.com/user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("failed getting user info: ", err)
|
global.Logger.Warn("failed getting user info: ", err)
|
||||||
@@ -193,7 +219,7 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
|
|||||||
}
|
}
|
||||||
}(resp.Body)
|
}(resp.Body)
|
||||||
|
|
||||||
// 在这里处理 GitHub 用户信息
|
// 解析用户信息
|
||||||
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
|
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
|
||||||
global.Logger.Warn("failed decoding user info: ", err)
|
global.Logger.Warn("failed decoding user info: ", err)
|
||||||
error = errors.New("DecodeOauthUserInfoError")
|
error = errors.New("DecodeOauthUserInfoError")
|
||||||
@@ -204,14 +230,23 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
|
|||||||
|
|
||||||
func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
|
func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
|
||||||
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
|
err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
|
||||||
token, err := oauthConfig.Exchange(context.Background(), code)
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用代理配置创建 HTTP 客户端
|
||||||
|
httpClient := getHTTPClientWithProxy()
|
||||||
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
|
token, err := oauthConfig.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
||||||
error = errors.New("GetOauthTokenError")
|
error = errors.New("GetOauthTokenError")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 创建 HTTP 客户端,并将 access_token 添加到 Authorization 头中
|
|
||||||
client := oauthConfig.Client(context.Background(), token)
|
// 使用带有代理的 HTTP 客户端获取用户信息
|
||||||
|
client := oauthConfig.Client(ctx, token)
|
||||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("failed getting user info: ", err)
|
global.Logger.Warn("failed getting user info: ", err)
|
||||||
@@ -225,8 +260,9 @@ func (os *OauthService) GoogleCallback(code string) (error error, userData *Goog
|
|||||||
}
|
}
|
||||||
}(resp.Body)
|
}(resp.Body)
|
||||||
|
|
||||||
|
// 解析用户信息
|
||||||
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
|
if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
|
||||||
global.Logger.Warn("failed decoding user info: %s\n", err)
|
global.Logger.Warn("failed decoding user info: ", err)
|
||||||
error = errors.New("DecodeOauthUserInfoError")
|
error = errors.New("DecodeOauthUserInfoError")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user