mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-05 17:40:24 +08:00
升级IP名单权限判断逻辑
This commit is contained in:
@@ -26,5 +26,5 @@ const (
|
||||
ReportNodeVersion = "0.1.0"
|
||||
|
||||
// SQLVersion SQL版本号
|
||||
SQLVersion = "10"
|
||||
SQLVersion = "11"
|
||||
)
|
||||
|
||||
@@ -558,6 +558,7 @@ func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId in
|
||||
}
|
||||
|
||||
// FindEnabledFirewallPolicyIdsWithIPListId 查找包含某个IPList的所有策略
|
||||
// TODO 改成通过 serverId 查询
|
||||
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) {
|
||||
ones, err := this.Query(tx).
|
||||
ResultPk().
|
||||
@@ -576,6 +577,7 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *
|
||||
}
|
||||
|
||||
// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略
|
||||
// TODO 改成通过 serverId 查询
|
||||
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) {
|
||||
one, err := this.Query(tx).
|
||||
State(HTTPFirewallPolicyStateEnabled).
|
||||
|
||||
@@ -138,10 +138,11 @@ func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, e
|
||||
}
|
||||
|
||||
// CreateIPList 创建名单
|
||||
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
|
||||
op := NewIPListOperator()
|
||||
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, serverId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
|
||||
var op = NewIPListOperator()
|
||||
op.IsOn = true
|
||||
op.UserId = userId
|
||||
op.ServerId = serverId
|
||||
op.State = IPListStateEnabled
|
||||
op.Type = listType
|
||||
op.Name = name
|
||||
@@ -189,26 +190,25 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
ok, err := this.Query(tx).
|
||||
// 获取名单信息
|
||||
listOne, err := this.Query(tx).
|
||||
Pk(listId).
|
||||
Attr("userId", userId).
|
||||
Exist()
|
||||
Result("userId", "serverId").
|
||||
Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
if listOne == nil {
|
||||
return ErrNotFound
|
||||
}
|
||||
var list = listOne.(*IPList)
|
||||
if int64(list.UserId) == userId {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否被用户的服务所使用
|
||||
policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, policyId := range policyIds {
|
||||
if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil {
|
||||
return nil
|
||||
}
|
||||
var serverId = int64(list.ServerId)
|
||||
if serverId > 0 {
|
||||
return SharedServerDAO.CheckUserServer(tx, userId, serverId)
|
||||
}
|
||||
|
||||
return ErrNotFound
|
||||
|
||||
@@ -20,6 +20,39 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||
t.Log("version:", version)
|
||||
}
|
||||
|
||||
func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
var tx *dbs.Tx
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
|
||||
if err == ErrNotFound {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
|
||||
if err == ErrNotFound {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
|
||||
if err == ErrNotFound {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
@@ -32,3 +65,4 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
_, _ = dao.IncreaseVersion(tx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ type IPList struct {
|
||||
Type string `field:"type"` // 类型
|
||||
AdminId uint32 `field:"adminId"` // 用户ID
|
||||
UserId uint32 `field:"userId"` // 用户ID
|
||||
ServerId uint64 `field:"serverId"` // 服务ID
|
||||
Name string `field:"name"` // 列表名
|
||||
Code string `field:"code"` // 代号
|
||||
State uint8 `field:"state"` // 状态
|
||||
@@ -26,6 +27,7 @@ type IPListOperator struct {
|
||||
Type interface{} // 类型
|
||||
AdminId interface{} // 用户ID
|
||||
UserId interface{} // 用户ID
|
||||
ServerId interface{} // 服务ID
|
||||
Name interface{} // 列表名
|
||||
Code interface{} // 代号
|
||||
State interface{} // 状态
|
||||
|
||||
@@ -21,9 +21,20 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx := this.NullTx()
|
||||
var tx = this.NullTx()
|
||||
|
||||
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
|
||||
// 检查用户相关信息
|
||||
if userId > 0 {
|
||||
// 检查服务ID
|
||||
if req.ServerId > 0 {
|
||||
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,12 +61,18 @@ func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPLis
|
||||
// FindEnabledIPList 查找IP列表
|
||||
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
|
||||
// 校验请求
|
||||
_, err := this.ValidateAdmin(ctx, 0)
|
||||
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx := this.NullTx()
|
||||
var tx = this.NullTx()
|
||||
if userId > 0 {
|
||||
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
|
||||
if err != nil {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -75,6 +75,9 @@ var upgradeFuncs = []*upgradeVersion{
|
||||
{
|
||||
"0.4.7", upgradeV0_4_7,
|
||||
},
|
||||
{
|
||||
"0.4.8", upgradeV0_4_8,
|
||||
},
|
||||
}
|
||||
|
||||
// UpgradeSQLData 升级SQL数据
|
||||
@@ -672,3 +675,55 @@ func upgradeV0_4_7(db *dbs.DB) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// v0.4.7
|
||||
func upgradeV0_4_8(db *dbs.DB) error {
|
||||
// 设置edgeIPLists中的serverId
|
||||
{
|
||||
firewallPolicyOnes, _, err := db.FindOnes("SELECT inbound,serverId FROM edgeHTTPFirewallPolicies WHERE serverId>0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, one := range firewallPolicyOnes {
|
||||
var inboundBytes = one.GetBytes("inbound")
|
||||
var serverId = one.GetInt64("serverId")
|
||||
|
||||
var listIds = []int64{}
|
||||
|
||||
if len(inboundBytes) > 0 {
|
||||
var inbound = &firewallconfigs.HTTPFirewallInboundConfig{}
|
||||
err = json.Unmarshal(inboundBytes, inbound)
|
||||
if err == nil { // we ignore errors
|
||||
if inbound.AllowListRef != nil && inbound.AllowListRef.ListId > 0 {
|
||||
listIds = append(listIds, inbound.AllowListRef.ListId)
|
||||
}
|
||||
if inbound.DenyListRef != nil && inbound.DenyListRef.ListId > 0 {
|
||||
listIds = append(listIds, inbound.DenyListRef.ListId)
|
||||
}
|
||||
if inbound.GreyListRef != nil && inbound.GreyListRef.ListId > 0 {
|
||||
listIds = append(listIds, inbound.GreyListRef.ListId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(listIds) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, listId := range listIds {
|
||||
isPublicCol, err := db.FindCol(0, "SELECT isPublic FROM edgeIPLists WHERE id=? LIMIT 1", listId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var isPublic = types.Bool(isPublicCol)
|
||||
if !isPublic {
|
||||
_, err = db.Exec("UPDATE edgeIPLists SET serverId=? WHERE id=?", serverId, listId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -175,3 +175,22 @@ func TestUpgradeSQLData_v0_4_7(t *testing.T) {
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestUpgradeSQLData_v0_4_8(t *testing.T) {
|
||||
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
|
||||
Driver: "mysql",
|
||||
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
|
||||
Prefix: "edge",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
err = upgradeV0_4_8(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user