Merge branch 'oauth_re' of https://github.com/IamTaoChen/rustdesk-api into IamTaoChen-oauth_re

This commit is contained in:
ljw
2024-11-05 08:24:48 +08:00
20 changed files with 822 additions and 555 deletions

View File

@@ -11,7 +11,6 @@ import (
"Gwen/service"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type Login struct {
@@ -60,12 +59,7 @@ func (ct *Login) Login(c *gin.Context) {
Platform: f.Platform,
})
response.Success(c, &adResp.LoginPayload{
Token: ut.Token,
Username: u.Username,
RouteNames: service.AllService.UserService.RouteNames(u),
Nickname: u.Nickname,
})
responseLoginSuccess(c, u, ut.Token)
}
// Logout 登出
@@ -96,13 +90,7 @@ func (ct *Login) Logout(c *gin.Context) {
// @Failure 500 {object} response.ErrorResponse
// @Router /admin/login-options [post]
func (ct *Login) LoginOptions(c *gin.Context) {
res := service.AllService.OauthService.List(1, 100, func(tx *gorm.DB) {
tx.Select("op").Order("id")
})
var ops []string
for _, v := range res.Oauths {
ops = append(ops, v.Op)
}
ops := service.AllService.OauthService.GetOauthProviders()
response.Success(c, gin.H{
"ops": ops,
"register": global.Config.App.Register,
@@ -163,12 +151,14 @@ func (ct *Login) OidcAuthQuery(c *gin.Context) {
if ut == nil {
return
}
//fmt.Println("u:", u)
//fmt.Println("ut:", ut)
response.Success(c, &adResp.LoginPayload{
Token: ut.Token,
Username: u.Username,
RouteNames: service.AllService.UserService.RouteNames(u),
Nickname: u.Nickname,
})
responseLoginSuccess(c, u, ut.Token)
}
func responseLoginSuccess(c *gin.Context, u *model.User, token string) {
lp := &adResp.LoginPayload{}
lp.FromUser(u)
lp.Token = token
lp.RouteNames = service.AllService.UserService.RouteNames(u)
response.Success(c, lp)
}

View File

@@ -5,7 +5,6 @@ import (
"Gwen/http/request/admin"
adminReq "Gwen/http/request/admin"
"Gwen/http/response"
"Gwen/model"
"Gwen/service"
"github.com/gin-gonic/gin"
"strconv"
@@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return
}
v := service.AllService.OauthService.GetOauthCache(j.Code)
if v == nil {
oauthService := service.AllService.OauthService
oauthCache := oauthService.GetOauthCache(j.Code)
if oauthCache == nil {
response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
return
}
u := service.AllService.UserService.CurUser(c)
err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id)
oauthUser := oauthCache.ToOauthUser()
user := service.AllService.UserService.CurUser(c)
err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "BindFail"))
return
}
v.UserId = u.Id
service.AllService.OauthService.SetOauthCache(j.Code, v, 0)
response.Success(c, v)
oauthCache.UserId = user.Id
oauthService.SetOauthCache(j.Code, oauthCache, 0)
response.Success(c, oauthCache)
}
func (o *Oauth) Unbind(c *gin.Context) {
@@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
return
}
if f.Op == model.OauthTypeGithub {
err = service.AllService.OauthService.UnBindGithubUser(u.Id)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
if f.Op == model.OauthTypeGoogle {
err = service.AllService.OauthService.UnBindGoogleUser(u.Id)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
}
if f.Op == model.OauthTypeOidc {
err = service.AllService.OauthService.UnBindOidcUser(u.Id)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
}
response.Success(c, nil)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strconv"
"time"
)
type User struct {
@@ -216,12 +217,7 @@ func (ct *User) Current(c *gin.Context) {
u := service.AllService.UserService.CurUser(c)
token, _ := c.Get("token")
t := token.(string)
response.Success(c, &adResp.LoginPayload{
Token: t,
Username: u.Username,
RouteNames: service.AllService.UserService.RouteNames(u),
Nickname: u.Nickname,
})
responseLoginSuccess(c, u, t)
}
// ChangeCurPwd 修改当前用户密码
@@ -286,10 +282,10 @@ func (ct *User) MyOauth(c *gin.Context) {
var res []*adResp.UserOauthItem
for _, oa := range oal.Oauths {
item := &adResp.UserOauthItem{
ThirdType: oa.Op,
Op: oa.Op,
}
for _, ut := range uts {
if ut.ThirdType == oa.Op {
if ut.Op == oa.Op {
item.Status = 1
break
}
@@ -299,6 +295,51 @@ func (ct *User) MyOauth(c *gin.Context) {
response.Success(c, res)
}
// List 列表
// @Tags 设备
// @Summary 设备列表
// @Description 设备列表
// @Accept json
// @Produce json
// @Param page query int false "页码"
// @Param page_size query int false "页大小"
// @Param time_ago query int false "时间"
// @Param id query string false "ID"
// @Param hostname query string false "主机名"
// @Param uuids query string false "uuids 用逗号分隔"
// @Success 200 {object} response.Response{data=model.PeerList}
// @Failure 500 {object} response.Response
// @Router /admin/user/myPeer [get]
// @Security token
func (ct *User) MyPeer(c *gin.Context) {
query := &admin.PeerQuery{}
if err := c.ShouldBindQuery(query); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
u := service.AllService.UserService.CurUser(c)
res := service.AllService.PeerService.ListFilterByUserId(query.Page, query.PageSize, func(tx *gorm.DB) {
if query.TimeAgo > 0 {
lt := time.Now().Unix() - int64(query.TimeAgo)
tx.Where("last_online_time < ?", lt)
}
if query.TimeAgo < 0 {
lt := time.Now().Unix() + int64(query.TimeAgo)
tx.Where("last_online_time > ?", lt)
}
if query.Id != "" {
tx.Where("id like ?", "%"+query.Id+"%")
}
if query.Hostname != "" {
tx.Where("hostname like ?", "%"+query.Hostname+"%")
}
if query.Uuids != "" {
tx.Where("uuid in (?)", query.Uuids)
}
}, u.Id)
response.Success(c, res)
}
// groupUsers
func (ct *User) GroupUsers(c *gin.Context) {
q := &admin.GroupUsersQuery{}
@@ -345,7 +386,7 @@ func (ct *User) Register(c *gin.Context) {
response.Fail(c, 101, errList[0])
return
}
u := service.AllService.UserService.Register(f.Username, f.Password)
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password)
if u == nil || u.Id == 0 {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
return
@@ -358,10 +399,5 @@ func (ct *User) Register(c *gin.Context) {
Ip: c.ClientIP(),
Type: model.LoginLogTypeAccount,
})
response.Success(c, &adResp.LoginPayload{
Token: ut.Token,
Username: u.Username,
RouteNames: service.AllService.UserService.RouteNames(u),
Nickname: u.Nickname,
})
responseLoginSuccess(c, u, ut.Token)
}

View File

@@ -60,6 +60,7 @@ func (l *Login) Login(c *gin.Context) {
ut := service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id,
Client: f.DeviceInfo.Type,
DeviceId: f.Id,
Uuid: f.Uuid,
Ip: c.ClientIP(),
Type: model.LoginLogTypeAccount,
@@ -83,22 +84,10 @@ func (l *Login) Login(c *gin.Context) {
// @Failure 500 {object} response.ErrorResponse
// @Router /login-options [get]
func (l *Login) LoginOptions(c *gin.Context) {
oauthOks := []string{}
err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub)
if err == nil {
oauthOks = append(oauthOks, model.OauthTypeGithub)
}
err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle)
if err == nil {
oauthOks = append(oauthOks, model.OauthTypeGoogle)
}
err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc)
if err == nil {
oauthOks = append(oauthOks, model.OauthTypeOidc)
}
oauthOks = append(oauthOks, model.OauthTypeWebauth)
ops := service.AllService.OauthService.GetOauthProviders()
ops = append(ops, model.OauthTypeWebauth)
var oidcItems []map[string]string
for _, v := range oauthOks {
for _, v := range ops {
oidcItems = append(oidcItems, map[string]string{"name": v})
}
common, err := json.Marshal(oidcItems)
@@ -108,7 +97,7 @@ func (l *Login) LoginOptions(c *gin.Context) {
}
var res []string
res = append(res, "common-oidc/"+string(common))
for _, v := range oauthOks {
for _, v := range ops {
res = append(res, "oidc/"+v)
}
c.JSON(http.StatusOK, res)

View File

@@ -9,8 +9,6 @@ import (
"Gwen/service"
"github.com/gin-gonic/gin"
"net/http"
"strconv"
"strings"
)
type Oauth struct {
@@ -32,13 +30,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
//fmt.Println(f)
if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc {
response.Error(c, response.TranslateMsg(c, "ParamsError"))
return
}
err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
oauthService := service.AllService.OauthService
var code string
var url string
err, code, url = oauthService.BeginAuth(f.Op)
if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error()))
return
@@ -98,6 +94,7 @@ func (o *Oauth) OidcAuthQueryPre(c *gin.Context) (*model.User, *model.UserToken)
ut = service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id,
Client: v.DeviceType,
DeviceId: v.Id,
Uuid: v.Uuid,
Ip: c.ClientIP(),
Type: model.LoginLogTypeOauth,
@@ -149,70 +146,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
return
}
cacheKey := state
oauthService := service.AllService.OauthService
//从缓存中获取
v := service.AllService.OauthService.GetOauthCache(cacheKey)
if v == nil {
oauthCache := oauthService.GetOauthCache(cacheKey)
if oauthCache == nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
return
}
ty := v.Op
ac := v.Action
var u *model.User
openid := ""
thirdName := ""
//fmt.Println("ty ac ", ty, ac)
if ty == model.OauthTypeGithub {
code := c.Query("code")
err, userData := service.AllService.OauthService.GithubCallback(code)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
return
}
openid = strconv.Itoa(userData.Id)
thirdName = userData.Login
} else if ty == model.OauthTypeGoogle {
code := c.Query("code")
err, userData := service.AllService.OauthService.GoogleCallback(code)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
return
}
openid = userData.Email
//将空格替换成_
thirdName = strings.Replace(userData.Name, " ", "_", -1)
} else if ty == model.OauthTypeOidc {
code := c.Query("code")
err, userData := service.AllService.OauthService.OidcCallback(code)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
return
}
openid = userData.Sub
thirdName = userData.PreferredUsername
} else {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
op := oauthCache.Op
action := oauthCache.Action
var user *model.User
// 获取用户信息
code := c.Query("code")
err, oauthUser := oauthService.Callback(code, op)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
return
}
if ac == service.OauthActionTypeBind {
userId := oauthCache.UserId
openid := oauthUser.OpenId
if action == service.OauthActionTypeBind {
//fmt.Println("bind", ty, userData)
utr := service.AllService.OauthService.UserThirdInfo(ty, openid)
// 检查此openid是否已经绑定过
utr := oauthService.UserThirdInfo(op, openid)
if utr.UserId > 0 {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
return
}
//绑定
u = service.AllService.UserService.InfoById(v.UserId)
if u == nil {
user = service.AllService.UserService.InfoById(userId)
if user == nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
return
}
//绑定
err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId)
err := oauthService.BindOauthUser(userId, oauthUser, op)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
return
@@ -220,42 +190,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
return
} else if ac == service.OauthActionTypeLogin {
} else if action == service.OauthActionTypeLogin {
//登录
if v.UserId != 0 {
if userId != 0 {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
return
}
u = service.AllService.UserService.InfoByGithubId(openid)
if u == nil {
oa := service.AllService.OauthService.InfoByOp(ty)
if !*oa.AutoRegister {
user = service.AllService.UserService.InfoByOauthId(op, openid)
if user == nil {
oauthConfig := oauthService.InfoByOp(op)
if !*oauthConfig.AutoRegister {
//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
v.ThirdName = thirdName
v.ThirdOpenId = openid
oauthCache.UpdateFromOauthUser(oauthUser)
url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
c.Redirect(http.StatusFound, url)
return
}
//自动注册
u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid)
if u.Id == 0 {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed"))
err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, err.Error()))
return
}
}
v.UserId = u.Id
service.AllService.OauthService.SetOauthCache(cacheKey, v, 0)
oauthCache.UserId = user.Id
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
// 如果是webadmin登录成功后跳转到webadmin
if v.DeviceType == "webadmin" {
if oauthCache.DeviceType == "webadmin" {
/*service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id,
Client: "webadmin",
Uuid: "", //must be empty
Ip: c.ClientIP(),
Type: model.LoginLogTypeOauth,
Platform: v.DeviceOs,
Platform: oauthService.DeviceOs,
})*/
url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
c.Redirect(http.StatusFound, url)