mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-04 00:10:25 +08:00
refactor: oauth2登录调整
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/auth/api/vo"
|
||||
@@ -59,7 +58,7 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
|
||||
biz.NotEmpty(stateAction, "state已过期, 请重新登录")
|
||||
|
||||
token, err := client.Exchange(rc.GinCtx, code)
|
||||
biz.ErrIsNilAppendErr(err, "获取token失败: %s")
|
||||
biz.ErrIsNilAppendErr(err, "获取OAuth2 accessToken失败: %s")
|
||||
|
||||
// 获取用户信息
|
||||
httpCli := client.Client(rc.GinCtx.Request.Context(), token)
|
||||
@@ -104,7 +103,12 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
|
||||
err = a.Oauth2App.GetOAuthAccount(&entity.Oauth2Account{
|
||||
AccountId: accountId,
|
||||
}, "account_id", "identity")
|
||||
biz.IsTrue(err != nil, "该账号已被绑定")
|
||||
biz.IsTrue(err != nil, "该账号已被其他用户绑定")
|
||||
|
||||
err = a.Oauth2App.GetOAuthAccount(&entity.Oauth2Account{
|
||||
Identity: userId,
|
||||
}, "account_id", "identity")
|
||||
biz.IsTrue(err != nil, "您已绑定其他账号")
|
||||
|
||||
now := time.Now()
|
||||
err = a.Oauth2App.BindOAuthAccount(&entity.Oauth2Account{
|
||||
@@ -118,13 +122,7 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
|
||||
"action": "oauthBind",
|
||||
"bind": true,
|
||||
}
|
||||
b, err = json.Marshal(res)
|
||||
biz.ErrIsNil(err, "数据序列化失败")
|
||||
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
|
||||
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
|
||||
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
|
||||
"</html>")
|
||||
rc.ResData = res
|
||||
} else {
|
||||
panic(biz.NewBizErr("state不合法"))
|
||||
}
|
||||
@@ -132,14 +130,12 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
|
||||
|
||||
// 指定登录操作
|
||||
func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity.ConfigOauth2Login) {
|
||||
clientIp := getIpAndRegion(rc)
|
||||
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", userId, clientIp)
|
||||
|
||||
// 查询用户是否存在
|
||||
oauthAccount := &entity.Oauth2Account{Identity: userId}
|
||||
err := a.Oauth2App.GetOAuthAccount(oauthAccount, "account_id", "identity")
|
||||
|
||||
var accountId uint64
|
||||
isFirst := false
|
||||
// 不存在,进行注册
|
||||
if err != nil {
|
||||
biz.IsTrue(oauth.AutoRegister, "系统未开启自动注册, 请先让管理员添加对应账号")
|
||||
@@ -164,6 +160,7 @@ func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity
|
||||
})
|
||||
biz.ErrIsNilAppendErr(err, "绑定用户失败: %s")
|
||||
accountId = account.Id
|
||||
isFirst = true
|
||||
} else {
|
||||
accountId = oauthAccount.AccountId
|
||||
}
|
||||
@@ -175,15 +172,13 @@ func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity
|
||||
err = a.AccountApp.GetAccount(account, "Id", "Name", "Username", "Password", "Status", "LastLoginTime", "LastLoginIp", "OtpSecret")
|
||||
biz.ErrIsNilAppendErr(err, "获取用户信息失败: %s")
|
||||
|
||||
clientIp := getIpAndRegion(rc)
|
||||
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", account.Username, clientIp)
|
||||
|
||||
res := LastLoginCheck(account, a.ConfigApp.GetConfig(sysentity.ConfigKeyAccountLoginSecurity).ToAccountLoginSecurity(), clientIp)
|
||||
res["action"] = "oauthLogin"
|
||||
b, err := json.Marshal(res)
|
||||
biz.ErrIsNil(err, "数据序列化失败")
|
||||
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
|
||||
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
|
||||
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
|
||||
"</html>")
|
||||
res["isFirstOauth2Login"] = isFirst
|
||||
rc.ResData = res
|
||||
}
|
||||
|
||||
func (a *Oauth2Login) getOAuthClient() (*oauth2.Config, *sysentity.ConfigOauth2Login) {
|
||||
@@ -198,7 +193,7 @@ func (a *Oauth2Login) getOAuthClient() (*oauth2.Config, *sysentity.ConfigOauth2L
|
||||
AuthURL: oath2LoginConfig.AuthorizationURL,
|
||||
TokenURL: oath2LoginConfig.AccessTokenURL,
|
||||
},
|
||||
RedirectURL: oath2LoginConfig.RedirectURL + "/api/auth/oauth2/callback",
|
||||
RedirectURL: oath2LoginConfig.RedirectURL + "/#/oauth2/callback",
|
||||
Scopes: strings.Split(oath2LoginConfig.Scopes, ","),
|
||||
}
|
||||
return client, oath2LoginConfig
|
||||
@@ -217,3 +212,7 @@ func (a *Oauth2Login) Oauth2Status(ctx *req.Ctx) {
|
||||
|
||||
ctx.ResData = res
|
||||
}
|
||||
|
||||
func (a *Oauth2Login) Oauth2Unbind(rc *req.Ctx) {
|
||||
a.Oauth2App.Unbind(rc.LoginAccount.Id)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ type Oauth2 interface {
|
||||
GetOAuthAccount(condition *entity.Oauth2Account, cols ...string) error
|
||||
|
||||
BindOAuthAccount(e *entity.Oauth2Account) error
|
||||
|
||||
Unbind(accountId uint64)
|
||||
}
|
||||
|
||||
func newAuthApp(oauthAccountRepo repository.Oauth2Account) Oauth2 {
|
||||
@@ -36,3 +38,7 @@ func (a *oauth2AppImpl) GetOAuthAccount(condition *entity.Oauth2Account, cols ..
|
||||
func (a *oauth2AppImpl) BindOAuthAccount(e *entity.Oauth2Account) error {
|
||||
return a.oauthAccountRepo.SaveOAuthAccount(e)
|
||||
}
|
||||
|
||||
func (a *oauth2AppImpl) Unbind(accountId uint64) {
|
||||
a.oauthAccountRepo.DeleteBy(&entity.Oauth2Account{AccountId: accountId})
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ type Oauth2Account interface {
|
||||
GetOAuthAccount(condition *entity.Oauth2Account, cols ...string) error
|
||||
|
||||
SaveOAuthAccount(e *entity.Oauth2Account) error
|
||||
|
||||
DeleteBy(e *entity.Oauth2Account)
|
||||
}
|
||||
|
||||
@@ -22,3 +22,7 @@ func (a *oauth2AccountRepoImpl) SaveOAuthAccount(e *entity.Oauth2Account) error
|
||||
}
|
||||
return gormx.UpdateById(e)
|
||||
}
|
||||
|
||||
func (a *oauth2AccountRepoImpl) DeleteBy(e *entity.Oauth2Account) {
|
||||
gormx.DeleteByCondition(e)
|
||||
}
|
||||
|
||||
@@ -45,9 +45,11 @@ func Init(router *gin.RouterGroup) {
|
||||
req.NewGet("/oauth2/bind", oauth2Login.OAuth2Bind),
|
||||
|
||||
// oauth2回调地址
|
||||
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSave("oauth2回调")).NoRes().DontNeedToken(),
|
||||
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSave("oauth2回调")).DontNeedToken(),
|
||||
|
||||
req.NewGet("/oauth2/status", oauth2Login.Oauth2Status),
|
||||
|
||||
req.NewGet("/oauth2/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSave("oauth2解绑")),
|
||||
}
|
||||
|
||||
req.BatchSetGroup(rg, reqs[:])
|
||||
|
||||
@@ -105,6 +105,13 @@ func (a *Account) UpdateAccount(rc *req.Ctx) {
|
||||
biz.IsTrue(utils.CheckAccountPasswordLever(updateAccount.Password), "密码强度必须8位以上且包含字⺟⼤⼩写+数字+特殊符号")
|
||||
updateAccount.Password = cryptox.PwdHash(updateAccount.Password)
|
||||
}
|
||||
|
||||
oldAcc := a.AccountApp.GetById(updateAccount.Id)
|
||||
// 账号创建十分钟内允许修改用户名(兼容oauth2首次登录修改用户名),否则不允许修改
|
||||
if oldAcc.CreateTime.Add(10 * time.Minute).Before(time.Now()) {
|
||||
// 禁止更新用户名,防止误传被更新
|
||||
updateAccount.Username = ""
|
||||
}
|
||||
a.AccountApp.Update(updateAccount)
|
||||
}
|
||||
|
||||
@@ -133,6 +140,8 @@ func (a *Account) SaveAccount(rc *req.Ctx) {
|
||||
biz.IsTrue(utils.CheckAccountPasswordLever(account.Password), "密码强度必须8位以上且包含字⺟⼤⼩写+数字+特殊符号")
|
||||
account.Password = cryptox.PwdHash(account.Password)
|
||||
}
|
||||
// 更新操作不允许修改用户名、防止误传更新
|
||||
account.Username = ""
|
||||
a.AccountApp.Update(account)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ type AccountCreateForm struct {
|
||||
}
|
||||
|
||||
type AccountUpdateForm struct {
|
||||
Password *string `json:"password" binding:"min=6,max=16"`
|
||||
Name string `json:"name" binding:"max=16"` // 姓名
|
||||
Username string `json:"username" binding:"max=20"`
|
||||
Password *string `json:"password"`
|
||||
}
|
||||
|
||||
type AccountChangePasswordForm struct {
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
type Account interface {
|
||||
GetAccount(condition *entity.Account, cols ...string) error
|
||||
|
||||
GetById(id uint64) *entity.Account
|
||||
|
||||
GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
|
||||
|
||||
Create(account *entity.Account)
|
||||
@@ -38,6 +40,10 @@ func (a *accountAppImpl) GetAccount(condition *entity.Account, cols ...string) e
|
||||
return a.accountRepo.GetAccount(condition, cols...)
|
||||
}
|
||||
|
||||
func (a *accountAppImpl) GetById(id uint64) *entity.Account {
|
||||
return a.accountRepo.GetById(id)
|
||||
}
|
||||
|
||||
func (a *accountAppImpl) GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
|
||||
return a.accountRepo.GetPageList(condition, pageParam, toEntity)
|
||||
}
|
||||
@@ -51,8 +57,12 @@ func (a *accountAppImpl) Create(account *entity.Account) {
|
||||
}
|
||||
|
||||
func (a *accountAppImpl) Update(account *entity.Account) {
|
||||
// 禁止更新用户名,防止误传被更新
|
||||
account.Username = ""
|
||||
if account.Username != "" {
|
||||
unAcc := &entity.Account{Username: account.Username}
|
||||
err := a.GetAccount(unAcc)
|
||||
biz.IsTrue(err != nil || unAcc.Id == account.Id, "该用户名已存在")
|
||||
}
|
||||
|
||||
a.accountRepo.Update(account)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ type Account interface {
|
||||
// 根据条件获取账号信息
|
||||
GetAccount(condition *entity.Account, cols ...string) error
|
||||
|
||||
GetById(id uint64) *entity.Account
|
||||
|
||||
GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
|
||||
|
||||
Insert(account *entity.Account)
|
||||
|
||||
@@ -18,6 +18,14 @@ func (a *accountRepoImpl) GetAccount(condition *entity.Account, cols ...string)
|
||||
return gormx.GetBy(condition, cols...)
|
||||
}
|
||||
|
||||
func (a *accountRepoImpl) GetById(id uint64) *entity.Account {
|
||||
ac := new(entity.Account)
|
||||
if err := gormx.GetById(ac, id); err != nil {
|
||||
return nil
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
func (m *accountRepoImpl) GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
|
||||
qd := gormx.NewQuery(new(entity.Account)).
|
||||
Like("name", condition.Name).
|
||||
|
||||
Reference in New Issue
Block a user