refactor: oauth2登录调整

This commit is contained in:
meilin.huang
2023-07-24 22:36:07 +08:00
parent 155ae65b4a
commit 5083b2bdfe
19 changed files with 227 additions and 60 deletions

View File

@@ -1,7 +1,6 @@
package api
import (
"encoding/json"
"fmt"
"io"
"mayfly-go/internal/auth/api/vo"
@@ -59,7 +58,7 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
biz.NotEmpty(stateAction, "state已过期, 请重新登录")
token, err := client.Exchange(rc.GinCtx, code)
biz.ErrIsNilAppendErr(err, "获取token失败: %s")
biz.ErrIsNilAppendErr(err, "获取OAuth2 accessToken失败: %s")
// 获取用户信息
httpCli := client.Client(rc.GinCtx.Request.Context(), token)
@@ -104,7 +103,12 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
err = a.Oauth2App.GetOAuthAccount(&entity.Oauth2Account{
AccountId: accountId,
}, "account_id", "identity")
biz.IsTrue(err != nil, "该账号已被绑定")
biz.IsTrue(err != nil, "该账号已被其他用户绑定")
err = a.Oauth2App.GetOAuthAccount(&entity.Oauth2Account{
Identity: userId,
}, "account_id", "identity")
biz.IsTrue(err != nil, "您已绑定其他账号")
now := time.Now()
err = a.Oauth2App.BindOAuthAccount(&entity.Oauth2Account{
@@ -118,13 +122,7 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
"action": "oauthBind",
"bind": true,
}
b, err = json.Marshal(res)
biz.ErrIsNil(err, "数据序列化失败")
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
"</html>")
rc.ResData = res
} else {
panic(biz.NewBizErr("state不合法"))
}
@@ -132,14 +130,12 @@ func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
// 指定登录操作
func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity.ConfigOauth2Login) {
clientIp := getIpAndRegion(rc)
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", userId, clientIp)
// 查询用户是否存在
oauthAccount := &entity.Oauth2Account{Identity: userId}
err := a.Oauth2App.GetOAuthAccount(oauthAccount, "account_id", "identity")
var accountId uint64
isFirst := false
// 不存在,进行注册
if err != nil {
biz.IsTrue(oauth.AutoRegister, "系统未开启自动注册, 请先让管理员添加对应账号")
@@ -164,6 +160,7 @@ func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity
})
biz.ErrIsNilAppendErr(err, "绑定用户失败: %s")
accountId = account.Id
isFirst = true
} else {
accountId = oauthAccount.AccountId
}
@@ -175,15 +172,13 @@ func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *sysentity
err = a.AccountApp.GetAccount(account, "Id", "Name", "Username", "Password", "Status", "LastLoginTime", "LastLoginIp", "OtpSecret")
biz.ErrIsNilAppendErr(err, "获取用户信息失败: %s")
clientIp := getIpAndRegion(rc)
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", account.Username, clientIp)
res := LastLoginCheck(account, a.ConfigApp.GetConfig(sysentity.ConfigKeyAccountLoginSecurity).ToAccountLoginSecurity(), clientIp)
res["action"] = "oauthLogin"
b, err := json.Marshal(res)
biz.ErrIsNil(err, "数据序列化失败")
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
"</html>")
res["isFirstOauth2Login"] = isFirst
rc.ResData = res
}
func (a *Oauth2Login) getOAuthClient() (*oauth2.Config, *sysentity.ConfigOauth2Login) {
@@ -198,7 +193,7 @@ func (a *Oauth2Login) getOAuthClient() (*oauth2.Config, *sysentity.ConfigOauth2L
AuthURL: oath2LoginConfig.AuthorizationURL,
TokenURL: oath2LoginConfig.AccessTokenURL,
},
RedirectURL: oath2LoginConfig.RedirectURL + "/api/auth/oauth2/callback",
RedirectURL: oath2LoginConfig.RedirectURL + "/#/oauth2/callback",
Scopes: strings.Split(oath2LoginConfig.Scopes, ","),
}
return client, oath2LoginConfig
@@ -217,3 +212,7 @@ func (a *Oauth2Login) Oauth2Status(ctx *req.Ctx) {
ctx.ResData = res
}
func (a *Oauth2Login) Oauth2Unbind(rc *req.Ctx) {
a.Oauth2App.Unbind(rc.LoginAccount.Id)
}

View File

@@ -12,6 +12,8 @@ type Oauth2 interface {
GetOAuthAccount(condition *entity.Oauth2Account, cols ...string) error
BindOAuthAccount(e *entity.Oauth2Account) error
Unbind(accountId uint64)
}
func newAuthApp(oauthAccountRepo repository.Oauth2Account) Oauth2 {
@@ -36,3 +38,7 @@ func (a *oauth2AppImpl) GetOAuthAccount(condition *entity.Oauth2Account, cols ..
func (a *oauth2AppImpl) BindOAuthAccount(e *entity.Oauth2Account) error {
return a.oauthAccountRepo.SaveOAuthAccount(e)
}
func (a *oauth2AppImpl) Unbind(accountId uint64) {
a.oauthAccountRepo.DeleteBy(&entity.Oauth2Account{AccountId: accountId})
}

View File

@@ -7,4 +7,6 @@ type Oauth2Account interface {
GetOAuthAccount(condition *entity.Oauth2Account, cols ...string) error
SaveOAuthAccount(e *entity.Oauth2Account) error
DeleteBy(e *entity.Oauth2Account)
}

View File

@@ -22,3 +22,7 @@ func (a *oauth2AccountRepoImpl) SaveOAuthAccount(e *entity.Oauth2Account) error
}
return gormx.UpdateById(e)
}
func (a *oauth2AccountRepoImpl) DeleteBy(e *entity.Oauth2Account) {
gormx.DeleteByCondition(e)
}

View File

@@ -45,9 +45,11 @@ func Init(router *gin.RouterGroup) {
req.NewGet("/oauth2/bind", oauth2Login.OAuth2Bind),
// oauth2回调地址
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSave("oauth2回调")).NoRes().DontNeedToken(),
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSave("oauth2回调")).DontNeedToken(),
req.NewGet("/oauth2/status", oauth2Login.Oauth2Status),
req.NewGet("/oauth2/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSave("oauth2解绑")),
}
req.BatchSetGroup(rg, reqs[:])

View File

@@ -105,6 +105,13 @@ func (a *Account) UpdateAccount(rc *req.Ctx) {
biz.IsTrue(utils.CheckAccountPasswordLever(updateAccount.Password), "密码强度必须8位以上且包含字⺟⼤⼩写+数字+特殊符号")
updateAccount.Password = cryptox.PwdHash(updateAccount.Password)
}
oldAcc := a.AccountApp.GetById(updateAccount.Id)
// 账号创建十分钟内允许修改用户名兼容oauth2首次登录修改用户名否则不允许修改
if oldAcc.CreateTime.Add(10 * time.Minute).Before(time.Now()) {
// 禁止更新用户名,防止误传被更新
updateAccount.Username = ""
}
a.AccountApp.Update(updateAccount)
}
@@ -133,6 +140,8 @@ func (a *Account) SaveAccount(rc *req.Ctx) {
biz.IsTrue(utils.CheckAccountPasswordLever(account.Password), "密码强度必须8位以上且包含字⺟⼤⼩写+数字+特殊符号")
account.Password = cryptox.PwdHash(account.Password)
}
// 更新操作不允许修改用户名、防止误传更新
account.Username = ""
a.AccountApp.Update(account)
}
}

View File

@@ -8,7 +8,9 @@ type AccountCreateForm struct {
}
type AccountUpdateForm struct {
Password *string `json:"password" binding:"min=6,max=16"`
Name string `json:"name" binding:"max=16"` // 姓名
Username string `json:"username" binding:"max=20"`
Password *string `json:"password"`
}
type AccountChangePasswordForm struct {

View File

@@ -14,6 +14,8 @@ import (
type Account interface {
GetAccount(condition *entity.Account, cols ...string) error
GetById(id uint64) *entity.Account
GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
Create(account *entity.Account)
@@ -38,6 +40,10 @@ func (a *accountAppImpl) GetAccount(condition *entity.Account, cols ...string) e
return a.accountRepo.GetAccount(condition, cols...)
}
func (a *accountAppImpl) GetById(id uint64) *entity.Account {
return a.accountRepo.GetById(id)
}
func (a *accountAppImpl) GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
return a.accountRepo.GetPageList(condition, pageParam, toEntity)
}
@@ -51,8 +57,12 @@ func (a *accountAppImpl) Create(account *entity.Account) {
}
func (a *accountAppImpl) Update(account *entity.Account) {
// 禁止更新用户名,防止误传被更新
account.Username = ""
if account.Username != "" {
unAcc := &entity.Account{Username: account.Username}
err := a.GetAccount(unAcc)
biz.IsTrue(err != nil || unAcc.Id == account.Id, "该用户名已存在")
}
a.accountRepo.Update(account)
}

View File

@@ -9,6 +9,8 @@ type Account interface {
// 根据条件获取账号信息
GetAccount(condition *entity.Account, cols ...string) error
GetById(id uint64) *entity.Account
GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
Insert(account *entity.Account)

View File

@@ -18,6 +18,14 @@ func (a *accountRepoImpl) GetAccount(condition *entity.Account, cols ...string)
return gormx.GetBy(condition, cols...)
}
func (a *accountRepoImpl) GetById(id uint64) *entity.Account {
ac := new(entity.Account)
if err := gormx.GetById(ac, id); err != nil {
return nil
}
return ac
}
func (m *accountRepoImpl) GetPageList(condition *entity.Account, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
qd := gormx.NewQuery(new(entity.Account)).
Like("name", condition.Name).