From b51896f92cd828489ec437a005d329706296efb1 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Tue, 2 Feb 2021 19:29:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=A3=80=E6=9F=A5IP=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 1 + go.sum | 1 + internal/db/models/ip_item_dao.go | 27 +++ internal/db/models/ip_item_model.go | 52 +++--- .../services/service_http_firewall_policy.go | 166 ++++++++++++++++++ .../service_http_firewall_policy_test.go | 47 +++++ internal/setup/sql_data.go | 23 +++ internal/utils/ip.go | 17 +- internal/utils/version.go | 2 +- 9 files changed, 306 insertions(+), 30 deletions(-) create mode 100644 internal/rpc/services/service_http_firewall_policy_test.go diff --git a/go.mod b/go.mod index f101142a..04971ca9 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 github.com/aliyun/alibaba-cloud-sdk-go v1.61.641 + github.com/cespare/xxhash/v2 v2.1.1 github.com/go-acme/lego/v4 v4.1.2 github.com/go-ole/go-ole v1.2.4 // indirect github.com/go-sql-driver/mysql v1.5.0 diff --git a/go.sum b/go.sum index ea9ffc1b..2d7419c0 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,7 @@ github.com/cenkalti/backoff/v4 v4.0.2 h1:JIufpQLbh4DkbQoii76ItQIUFzevQSqOLZca4ea github.com/cenkalti/backoff/v4 v4.0.2/go.mod h1:eEew/i+1Q6OrCDZh3WiXYv3+nJwBASZ8Bog/87DQnVg= github.com/census-instrumentation/opencensus-proto v0.2.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 47ffec38..9e04e4b6 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -2,11 +2,13 @@ package models import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/utils" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" + "math" "time" ) @@ -95,6 +97,8 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, listId int64, ipFrom string, ipT op.ListId = listId op.IpFrom = ipFrom op.IpTo = ipTo + op.IpFromLong = utils.IP2Long(ipFrom) + op.IpToLong = utils.IP2Long(ipTo) op.Reason = reason op.Type = itemType op.Version = version @@ -142,6 +146,8 @@ func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipT op.Id = itemId op.IpFrom = ipFrom op.IpTo = ipTo + op.IpFromLong = utils.IP2Long(ipFrom) + op.IpToLong = utils.IP2Long(ipTo) op.Reason = reason op.Type = itemType if expiredAt < 0 { @@ -200,6 +206,27 @@ func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) { FindInt64Col(0) } +// 查找包含某个IP的Item +func (this *IPItemDAO) FindEnabledItemContainsIP(tx *dbs.Tx, listId int64, ip uint64) (*IPItem, error) { + query := this.Query(tx). + Attr("listId", listId). + State(IPItemStateEnabled) + if ip > math.MaxUint32 { + query.Where("(type='all' OR ipFromLong=:ip)") + } else { + query.Where("(type='all' OR ipFromLong=:ip OR (ipToLong>0 AND ipFromLong<=:ip AND ipToLong>=:ip))"). + Param("ip", ip) + } + one, err := query.Find() + if err != nil { + return nil, err + } + if one == nil { + return nil, nil + } + return one.(*IPItem), nil +} + // 通知更新 func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error { // 获取ListId diff --git a/internal/db/models/ip_item_model.go b/internal/db/models/ip_item_model.go index 37d6f16d..94e1d7cf 100644 --- a/internal/db/models/ip_item_model.go +++ b/internal/db/models/ip_item_model.go @@ -2,33 +2,37 @@ package models // IP type IPItem struct { - Id uint64 `field:"id"` // ID - ListId uint32 `field:"listId"` // 所属名单ID - Type string `field:"type"` // 类型 - IpFrom string `field:"ipFrom"` // 开始IP - IpTo string `field:"ipTo"` // 结束IP - Version uint64 `field:"version"` // 版本 - CreatedAt uint64 `field:"createdAt"` // 创建时间 - UpdatedAt uint64 `field:"updatedAt"` // 修改时间 - Reason string `field:"reason"` // 加入说明 - Action string `field:"action"` // 动作代号 - State uint8 `field:"state"` // 状态 - ExpiredAt uint64 `field:"expiredAt"` // 过期时间 + Id uint64 `field:"id"` // ID + ListId uint32 `field:"listId"` // 所属名单ID + Type string `field:"type"` // 类型 + IpFrom string `field:"ipFrom"` // 开始IP + IpTo string `field:"ipTo"` // 结束IP + IpFromLong uint64 `field:"ipFromLong"` // 开始IP整型 + IpToLong uint64 `field:"ipToLong"` // 结束IP整型 + Version uint64 `field:"version"` // 版本 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + UpdatedAt uint64 `field:"updatedAt"` // 修改时间 + Reason string `field:"reason"` // 加入说明 + Action string `field:"action"` // 动作代号 + State uint8 `field:"state"` // 状态 + ExpiredAt uint64 `field:"expiredAt"` // 过期时间 } type IPItemOperator struct { - Id interface{} // ID - ListId interface{} // 所属名单ID - Type interface{} // 类型 - IpFrom interface{} // 开始IP - IpTo interface{} // 结束IP - Version interface{} // 版本 - CreatedAt interface{} // 创建时间 - UpdatedAt interface{} // 修改时间 - Reason interface{} // 加入说明 - Action interface{} // 动作代号 - State interface{} // 状态 - ExpiredAt interface{} // 过期时间 + Id interface{} // ID + ListId interface{} // 所属名单ID + Type interface{} // 类型 + IpFrom interface{} // 开始IP + IpTo interface{} // 结束IP + IpFromLong interface{} // 开始IP整型 + IpToLong interface{} // 结束IP整型 + Version interface{} // 版本 + CreatedAt interface{} // 创建时间 + UpdatedAt interface{} // 修改时间 + Reason interface{} // 加入说明 + Action interface{} // 动作代号 + State interface{} // 状态 + ExpiredAt interface{} // 过期时间 } func NewIPItemOperator() *IPItemOperator { diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 6f8de023..23fae8ad 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -4,11 +4,15 @@ import ( "context" "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/db/models/regions" "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/iplibrary" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/iwind/TeaGo/lists" + "net" ) // HTTP防火墙(WAF)相关服务 @@ -628,3 +632,165 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont return this.Success() } + +// 检查IP状态 +func (this *HTTPFirewallPolicyService) CheckHTTPFirewallPolicyIPStatus(ctx context.Context, req *pb.CheckHTTPFirewallPolicyIPStatusRequest) (*pb.CheckHTTPFirewallPolicyIPStatusResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + // 校验IP + ip := net.ParseIP(req.Ip) + if len(ip) == 0 { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: false, + Error: "请输入正确的IP", + }, nil + } + ipLong := utils.IP2Long(req.Ip) + + tx := this.NullTx() + firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId) + if err != nil { + return nil, err + } + if firewallPolicy == nil { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: false, + Error: "找不到策略信息", + }, nil + } + + // 检查白名单 + if firewallPolicy.Inbound != nil && + firewallPolicy.Inbound.IsOn && + firewallPolicy.Inbound.AllowListRef != nil && + firewallPolicy.Inbound.AllowListRef.IsOn && + firewallPolicy.Inbound.AllowListRef.ListId > 0 { + item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, firewallPolicy.Inbound.AllowListRef.ListId, ipLong) + if err != nil { + return nil, err + } + if item != nil { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: true, + Error: "", + IsFound: true, + IsAllowed: true, + IpList: &pb.IPList{Name: "白名单", Id: firewallPolicy.Inbound.AllowListRef.ListId}, + IpItem: &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + ExpiredAt: int64(item.ExpiredAt), + Reason: item.Reason, + Type: item.Type, + }, + RegionCountry: nil, + RegionProvince: nil, + }, nil + } + } + + // 检查黑名单 + if firewallPolicy.Inbound != nil && + firewallPolicy.Inbound.IsOn && + firewallPolicy.Inbound.AllowListRef != nil && + firewallPolicy.Inbound.AllowListRef.IsOn && + firewallPolicy.Inbound.AllowListRef.ListId > 0 { + item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, firewallPolicy.Inbound.DenyListRef.ListId, ipLong) + if err != nil { + return nil, err + } + if item != nil { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: true, + Error: "", + IsFound: true, + IsAllowed: false, + IpList: &pb.IPList{Name: "黑名单", Id: firewallPolicy.Inbound.DenyListRef.ListId}, + IpItem: &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + ExpiredAt: int64(item.ExpiredAt), + Reason: item.Reason, + Type: item.Type, + }, + RegionCountry: nil, + RegionProvince: nil, + }, nil + } + } + + // 检查封禁的地区和省份 + info, err := iplibrary.SharedLibrary.Lookup(req.Ip) + if err != nil { + return nil, err + } + if info != nil { + if firewallPolicy.Inbound != nil && + firewallPolicy.Inbound.IsOn && + firewallPolicy.Inbound.Region != nil && + firewallPolicy.Inbound.Region.IsOn { + // 检查封禁的地区 + countryId, err := regions.SharedRegionCountryDAO.FindCountryIdWithNameCacheable(tx, info.Country) + if err != nil { + return nil, err + } + if countryId > 0 && lists.ContainsInt64(firewallPolicy.Inbound.Region.DenyCountryIds, countryId) { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: true, + Error: "", + IsFound: true, + IsAllowed: false, + IpList: nil, + IpItem: nil, + RegionCountry: &pb.RegionCountry{ + Id: countryId, + Name: info.Country, + }, + RegionProvince: nil, + }, nil + } + + // 检查封禁的省份 + if countryId > 0 { + provinceId, err := regions.SharedRegionProvinceDAO.FindProvinceIdWithNameCacheable(tx, countryId, info.Province) + if err != nil { + return nil, err + } + if provinceId > 0 && lists.ContainsInt64(firewallPolicy.Inbound.Region.DenyProvinceIds, provinceId) { + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: true, + Error: "", + IsFound: true, + IsAllowed: false, + IpList: nil, + IpItem: nil, + RegionCountry: &pb.RegionCountry{ + Id: countryId, + Name: info.Country, + }, + RegionProvince: &pb.RegionProvince{ + Id: provinceId, + Name: info.Province, + }, + }, nil + } + } + } + } + + return &pb.CheckHTTPFirewallPolicyIPStatusResponse{ + IsOk: true, + Error: "", + IsFound: false, + IsAllowed: false, + IpList: nil, + IpItem: nil, + RegionCountry: nil, + RegionProvince: nil, + }, nil +} diff --git a/internal/rpc/services/service_http_firewall_policy_test.go b/internal/rpc/services/service_http_firewall_policy_test.go new file mode 100644 index 00000000..352bb535 --- /dev/null +++ b/internal/rpc/services/service_http_firewall_policy_test.go @@ -0,0 +1,47 @@ +package services + +import ( + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/logs" + "testing" +) + +func TestHTTPFirewallPolicyService_CheckHTTPFirewallPolicyIPStatus(t *testing.T) { + dbs.NotifyReady() + service := &HTTPFirewallPolicyService{} + + { + resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{ + HttpFirewallPolicyId: 14, + Ip: "127.0.0.1", + }) + if err != nil { + t.Fatal(err) + } + logs.PrintAsJSON(resp, t) + } + + { + resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{ + HttpFirewallPolicyId: 14, + Ip: "192.168.1.100", + }) + if err != nil { + t.Fatal(err) + } + logs.PrintAsJSON(resp, t) + } + + { + resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{ + HttpFirewallPolicyId: 14, + Ip: "221.218.201.94", + }) + if err != nil { + t.Fatal(err) + } + logs.PrintAsJSON(resp, t) + } +} diff --git a/internal/setup/sql_data.go b/internal/setup/sql_data.go index 4c2aa1e4..9d2e0394 100644 --- a/internal/setup/sql_data.go +++ b/internal/setup/sql_data.go @@ -4,6 +4,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/acme" "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/rands" @@ -29,6 +30,9 @@ var upgradeFuncs = []*upgradeVersion{ { "0.0.9", upgradeV0_0_9, }, + { + "0.0.10", upgradeV0_0_10, + }, } // 升级SQL数据 @@ -180,3 +184,22 @@ func upgradeV0_0_9(db *dbs.DB) error { return nil } + +// v0.0.10 +func upgradeV0_0_10(db *dbs.DB) error { + // IP Item列表转换 + ones, _, err := db.FindOnes("SELECT * FROM edgeIPItems ORDER BY id ASC") + if err != nil { + return err + } + for _, one := range ones { + ipFromLong := utils.IP2Long(one.GetString("ipFrom")) + ipToLong := utils.IP2Long(one.GetString("ipTo")) + _, err = db.Exec("UPDATE edgeIPItems SET ipFromLong=?, ipToLong=? WHERE id=?", ipFromLong, ipToLong, one.GetInt64("id")) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/utils/ip.go b/internal/utils/ip.go index 58105208..99a36746 100644 --- a/internal/utils/ip.go +++ b/internal/utils/ip.go @@ -2,18 +2,25 @@ package utils import ( "encoding/binary" + "github.com/cespare/xxhash/v2" + "math" "net" + "strings" ) // 将IP转换为整型 -func IP2Long(ip string) uint32 { +// 注意IPv6没有顺序 +func IP2Long(ip string) uint64 { + if len(ip) == 0 { + return 0 + } s := net.ParseIP(ip) - if s == nil { + if len(s) == 0 { return 0 } - if len(s) == 16 { - return binary.BigEndian.Uint32(s[12:16]) + if strings.Contains(ip, ":") { + return math.MaxUint32 + xxhash.Sum64(s) } - return binary.BigEndian.Uint32(s) + return uint64(binary.BigEndian.Uint32(s.To4())) } diff --git a/internal/utils/version.go b/internal/utils/version.go index 8940a0cb..5742711d 100644 --- a/internal/utils/version.go +++ b/internal/utils/version.go @@ -14,5 +14,5 @@ func VersionToLong(version string) uint32 { } else if countDots == 0 { version += ".0.0.0" } - return IP2Long(version) + return uint32(IP2Long(version)) }