feat: OAuth2 登录

This commit is contained in:
王一之
2023-07-21 21:49:49 +08:00
parent 513f8ea012
commit 062d28b6e6
6 changed files with 190 additions and 75 deletions

View File

@@ -20,6 +20,7 @@ import (
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils"
"net/http"
"strconv"
"strings"
"time"
)
@@ -70,7 +71,8 @@ func (a *Auth) OAuth2Callback(rc *req.Ctx) {
if state == "" {
biz.ErrIsNil(errors.New("state不能为空"), "state不能为空")
}
if cache.GetStr("oauth2:state:"+state) == "" {
stateAction := cache.GetStr("oauth2:state:" + state)
if stateAction == "" {
biz.ErrIsNil(errors.New("state已过期请重新登录"), "state已过期请重新登录")
}
token, err := client.Exchange(rc.GinCtx, code)
@@ -115,86 +117,117 @@ func (a *Auth) OAuth2Callback(rc *req.Ctx) {
// 查询用户是否存在
oauthAccount := &entity.OAuthAccount{Identity: userId}
err = a.AuthApp.GetOAuthAccount(oauthAccount, "account_id", "identity")
var accountId uint64
if err != nil {
if err != gorm.ErrRecordNotFound {
biz.ErrIsNil(err, "查询用户失败: "+err.Error())
// 判断是登录还是绑定
if stateAction == "login" {
var accountId uint64
if err != nil {
if err != gorm.ErrRecordNotFound {
biz.ErrIsNil(err, "查询用户失败: "+err.Error())
}
// 不存在,进行注册
if !oauth.AutoRegister {
biz.ErrIsNil(errors.New("用户不存在,请先注册"), "用户不存在,请先注册")
}
now := time.Now()
account := &entity.Account{
Model: model.Model{
CreateTime: &now,
CreatorId: 0,
Creator: "oauth2",
UpdateTime: &now,
},
Name: userId,
Username: userId,
}
a.AccountApp.Create(account)
// 绑定
if err := a.AuthApp.BindOAuthAccount(&entity.OAuthAccount{
AccountId: account.Id,
Identity: oauthAccount.Identity,
CreateTime: &now,
UpdateTime: &now,
}); err != nil {
biz.ErrIsNil(err, "绑定用户失败: "+err.Error())
}
accountId = account.Id
} else {
accountId = oauthAccount.AccountId
}
// 不存在,进行注册
if !oauth.AutoRegister {
biz.ErrIsNil(errors.New("用户不存在,请先注册"), "用户不存在,请先注册")
// 进行登录
account := &entity.Account{
Model: model.Model{Id: accountId},
}
if err := a.AccountApp.GetAccount(account, "Id", "Name", "Username", "Password", "Status", "LastLoginTime", "LastLoginIp", "OtpSecret"); err != nil {
biz.ErrIsNil(err, "获取用户信息失败: "+err.Error())
}
biz.IsTrue(account.IsEnable(), "该账号不可用")
// 访问系统使用的token
accessToken := req.CreateToken(accountId, account.Username)
// 默认为不校验otp
otpStatus := OtpStatusNone
clientIp := rc.GinCtx.ClientIP()
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", account.Username, clientIp)
res := map[string]any{
"name": account.Name,
"username": account.Username,
"lastLoginTime": account.LastLoginTime,
"lastLoginIp": account.LastLoginIp,
}
accountLoginSecurity := a.ConfigApp.GetConfig(entity.ConfigKeyAccountLoginSecurity).ToAccountLoginSecurity()
// 判断otp
if accountLoginSecurity.UseOtp {
otpInfo, otpurl, otpToken := useOtp(account, accountLoginSecurity.OtpIssuer, accessToken)
otpStatus = otpInfo.OptStatus
if otpurl != "" {
res["otpUrl"] = otpurl
}
accessToken = otpToken
} else {
// 保存登录消息
go saveLogin(a.AccountApp, a.MsgApp, account, rc.GinCtx.ClientIP())
}
// 赋值otp状态
res["action"] = "oauthLogin"
res["otp"] = otpStatus
res["token"] = accessToken
b, err = json.Marshal(res)
biz.ErrIsNil(err, "数据序列化失败")
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
"</html>")
} else if sAccountId, ok := strings.CutPrefix(stateAction, "bind:"); ok {
// 绑定
accountId, err := strconv.ParseUint(sAccountId, 10, 64)
if err != nil {
biz.ErrIsNil(err, "绑定用户失败: "+err.Error())
}
now := time.Now()
account := &entity.Account{
Model: model.Model{
CreateTime: &now,
CreatorId: 0,
Creator: "oauth2",
UpdateTime: &now,
},
Name: userId,
Username: userId,
}
a.AccountApp.Create(account)
// 绑定
if err := a.AuthApp.BindOAuthAccount(&entity.OAuthAccount{
AccountId: account.Id,
AccountId: accountId,
Identity: oauthAccount.Identity,
CreateTime: &now,
UpdateTime: &now,
}); err != nil {
biz.ErrIsNil(err, "绑定用户失败: "+err.Error())
}
accountId = account.Id
} else {
accountId = oauthAccount.AccountId
}
// 进行登录
account := &entity.Account{
Model: model.Model{Id: accountId},
}
if err := a.AccountApp.GetAccount(account, "Id", "Name", "Username", "Password", "Status", "LastLoginTime", "LastLoginIp", "OtpSecret"); err != nil {
biz.ErrIsNil(err, "获取用户信息失败: "+err.Error())
}
biz.IsTrue(account.IsEnable(), "该账号不可用")
// 访问系统使用的token
accessToken := req.CreateToken(accountId, account.Username)
// 默认为不校验otp
otpStatus := OtpStatusNone
clientIp := rc.GinCtx.ClientIP()
rc.ReqParam = fmt.Sprintf("oauth2 login username: %s | ip: %s", account.Username, clientIp)
res := map[string]any{
"name": account.Name,
"username": account.Username,
"lastLoginTime": account.LastLoginTime,
"lastLoginIp": account.LastLoginIp,
}
accountLoginSecurity := a.ConfigApp.GetConfig(entity.ConfigKeyAccountLoginSecurity).ToAccountLoginSecurity()
// 判断otp
if accountLoginSecurity.UseOtp {
otpInfo, otpurl, otpToken := useOtp(account, accountLoginSecurity.OtpIssuer, accessToken)
otpStatus = otpInfo.OptStatus
if otpurl != "" {
res["otpUrl"] = otpurl
res := map[string]any{
"action": "oauthBind",
"bind": true,
}
accessToken = otpToken
b, err = json.Marshal(res)
biz.ErrIsNil(err, "数据序列化失败")
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
"</html>")
} else {
// 保存登录消息
go saveLogin(a.AccountApp, a.MsgApp, account, rc.GinCtx.ClientIP())
biz.ErrIsNil(errors.New("state不合法"), "state不合法")
}
// 赋值otp状态
res["action"] = "oauthLogin"
res["otp"] = otpStatus
res["token"] = accessToken
b, err = json.Marshal(res)
biz.ErrIsNil(err, "数据序列化失败")
rc.GinCtx.Header("Content-Type", "text/html; charset=utf-8")
rc.GinCtx.Writer.WriteHeader(http.StatusOK)
_, _ = rc.GinCtx.Writer.WriteString("<html>" +
"<script>top.opener.postMessage(" + string(b) + ")</script>" +
"</html>")
}
func (a *Auth) getOAuthClient() (*oauth2.Config, *vo.OAuth2VO, error) {
@@ -265,3 +298,42 @@ func (a *Auth) SaveOAuth2(rc *req.Ctx) {
config.Remark = AuthOAuth2Remark
a.ConfigApp.Save(config)
}
func (a *Auth) OAuth2Bind(rc *req.Ctx) {
client, _, err := a.getOAuthClient()
if err != nil {
biz.ErrIsNil(err, "获取oauth2 client失败: "+err.Error())
return
}
state := utils.RandString(32)
cache.SetStr("oauth2:state:"+state, "bind:"+strconv.FormatUint(rc.LoginAccount.Id, 10),
5*time.Minute)
rc.GinCtx.Redirect(http.StatusFound, client.AuthCodeURL(state))
}
func (a *Auth) Auth2Status(ctx *req.Ctx) {
res := &vo.AuthStatusVO{}
config := a.ConfigApp.GetConfig(AuthOAuth2Key)
if config.Value != "" {
oauth2 := &vo.OAuth2VO{}
if err := json.Unmarshal([]byte(config.Value), oauth2); err != nil {
global.Log.Warnf("解析自定义oauth2配置失败err%s", err.Error())
biz.ErrIsNil(err, "解析自定义oauth2配置失败")
} else if oauth2.ClientID != "" {
res.Enable.OAuth2 = true
}
}
if res.Enable.OAuth2 {
err := a.AuthApp.GetOAuthAccount(&entity.OAuthAccount{
AccountId: ctx.LoginAccount.Id,
}, "account_id", "identity")
if err != nil {
if err != gorm.ErrRecordNotFound {
biz.ErrIsNil(err, "查询用户失败: "+err.Error())
}
} else {
res.Bind.OAuth2 = true
}
}
ctx.ResData = res
}

View File

@@ -53,7 +53,7 @@ func (c *Config) SaveConfig(rc *req.Ctx) {
// AuthConfig auth相关配置
func (c *Config) AuthConfig(rc *req.Ctx) {
resp := &vo.OAuth2EnableVO{}
resp := &vo.Auth2EnableVO{}
config := c.ConfigApp.GetConfig(AuthOAuth2Key)
oauth2 := &vo.OAuth2VO{}
if config.Value != "" {

View File

@@ -16,6 +16,11 @@ type AuthVO struct {
*OAuth2VO `json:"oauth2"`
}
type OAuth2EnableVO struct {
type Auth2EnableVO struct {
OAuth2 bool `json:"oauth2"`
}
type AuthStatusVO struct {
Enable Auth2EnableVO `json:"enable"`
Bind Auth2EnableVO `json:"bind"`
}

View File

@@ -24,6 +24,9 @@ func InitSysAuthRouter(router *gin.RouterGroup) {
req.NewPut("/oauth2", r.SaveOAuth2).RequiredPermission(baseP),
req.NewGet("/status", r.Auth2Status),
req.NewGet("/oauth2/bind", r.OAuth2Bind),
req.NewGet("/oauth2/login", r.OAuth2Login).DontNeedToken(),
req.NewGet("/oauth2/callback", r.OAuth2Callback).NoRes().DontNeedToken(),
}