升级IP名单权限判断逻辑

This commit is contained in:
GoEdgeLab
2022-06-15 19:22:33 +08:00
parent 6ce7119c14
commit f5665e8e36
9 changed files with 150 additions and 21 deletions

View File

@@ -26,5 +26,5 @@ const (
ReportNodeVersion = "0.1.0" ReportNodeVersion = "0.1.0"
// SQLVersion SQL版本号 // SQLVersion SQL版本号
SQLVersion = "10" SQLVersion = "11"
) )

View File

@@ -558,6 +558,7 @@ func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId in
} }
// FindEnabledFirewallPolicyIdsWithIPListId 查找包含某个IPList的所有策略 // FindEnabledFirewallPolicyIdsWithIPListId 查找包含某个IPList的所有策略
// TODO 改成通过 serverId 查询
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) { func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) {
ones, err := this.Query(tx). ones, err := this.Query(tx).
ResultPk(). ResultPk().
@@ -576,6 +577,7 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *
} }
// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略 // FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略
// TODO 改成通过 serverId 查询
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) { func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) {
one, err := this.Query(tx). one, err := this.Query(tx).
State(HTTPFirewallPolicyStateEnabled). State(HTTPFirewallPolicyStateEnabled).

View File

@@ -138,10 +138,11 @@ func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, e
} }
// CreateIPList 创建名单 // CreateIPList 创建名单
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) { func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, serverId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
op := NewIPListOperator() var op = NewIPListOperator()
op.IsOn = true op.IsOn = true
op.UserId = userId op.UserId = userId
op.ServerId = serverId
op.State = IPListStateEnabled op.State = IPListStateEnabled
op.Type = listType op.Type = listType
op.Name = name op.Name = name
@@ -189,26 +190,25 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e
return ErrNotFound return ErrNotFound
} }
ok, err := this.Query(tx). // 获取名单信息
listOne, err := this.Query(tx).
Pk(listId). Pk(listId).
Attr("userId", userId). Result("userId", "serverId").
Exist() Find()
if err != nil { if err != nil {
return err return err
} }
if ok { if listOne == nil {
return ErrNotFound
}
var list = listOne.(*IPList)
if int64(list.UserId) == userId {
return nil return nil
} }
// 检查是否被用户的服务所使用 var serverId = int64(list.ServerId)
policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) if serverId > 0 {
if err != nil { return SharedServerDAO.CheckUserServer(tx, userId, serverId)
return err
}
for _, policyId := range policyIds {
if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil {
return nil
}
} }
return ErrNotFound return ErrNotFound

View File

@@ -20,6 +20,39 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
t.Log("version:", version) t.Log("version:", version)
} }
func TestIPListDAO_CheckUserIPList(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
{
err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
if err == ErrNotFound {
t.Log("not found")
} else {
t.Log(err)
}
}
{
err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
if err == ErrNotFound {
t.Log("not found")
} else {
t.Log(err)
}
}
{
err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
if err == ErrNotFound {
t.Log("not found")
} else {
t.Log(err)
}
}
}
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
@@ -32,3 +65,4 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
_, _ = dao.IncreaseVersion(tx) _, _ = dao.IncreaseVersion(tx)
} }
} }

View File

@@ -9,6 +9,7 @@ type IPList struct {
Type string `field:"type"` // 类型 Type string `field:"type"` // 类型
AdminId uint32 `field:"adminId"` // 用户ID AdminId uint32 `field:"adminId"` // 用户ID
UserId uint32 `field:"userId"` // 用户ID UserId uint32 `field:"userId"` // 用户ID
ServerId uint64 `field:"serverId"` // 服务ID
Name string `field:"name"` // 列表名 Name string `field:"name"` // 列表名
Code string `field:"code"` // 代号 Code string `field:"code"` // 代号
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
@@ -26,6 +27,7 @@ type IPListOperator struct {
Type interface{} // 类型 Type interface{} // 类型
AdminId interface{} // 用户ID AdminId interface{} // 用户ID
UserId interface{} // 用户ID UserId interface{} // 用户ID
ServerId interface{} // 服务ID
Name interface{} // 列表名 Name interface{} // 列表名
Code interface{} // 代号 Code interface{} // 代号
State interface{} // 状态 State interface{} // 状态

View File

@@ -21,9 +21,20 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
return nil, err return nil, err
} }
tx := this.NullTx() var tx = this.NullTx()
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal) // 检查用户相关信息
if userId > 0 {
// 检查服务ID
if req.ServerId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
return nil, err
}
}
}
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -50,12 +61,18 @@ func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPLis
// FindEnabledIPList 查找IP列表 // FindEnabledIPList 查找IP列表
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) { func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
// 校验请求 // 校验请求
_, err := this.ValidateAdmin(ctx, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tx := this.NullTx() var tx = this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil) list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil { if err != nil {

File diff suppressed because one or more lines are too long

View File

@@ -75,6 +75,9 @@ var upgradeFuncs = []*upgradeVersion{
{ {
"0.4.7", upgradeV0_4_7, "0.4.7", upgradeV0_4_7,
}, },
{
"0.4.8", upgradeV0_4_8,
},
} }
// UpgradeSQLData 升级SQL数据 // UpgradeSQLData 升级SQL数据
@@ -672,3 +675,55 @@ func upgradeV0_4_7(db *dbs.DB) error {
return nil return nil
} }
// v0.4.7
func upgradeV0_4_8(db *dbs.DB) error {
// 设置edgeIPLists中的serverId
{
firewallPolicyOnes, _, err := db.FindOnes("SELECT inbound,serverId FROM edgeHTTPFirewallPolicies WHERE serverId>0")
if err != nil {
return err
}
for _, one := range firewallPolicyOnes {
var inboundBytes = one.GetBytes("inbound")
var serverId = one.GetInt64("serverId")
var listIds = []int64{}
if len(inboundBytes) > 0 {
var inbound = &firewallconfigs.HTTPFirewallInboundConfig{}
err = json.Unmarshal(inboundBytes, inbound)
if err == nil { // we ignore errors
if inbound.AllowListRef != nil && inbound.AllowListRef.ListId > 0 {
listIds = append(listIds, inbound.AllowListRef.ListId)
}
if inbound.DenyListRef != nil && inbound.DenyListRef.ListId > 0 {
listIds = append(listIds, inbound.DenyListRef.ListId)
}
if inbound.GreyListRef != nil && inbound.GreyListRef.ListId > 0 {
listIds = append(listIds, inbound.GreyListRef.ListId)
}
}
}
if len(listIds) == 0 {
continue
}
for _, listId := range listIds {
isPublicCol, err := db.FindCol(0, "SELECT isPublic FROM edgeIPLists WHERE id=? LIMIT 1", listId)
if err != nil {
return err
}
var isPublic = types.Bool(isPublicCol)
if !isPublic {
_, err = db.Exec("UPDATE edgeIPLists SET serverId=? WHERE id=?", serverId, listId)
if err != nil {
return err
}
}
}
}
}
return nil
}

View File

@@ -175,3 +175,22 @@ func TestUpgradeSQLData_v0_4_7(t *testing.T) {
} }
t.Log("ok") t.Log("ok")
} }
func TestUpgradeSQLData_v0_4_8(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_8(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}