diff --git a/internal/db/models/login_session_dao.go b/internal/db/models/login_session_dao.go index 3e469d1e..208c2c5d 100644 --- a/internal/db/models/login_session_dao.go +++ b/internal/db/models/login_session_dao.go @@ -84,7 +84,6 @@ func (this *LoginSessionDAO) WriteSessionValue(tx *dbs.Tx, sid string, key strin return err } var sessionId int64 - var isNewSession = false var valueMap = maps.Map{} if sessionOne != nil { var session = sessionOne.(*LoginSession) @@ -113,7 +112,6 @@ func (this *LoginSessionDAO) WriteSessionValue(tx *dbs.Tx, sid string, key strin if err != nil { return err } - isNewSession = true } var sessionOp = NewLoginSessionOperator() @@ -133,24 +131,17 @@ func (this *LoginSessionDAO) WriteSessionValue(tx *dbs.Tx, sid string, key strin if adminId > 0 || userId > 0 { sessionOp.AdminId = adminId sessionOp.UserId = userId - - if isNewSession { - // 删除此用户之前创建的SESSION,不再保存以往的SESSION,避免安全问题 - err = this.Query(tx). - ResultPk(). - Attr("adminId", adminId). - Attr("userId", userId). - Neq("sid", sid). - DeleteQuickly() - if err != nil { - return err - } - } } // 写入数据 valueMap[key] = value sessionOp.Values = valueMap.AsJSON() + + // IP + if key == "@ip" { + sessionOp.Ip = value + } + return this.Save(tx, sessionOp) } @@ -182,3 +173,45 @@ func (this *LoginSessionDAO) FindSession(tx *dbs.Tx, sid string) (*LoginSession, } return session, nil } + +func (this *LoginSessionDAO) ClearOldSessions(tx *dbs.Tx, adminId int64, userId int64, sid string, ip string) error { + // 删除此用户之前创建的SESSION + err := this.Query(tx). + Attr("adminId", adminId). + Attr("userId", userId). + Neq("sid", sid). + Neq("ip", ip). // 同一个IP允许多个SID,因为有人可能会同时使用手机端和PC端 + DeleteQuickly() + if err != nil { + return err + } + + // 删除过多的SESSION + oldOnes, queryErr := this.Query(tx). + ResultPk(). + Attr("adminId", adminId). + Attr("userId", userId). + Neq("sid", sid). + AscPk(). + FindAll() + if queryErr != nil { + return queryErr + } + var oldCount = len(oldOnes) + if oldCount > 3 { + for _, oldOne := range oldOnes[:oldCount-3] { + var oldId = oldOne.(*LoginSession).Id + if oldOne.(*LoginSession).Sid == sid { + continue + } + err = this.Query(tx). + Pk(oldId). + DeleteQuickly() + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/rpc/services/service_login_session.go b/internal/rpc/services/service_login_session.go index 0192a44d..49595550 100644 --- a/internal/rpc/services/service_login_session.go +++ b/internal/rpc/services/service_login_session.go @@ -84,3 +84,31 @@ func (this *LoginSessionService) FindLoginSession(ctx context.Context, req *pb.F }, }, nil } + +// ClearOldLoginSessions 清理老的SESSION +func (this *LoginSessionService) ClearOldLoginSessions(ctx context.Context, req *pb.ClearOldLoginSessionsRequest) (*pb.RPCSuccess, error) { + _, _, err := this.ValidateAdminAndUser(ctx, false) + if err != nil { + return nil, err + } + + if len(req.Sid) == 0 { + return nil, errors.New("'token' should not be empty") + } + + var tx = this.NullTx() + session, err := models.SharedLoginSessionDAO.FindSession(tx, req.Sid) + if err != nil { + return nil, err + } + if session == nil || !session.IsAvailable() { + return nil, errors.New("invalid sid") + } + + err = models.SharedLoginSessionDAO.ClearOldSessions(tx, int64(session.AdminId), int64(session.UserId), req.Sid, req.Ip) + if err != nil { + return nil, err + } + + return this.Success() +}