Files
mayfly-go/server/internal/auth/api/oauth2_login.go

249 lines
7.6 KiB
Go
Raw Normal View History

2023-07-22 20:51:46 +08:00
package api
import (
"context"
2023-07-22 20:51:46 +08:00
"fmt"
"io"
"mayfly-go/internal/auth/api/vo"
"mayfly-go/internal/auth/application"
"mayfly-go/internal/auth/config"
2023-07-22 20:51:46 +08:00
"mayfly-go/internal/auth/domain/entity"
2024-11-20 22:43:53 +08:00
"mayfly-go/internal/auth/imsg"
2023-07-22 20:51:46 +08:00
sysapp "mayfly-go/internal/sys/application"
sysentity "mayfly-go/internal/sys/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/errorx"
2023-07-22 20:51:46 +08:00
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
2023-10-12 12:14:56 +08:00
"mayfly-go/pkg/utils/collx"
2023-07-22 20:51:46 +08:00
"mayfly-go/pkg/utils/jsonx"
"mayfly-go/pkg/utils/stringx"
"net/http"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
)
type Oauth2Login struct {
oauth2App application.Oauth2 `inject:"T"`
accountApp sysapp.Account `inject:"T"`
}
func (o *Oauth2Login) ReqConfs() *req.Confs {
reqs := [...]*req.Conf{
req.NewGet("/config", o.Oauth2Config).DontNeedToken(),
// oauth2登录
req.NewGet("/login", o.OAuth2Login).DontNeedToken(),
req.NewGet("/bind", o.OAuth2Bind),
// oauth2回调地址
req.NewGet("/callback", o.OAuth2Callback).Log(req.NewLogSaveI(imsg.LogOauth2Callback)).DontNeedToken(),
req.NewGet("/status", o.Oauth2Status),
req.NewGet("/unbind", o.Oauth2Unbind).Log(req.NewLogSaveI(imsg.LogOauth2Unbind)),
}
return req.NewConfs("/auth/oauth2", reqs[:]...)
2023-07-22 20:51:46 +08:00
}
func (a *Oauth2Login) OAuth2Login(rc *req.Ctx) {
client, _ := a.getOAuthClient()
state := stringx.Rand(32)
cache.SetStr("oauth2:state:"+state, "login", 5*time.Minute)
2024-02-25 12:46:18 +08:00
rc.Redirect(http.StatusFound, client.AuthCodeURL(state))
2023-07-22 20:51:46 +08:00
}
func (a *Oauth2Login) OAuth2Bind(rc *req.Ctx) {
client, _ := a.getOAuthClient()
state := stringx.Rand(32)
cache.SetStr("oauth2:state:"+state, "bind:"+strconv.FormatUint(rc.GetLoginAccount().Id, 10),
2023-07-22 20:51:46 +08:00
5*time.Minute)
2024-02-25 12:46:18 +08:00
rc.Redirect(http.StatusFound, client.AuthCodeURL(state))
2023-07-22 20:51:46 +08:00
}
func (a *Oauth2Login) OAuth2Callback(rc *req.Ctx) {
client, oauth := a.getOAuthClient()
2024-02-25 12:46:18 +08:00
code := rc.Query("code")
2024-11-20 22:43:53 +08:00
biz.NotEmpty(code, "code cannot be empty")
2023-07-22 20:51:46 +08:00
2024-02-25 12:46:18 +08:00
state := rc.Query("state")
2024-11-20 22:43:53 +08:00
biz.NotEmpty(state, "state canot be empty")
2023-07-22 20:51:46 +08:00
stateAction := cache.GetStr("oauth2:state:" + state)
biz.NotEmpty(stateAction, "state已过期, 请重新登录")
token, err := client.Exchange(rc, code)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "get OAuth2 accessToken fail: %s")
2023-07-22 20:51:46 +08:00
// 获取用户信息
2024-02-25 12:46:18 +08:00
httpCli := client.Client(rc.GetRequest().Context(), token)
2023-07-22 20:51:46 +08:00
resp, err := httpCli.Get(oauth.ResourceURL)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "get user info error: %s")
2023-07-22 20:51:46 +08:00
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to read the response user information: %s")
2023-07-22 20:51:46 +08:00
// UserIdentifier格式为 type:fieldPath。如string:user.username 或 number:user.id
userIdTypeAndFieldPath := strings.Split(oauth.UserIdentifier, ":")
2024-11-20 22:43:53 +08:00
biz.IsTrue(len(userIdTypeAndFieldPath) == 2, "oauth2 configuration property 'UserIdentifier' is not compliant")
2023-07-22 20:51:46 +08:00
// 解析用户唯一标识
userIdFieldPath := userIdTypeAndFieldPath[1]
userId := ""
if userIdTypeAndFieldPath[0] == "string" {
userId, err = jsonx.GetStringByBytes(b, userIdFieldPath)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to resolve the user unique identity: %s")
2023-07-22 20:51:46 +08:00
} else {
intUserId, err := jsonx.GetIntByBytes(b, userIdFieldPath)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to resolve the user unique identity: %s")
2023-07-22 20:51:46 +08:00
userId = fmt.Sprintf("%d", intUserId)
}
2024-11-20 22:43:53 +08:00
biz.NotBlank(userId, "the user unique identification field value cannot be null")
2023-07-22 20:51:46 +08:00
// 判断是登录还是绑定
if stateAction == "login" {
a.doLoginAction(rc, userId, oauth)
} else if sAccountId, ok := strings.CutPrefix(stateAction, "bind:"); ok {
// 绑定
accountId, err := strconv.ParseUint(sAccountId, 10, 64)
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to bind user: %s")
2023-07-22 20:51:46 +08:00
account := new(sysentity.Account)
account.Id = accountId
err = a.accountApp.GetByCond(model.NewModelCond(account).Columns("username"))
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "this account does not exist")
2023-10-12 12:14:56 +08:00
rc.ReqParam = collx.Kvs("username", account.Username, "type", "bind")
2023-07-22 20:51:46 +08:00
err = a.oauth2App.GetOAuthAccount(&entity.Oauth2Account{
2023-07-22 20:51:46 +08:00
AccountId: accountId,
}, "account_id", "identity")
2024-11-20 22:43:53 +08:00
biz.IsTrue(err != nil, "the account has been linked by another user")
2023-07-24 22:36:07 +08:00
err = a.oauth2App.GetOAuthAccount(&entity.Oauth2Account{
2023-07-24 22:36:07 +08:00
Identity: userId,
}, "account_id", "identity")
2024-11-20 22:43:53 +08:00
biz.IsTrue(err != nil, "you are bound to another account")
2023-07-22 20:51:46 +08:00
now := time.Now()
err = a.oauth2App.BindOAuthAccount(&entity.Oauth2Account{
2023-07-22 20:51:46 +08:00
AccountId: accountId,
Identity: userId,
CreateTime: &now,
UpdateTime: &now,
})
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to bind user: %s")
2023-10-12 12:14:56 +08:00
res := collx.M{
2023-07-22 20:51:46 +08:00
"action": "oauthBind",
"bind": true,
}
2023-07-24 22:36:07 +08:00
rc.ResData = res
2023-07-22 20:51:46 +08:00
} else {
2024-11-20 22:43:53 +08:00
panic(errorx.NewBiz("state is invalid"))
2023-07-22 20:51:46 +08:00
}
}
// 指定登录操作
func (a *Oauth2Login) doLoginAction(rc *req.Ctx, userId string, oauth *config.Oauth2Login) {
2023-07-22 20:51:46 +08:00
// 查询用户是否存在
oauthAccount := &entity.Oauth2Account{Identity: userId}
err := a.oauth2App.GetOAuthAccount(oauthAccount, "account_id", "identity")
2024-11-20 22:43:53 +08:00
ctx := rc.MetaCtx
2023-07-22 20:51:46 +08:00
var accountId uint64
2023-07-24 22:36:07 +08:00
isFirst := false
2023-07-22 20:51:46 +08:00
// 不存在,进行注册
if err != nil {
2024-11-20 22:43:53 +08:00
biz.IsTrueI(ctx, oauth.AutoRegister, imsg.ErrOauth2NoAutoRegister)
2023-07-22 20:51:46 +08:00
now := time.Now()
account := &sysentity.Account{
Model: model.Model{
CreateModel: model.CreateModel{
CreateTime: &now,
CreatorId: 0,
Creator: "oauth2",
},
2023-07-22 20:51:46 +08:00
UpdateTime: &now,
},
Name: userId,
Username: userId,
}
biz.ErrIsNil(a.accountApp.Create(context.TODO(), account))
2023-07-22 20:51:46 +08:00
// 绑定
err := a.oauth2App.BindOAuthAccount(&entity.Oauth2Account{
2023-07-22 20:51:46 +08:00
AccountId: account.Id,
Identity: oauthAccount.Identity,
CreateTime: &now,
UpdateTime: &now,
})
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "failed to bind user: %s")
2023-07-22 20:51:46 +08:00
accountId = account.Id
2023-07-24 22:36:07 +08:00
isFirst = true
2023-07-22 20:51:46 +08:00
} else {
accountId = oauthAccount.AccountId
}
// 进行登录
account, err := a.accountApp.GetById(accountId, "Id", "Name", "Username", "Password", "Status", "LastLoginTime", "LastLoginIp", "OtpSecret")
2024-11-20 22:43:53 +08:00
biz.ErrIsNilAppendErr(err, "get user info error: %s")
2023-07-22 20:51:46 +08:00
2023-07-24 22:36:07 +08:00
clientIp := getIpAndRegion(rc)
2023-10-12 12:14:56 +08:00
rc.ReqParam = collx.Kvs("username", account.Username, "ip", clientIp, "type", "login")
2023-07-24 22:36:07 +08:00
2024-11-20 22:43:53 +08:00
res := LastLoginCheck(ctx, account, config.GetAccountLoginSecurity(), clientIp)
2023-07-22 20:51:46 +08:00
res["action"] = "oauthLogin"
2023-07-24 22:36:07 +08:00
res["isFirstOauth2Login"] = isFirst
rc.ResData = res
2023-07-22 20:51:46 +08:00
}
func (a *Oauth2Login) getOAuthClient() (*oauth2.Config, *config.Oauth2Login) {
oath2LoginConfig := config.GetOauth2Login()
2024-11-20 22:43:53 +08:00
biz.IsTrue(oath2LoginConfig.Enable, "please configure oauth2 or enable oauth2 login first")
biz.IsTrue(oath2LoginConfig.ClientId != "", "oauth2 clientId cannot be empty")
2023-07-22 20:51:46 +08:00
client := &oauth2.Config{
ClientID: oath2LoginConfig.ClientId,
ClientSecret: oath2LoginConfig.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oath2LoginConfig.AuthorizationURL,
TokenURL: oath2LoginConfig.AccessTokenURL,
},
2023-07-24 22:36:07 +08:00
RedirectURL: oath2LoginConfig.RedirectURL + "/#/oauth2/callback",
2023-07-22 20:51:46 +08:00
Scopes: strings.Split(oath2LoginConfig.Scopes, ","),
}
return client, oath2LoginConfig
}
func (a *Oauth2Login) Oauth2Status(ctx *req.Ctx) {
res := &vo.Oauth2Status{}
oauth2LoginConfig := config.GetOauth2Login()
2023-07-22 20:51:46 +08:00
res.Enable = oauth2LoginConfig.Enable
if res.Enable {
err := a.oauth2App.GetOAuthAccount(&entity.Oauth2Account{
AccountId: ctx.GetLoginAccount().Id,
2023-07-22 20:51:46 +08:00
}, "account_id", "identity")
res.Bind = err == nil
}
ctx.ResData = res
}
2023-07-24 22:36:07 +08:00
func (a *Oauth2Login) Oauth2Unbind(rc *req.Ctx) {
a.oauth2App.Unbind(rc.GetLoginAccount().Id)
2023-07-24 22:36:07 +08:00
}
// 获取oauth2登录配置信息因为有些字段是敏感字段故单独使用接口获取
func (c *Oauth2Login) Oauth2Config(rc *req.Ctx) {
oauth2LoginConfig := config.GetOauth2Login()
2023-10-12 12:14:56 +08:00
rc.ResData = collx.M{
"enable": oauth2LoginConfig.Enable,
"name": oauth2LoginConfig.Name,
}
}