mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-06 01:50:25 +08:00
升级IP名单权限判断逻辑
This commit is contained in:
@@ -26,5 +26,5 @@ const (
|
|||||||
ReportNodeVersion = "0.1.0"
|
ReportNodeVersion = "0.1.0"
|
||||||
|
|
||||||
// SQLVersion SQL版本号
|
// SQLVersion SQL版本号
|
||||||
SQLVersion = "10"
|
SQLVersion = "11"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -558,6 +558,7 @@ func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId in
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FindEnabledFirewallPolicyIdsWithIPListId 查找包含某个IPList的所有策略
|
// FindEnabledFirewallPolicyIdsWithIPListId 查找包含某个IPList的所有策略
|
||||||
|
// TODO 改成通过 serverId 查询
|
||||||
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) {
|
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) {
|
||||||
ones, err := this.Query(tx).
|
ones, err := this.Query(tx).
|
||||||
ResultPk().
|
ResultPk().
|
||||||
@@ -576,6 +577,7 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略
|
// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略
|
||||||
|
// TODO 改成通过 serverId 查询
|
||||||
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) {
|
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) {
|
||||||
one, err := this.Query(tx).
|
one, err := this.Query(tx).
|
||||||
State(HTTPFirewallPolicyStateEnabled).
|
State(HTTPFirewallPolicyStateEnabled).
|
||||||
|
|||||||
@@ -138,10 +138,11 @@ func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateIPList 创建名单
|
// CreateIPList 创建名单
|
||||||
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
|
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, serverId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
|
||||||
op := NewIPListOperator()
|
var op = NewIPListOperator()
|
||||||
op.IsOn = true
|
op.IsOn = true
|
||||||
op.UserId = userId
|
op.UserId = userId
|
||||||
|
op.ServerId = serverId
|
||||||
op.State = IPListStateEnabled
|
op.State = IPListStateEnabled
|
||||||
op.Type = listType
|
op.Type = listType
|
||||||
op.Name = name
|
op.Name = name
|
||||||
@@ -189,26 +190,25 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e
|
|||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := this.Query(tx).
|
// 获取名单信息
|
||||||
|
listOne, err := this.Query(tx).
|
||||||
Pk(listId).
|
Pk(listId).
|
||||||
Attr("userId", userId).
|
Result("userId", "serverId").
|
||||||
Exist()
|
Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ok {
|
if listOne == nil {
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
var list = listOne.(*IPList)
|
||||||
|
if int64(list.UserId) == userId {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否被用户的服务所使用
|
var serverId = int64(list.ServerId)
|
||||||
policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
|
if serverId > 0 {
|
||||||
if err != nil {
|
return SharedServerDAO.CheckUserServer(tx, userId, serverId)
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, policyId := range policyIds {
|
|
||||||
if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
|
|||||||
@@ -20,6 +20,39 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
|||||||
t.Log("version:", version)
|
t.Log("version:", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||||
|
dbs.NotifyReady()
|
||||||
|
|
||||||
|
var tx *dbs.Tx
|
||||||
|
|
||||||
|
{
|
||||||
|
err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
|
||||||
|
if err == ErrNotFound {
|
||||||
|
t.Log("not found")
|
||||||
|
} else {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
|
||||||
|
if err == ErrNotFound {
|
||||||
|
t.Log("not found")
|
||||||
|
} else {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
|
||||||
|
if err == ErrNotFound {
|
||||||
|
t.Log("not found")
|
||||||
|
} else {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||||
runtime.GOMAXPROCS(1)
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
@@ -32,3 +65,4 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
|||||||
_, _ = dao.IncreaseVersion(tx)
|
_, _ = dao.IncreaseVersion(tx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ type IPList struct {
|
|||||||
Type string `field:"type"` // 类型
|
Type string `field:"type"` // 类型
|
||||||
AdminId uint32 `field:"adminId"` // 用户ID
|
AdminId uint32 `field:"adminId"` // 用户ID
|
||||||
UserId uint32 `field:"userId"` // 用户ID
|
UserId uint32 `field:"userId"` // 用户ID
|
||||||
|
ServerId uint64 `field:"serverId"` // 服务ID
|
||||||
Name string `field:"name"` // 列表名
|
Name string `field:"name"` // 列表名
|
||||||
Code string `field:"code"` // 代号
|
Code string `field:"code"` // 代号
|
||||||
State uint8 `field:"state"` // 状态
|
State uint8 `field:"state"` // 状态
|
||||||
@@ -26,6 +27,7 @@ type IPListOperator struct {
|
|||||||
Type interface{} // 类型
|
Type interface{} // 类型
|
||||||
AdminId interface{} // 用户ID
|
AdminId interface{} // 用户ID
|
||||||
UserId interface{} // 用户ID
|
UserId interface{} // 用户ID
|
||||||
|
ServerId interface{} // 服务ID
|
||||||
Name interface{} // 列表名
|
Name interface{} // 列表名
|
||||||
Code interface{} // 代号
|
Code interface{} // 代号
|
||||||
State interface{} // 状态
|
State interface{} // 状态
|
||||||
|
|||||||
@@ -21,9 +21,20 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := this.NullTx()
|
var tx = this.NullTx()
|
||||||
|
|
||||||
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
|
// 检查用户相关信息
|
||||||
|
if userId > 0 {
|
||||||
|
// 检查服务ID
|
||||||
|
if req.ServerId > 0 {
|
||||||
|
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -50,12 +61,18 @@ func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPLis
|
|||||||
// FindEnabledIPList 查找IP列表
|
// FindEnabledIPList 查找IP列表
|
||||||
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
|
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
|
||||||
// 校验请求
|
// 校验请求
|
||||||
_, err := this.ValidateAdmin(ctx, 0)
|
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := this.NullTx()
|
var tx = this.NullTx()
|
||||||
|
if userId > 0 {
|
||||||
|
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
|
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -75,6 +75,9 @@ var upgradeFuncs = []*upgradeVersion{
|
|||||||
{
|
{
|
||||||
"0.4.7", upgradeV0_4_7,
|
"0.4.7", upgradeV0_4_7,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"0.4.8", upgradeV0_4_8,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpgradeSQLData 升级SQL数据
|
// UpgradeSQLData 升级SQL数据
|
||||||
@@ -672,3 +675,55 @@ func upgradeV0_4_7(db *dbs.DB) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// v0.4.7
|
||||||
|
func upgradeV0_4_8(db *dbs.DB) error {
|
||||||
|
// 设置edgeIPLists中的serverId
|
||||||
|
{
|
||||||
|
firewallPolicyOnes, _, err := db.FindOnes("SELECT inbound,serverId FROM edgeHTTPFirewallPolicies WHERE serverId>0")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, one := range firewallPolicyOnes {
|
||||||
|
var inboundBytes = one.GetBytes("inbound")
|
||||||
|
var serverId = one.GetInt64("serverId")
|
||||||
|
|
||||||
|
var listIds = []int64{}
|
||||||
|
|
||||||
|
if len(inboundBytes) > 0 {
|
||||||
|
var inbound = &firewallconfigs.HTTPFirewallInboundConfig{}
|
||||||
|
err = json.Unmarshal(inboundBytes, inbound)
|
||||||
|
if err == nil { // we ignore errors
|
||||||
|
if inbound.AllowListRef != nil && inbound.AllowListRef.ListId > 0 {
|
||||||
|
listIds = append(listIds, inbound.AllowListRef.ListId)
|
||||||
|
}
|
||||||
|
if inbound.DenyListRef != nil && inbound.DenyListRef.ListId > 0 {
|
||||||
|
listIds = append(listIds, inbound.DenyListRef.ListId)
|
||||||
|
}
|
||||||
|
if inbound.GreyListRef != nil && inbound.GreyListRef.ListId > 0 {
|
||||||
|
listIds = append(listIds, inbound.GreyListRef.ListId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listIds) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, listId := range listIds {
|
||||||
|
isPublicCol, err := db.FindCol(0, "SELECT isPublic FROM edgeIPLists WHERE id=? LIMIT 1", listId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var isPublic = types.Bool(isPublicCol)
|
||||||
|
if !isPublic {
|
||||||
|
_, err = db.Exec("UPDATE edgeIPLists SET serverId=? WHERE id=?", serverId, listId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -175,3 +175,22 @@ func TestUpgradeSQLData_v0_4_7(t *testing.T) {
|
|||||||
}
|
}
|
||||||
t.Log("ok")
|
t.Log("ok")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpgradeSQLData_v0_4_8(t *testing.T) {
|
||||||
|
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
|
||||||
|
Driver: "mysql",
|
||||||
|
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
|
||||||
|
Prefix: "edge",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = db.Close()
|
||||||
|
}()
|
||||||
|
err = upgradeV0_4_8(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user