diff --git a/internal/db/models/user_dao.go b/internal/db/models/user_dao.go index 5b1efa40..8153d091 100644 --- a/internal/db/models/user_dao.go +++ b/internal/db/models/user_dao.go @@ -10,6 +10,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" stringutil "github.com/iwind/TeaGo/utils/string" timeutil "github.com/iwind/TeaGo/utils/time" @@ -352,7 +353,7 @@ func (this *UserDAO) FindUserClusterId(tx *dbs.Tx, userId int64) (int64, error) FindInt64Col(0) } -// UpdateUserFeatures 更新用户Features +// UpdateUserFeatures 更新单个用户Features func (this *UserDAO) UpdateUserFeatures(tx *dbs.Tx, userId int64, featuresJSON []byte) error { if userId <= 0 { return errors.New("invalid userId") @@ -370,6 +371,74 @@ func (this *UserDAO) UpdateUserFeatures(tx *dbs.Tx, userId int64, featuresJSON [ return nil } +// UpdateUsersFeatures 更新所有用户的Features +func (this *UserDAO) UpdateUsersFeatures(tx *dbs.Tx, featureCodes []string, overwrite bool) error { + if featureCodes == nil { + featureCodes = []string{} + } + if overwrite { + featureCodesJSON, err := json.Marshal(featureCodes) + if err != nil { + return err + } + err = this.Query(tx). + State(UserStateEnabled). + Set("features", featureCodesJSON). + UpdateQuickly() + return err + } + + var lastId int64 + const size = 1000 + for { + ones, _, err := this.Query(tx). + Result("id", "features"). + State(UserStateEnabled). + Gt("id", lastId). + Limit(size). + AscPk(). + FindOnes() + if err != nil { + return err + } + for _, one := range ones { + var userId = one.GetInt64("id") + var userFeaturesJSON = one.GetBytes("features") + var userFeatures = []string{} + if len(userFeaturesJSON) > 0 { + err = json.Unmarshal(userFeaturesJSON, &userFeatures) + if err != nil { + return err + } + } + for _, featureCode := range featureCodes { + if !lists.ContainsString(userFeatures, featureCode) { + userFeatures = append(userFeatures, featureCode) + } + } + userFeaturesJSON, err = json.Marshal(userFeatures) + if err != nil { + return err + } + err = this.Query(tx). + Pk(userId). + Set("features", userFeaturesJSON). + UpdateQuickly() + if err != nil { + return err + } + } + + if len(ones) < size { + break + } + + lastId += size + } + + return nil +} + // FindUserFeatures 查找用户Features func (this *UserDAO) FindUserFeatures(tx *dbs.Tx, userId int64) ([]*userconfigs.UserFeature, error) { featuresJSON, err := this.Query(tx). diff --git a/internal/db/models/user_dao_test.go b/internal/db/models/user_dao_test.go index 97c24b56..95838d5a 100644 --- a/internal/db/models/user_dao_test.go +++ b/internal/db/models/user_dao_test.go @@ -1,5 +1,21 @@ package models import ( + "github.com/TeaOSLab/EdgeCommon/pkg/userconfigs" _ "github.com/go-sql-driver/mysql" + _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/dbs" + "testing" ) + +func TestUserDAO_UpdateUserFeatures(t *testing.T) { + var dao = NewUserDAO() + var tx *dbs.Tx + err := dao.UpdateUsersFeatures(tx, []string{ + userconfigs.UserFeatureCodeFinance, + userconfigs.UserFeatureCodeServerACME, + }, false) + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index d2e0ba72..a0be4848 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -557,6 +557,22 @@ func (this *UserService) UpdateUserFeatures(ctx context.Context, req *pb.UpdateU return this.Success() } +// UpdateAllUsersFeatures 设置所有用户能使用的功能 +func (this *UserService) UpdateAllUsersFeatures(ctx context.Context, req *pb.UpdateAllUsersFeaturesRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + err = models.SharedUserDAO.UpdateUsersFeatures(tx, req.FeatureCodes, req.Overwrite) + if err != nil { + return nil, err + } + + return this.Success() +} + // FindUserFeatures 获取用户所有的功能列表 func (this *UserService) FindUserFeatures(ctx context.Context, req *pb.FindUserFeaturesRequest) (*pb.FindUserFeaturesResponse, error) { _, userId, err := this.ValidateAdminAndUser(ctx)