up oauth re

This commit is contained in:
ljw
2024-11-05 09:48:02 +08:00
parent 4321a41cd7
commit daeae19194
9 changed files with 170 additions and 175 deletions

View File

@@ -12,16 +12,15 @@ import (
// "golang.org/x/oauth2/google"
"gorm.io/gorm"
// "io"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"fmt"
)
type OauthService struct {
}
@@ -34,26 +33,26 @@ type OidcEndpoint struct {
}
type OauthCacheItem struct {
UserId uint `json:"user_id"`
Id string `json:"id"` //rustdesk的设备ID
Op string `json:"op"`
Action string `json:"action"`
Uuid string `json:"uuid"`
DeviceName string `json:"device_name"`
DeviceOs string `json:"device_os"`
DeviceType string `json:"device_type"`
OpenId string `json:"open_id"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
UserId uint `json:"user_id"`
Id string `json:"id"` //rustdesk的设备ID
Op string `json:"op"`
Action string `json:"action"`
Uuid string `json:"uuid"`
DeviceName string `json:"device_name"`
DeviceOs string `json:"device_os"`
DeviceType string `json:"device_type"`
OpenId string `json:"open_id"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
}
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
return &model.OauthUser{
OpenId: oci.OpenId,
OpenId: oci.OpenId,
Username: oci.Username,
Name: oci.Name,
Email: oci.Email,
Name: oci.Name,
Email: oci.Email,
}
}
@@ -64,14 +63,13 @@ const (
OauthActionTypeBind = "bind"
)
func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
oa.OpenId = oauthUser.OpenId
oa.Username = oauthUser.Username
oa.Name = oauthUser.Name
oa.Email = oauthUser.Email
func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
oci.OpenId = oauthUser.OpenId
oci.Username = oauthUser.Username
oci.Name = oauthUser.Name
oci.Email = oauthUser.Email
}
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
v, ok := OauthCache.Load(key)
if !ok {
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
if err != nil {
return err, nil, nil
}
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
default:
return errors.New("unsupported OAuth type"), nil, nil
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
return nil, user.ToOauthUser()
}
// oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
var user = &model.OidcUser{}
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
return err, nil
@@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
}
oauthType := oauthInfo.OauthType
switch oauthType {
case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code)
case model.OauthTypeOidc, model.OauthTypeGoogle:
case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code)
case model.OauthTypeOidc, model.OauthTypeGoogle:
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil {
return err, nil
}
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
default:
return errors.New("unsupported OAuth type"), nil
}
return err, oauthUser
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
default:
return errors.New("unsupported OAuth type"), nil
}
return err, oauthUser
}
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
ut := &model.UserThird{}
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
@@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
// Helper function to get scopes by operation
func (os *OauthService) getScopesByOp(op string) []string {
scopes := os.InfoByOp(op).Scopes
scopes := os.InfoByOp(op).Scopes
return os.constructScopes(scopes)
}
// Helper function to construct scopes
func (os *OauthService) constructScopes(scopes string) []string {
scopes = strings.TrimSpace(scopes)
if scopes == "" {
scopes = model.OIDC_DEFAULT_SCOPES
}
return strings.Split(scopes, ",")
scopes = strings.TrimSpace(scopes)
if scopes == "" {
scopes = model.OIDC_DEFAULT_SCOPES
}
return strings.Split(scopes, ",")
}
func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
@@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m
}
return fmt.Errorf("no primary verified email found")
}
}

View File

@@ -5,13 +5,13 @@ import (
adResp "Gwen/http/response/admin"
"Gwen/model"
"Gwen/utils"
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math/rand"
"strconv"
"time"
"strings"
"errors"
"time"
)
type UserService struct {
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
global.DB.Where("id = ?", id).First(u)
return u
}
// InfoByUsername 根据用户名取用户信息
func (us *UserService) InfoByUsername(un string) *model.User {
u := &model.User{}
@@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string {
func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
token := us.GenerateToken(u)
ut := &model.UserToken{
UserId: u.Id,
Token: token,
UserId: u.Id,
Token: token,
DeviceUuid: llog.Uuid,
DeviceId: llog.DeviceId,
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
}
global.DB.Create(ut)
llog.UserTokenId = ut.UserId
@@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error {
// GetUuidByToken 根据token和user取uuid
func (us *UserService) GetUuidByToken(u *model.User, token string) string {
ut := &model.UserToken{}
err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
if err != nil {
return ""
}
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
tx.Rollback()
return err
}
tx.Commit()
// 删除关联的peer
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
@@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error {
if us.IsAdmin(currentUser) {
adminCount := us.getAdminUserCount()
// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
return errors.New("The last admin user cannot be disabled or demoted")
}
}
@@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
}
// RegisterByOauth 注册
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
global.Lock.Lock("registerByOauth")
defer global.Lock.UnLock("registerByOauth")
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
if ut.Id != 0 {
return nil, us.InfoById(ut.UserId)
}
//check if this email has been registered
email := oauthUser.Email
err, oauthType := AllService.OauthService.GetTypeByOp(op)
if err != nil {
return err, nil
}
// if email is empty, use username and op as email
if email == "" {
email = oauthUser.Username + "@" + op
}
email = strings.ToLower(email)
// update email to oauthUser, in case it contain upper case
oauthUser.Email = email
user := us.InfoByEmail(email)
tx := global.DB.Begin()
if user.Id != 0 {
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
} else {
ut = &model.UserThird{}
ut.FromOauthUser(0, oauthUser, oauthType, op)
// The initial username should be formatted
username := us.formatUsername(oauthUser.Username)
usernameUnique := us.GenerateUsernameByOauth(username)
user = &model.User{
Username: usernameUnique,
GroupId: 1,
//check if this email has been registered
email := oauthUser.Email
// only email is not empty
if email != "" {
email = strings.ToLower(email)
// update email to oauthUser, in case it contain upper case
oauthUser.Email = email
user := us.InfoByEmail(email)
if user.Id != 0 {
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
global.DB.Create(ut)
return nil, user
}
oauthUser.ToUser(user, false)
tx.Create(user)
if user.Id == 0 {
tx.Rollback()
return errors.New("OauthRegisterFailed"), user
}
ut.UserId = user.Id
}
tx := global.DB.Begin()
ut = &model.UserThird{}
ut.FromOauthUser(0, oauthUser, oauthType, op)
// The initial username should be formatted
username := us.formatUsername(oauthUser.Username)
usernameUnique := us.GenerateUsernameByOauth(username)
user := &model.User{
Username: usernameUnique,
GroupId: 1,
}
oauthUser.ToUser(user, false)
tx.Create(user)
if user.Id == 0 {
tx.Rollback()
return errors.New("OauthRegisterFailed"), user
}
ut.UserId = user.Id
tx.Create(ut)
tx.Commit()
return nil, user
@@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string {
return username
}
// Helper functions, getUserCount
// Helper functions, getUserCount
func (us *UserService) getUserCount() int64 {
var count int64
global.DB.Model(&model.User{}).Count(&count)
@@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 {
var count int64
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
return count
}
}