From 065ddbe734806b404336a9e4eff77573d60fe183 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sun, 3 Jan 2021 20:18:07 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=AB=AF=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0WAF=20=E9=BB=91=E7=99=BD=E5=90=8D=E5=8D=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 4 +- go.sum | 3 + .../db/models/http_firewall_policy_dao.go | 18 ++++- internal/db/models/http_web_dao.go | 13 +++- internal/db/models/ip_item_dao.go | 8 +++ internal/db/models/ip_list_dao.go | 18 ++++- .../services/service_http_firewall_policy.go | 13 +++- internal/rpc/services/service_ip_item.go | 65 +++++++++++++++++-- internal/rpc/services/service_ip_list.go | 4 +- 9 files changed, 129 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 59bb9466..22e45213 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,6 @@ go 1.15 replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon -replace github.com/iwind/TeaGo => /Users/WorkSpace/TeaGo - require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 @@ -15,7 +13,7 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/go-yaml/yaml v2.1.0+incompatible github.com/golang/protobuf v1.4.2 - github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e + github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea // indirect github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/mozillazg/go-pinyin v0.18.0 github.com/pkg/sftp v1.12.0 diff --git a/go.sum b/go.sum index 358d441c..4561b00d 100644 --- a/go.sum +++ b/go.sum @@ -173,6 +173,9 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/iij/doapi v0.0.0-20190504054126-0bbf12d6d7df/go.mod h1:QMZY7/J/KSQEhKWFeDesPjMj+wCHReeknARU3wqlyN4= +github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= +github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea h1:ACgcrVeyHpKt8K6k92RrmPndxVq7Qx+zNXZBzRKv3+0= +github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index da90e691..9c280822 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -99,8 +99,9 @@ func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies(tx *dbs.Tx) (r } // 创建策略 -func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) { +func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) { op := NewHTTPFirewallPolicyOperator() + op.UserId = userId op.State = HTTPFirewallPolicyStateEnabled op.IsOn = isOn op.Name = name @@ -282,3 +283,18 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in return config, nil } + +// 检查用户防火墙策略 +func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId int64, firewallPolicyId int64) error { + ok, err := this.Query(tx). + Pk(firewallPolicyId). + Attr("userId", userId). + Exist() + if err != nil { + return err + } + if !ok { + return ErrNotFound + } + return nil +} diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 405a41d4..af33b8d4 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -242,7 +242,18 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig } config.FirewallRef = firewallRef - // 暂不支持自定义防火墙策略设置,因为同一个集群下的服务需要集中管理 + // 自定义防火墙设置 + if firewallRef.FirewallPolicyId > 0 { + firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId) + if err != nil { + return nil, err + } + if firewallPolicy == nil { + config.FirewallRef = nil + } else { + config.FirewallPolicy = firewallPolicy + } + } } // 路径规则 diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index cfed27c7..84d6a431 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -166,3 +166,11 @@ func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size i FindAll() return } + +// 查找IPItem对应的列表ID +func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) { + return this.Query(tx). + Pk(itemId). + Result("listId"). + FindInt64Col(0) +} diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index 531c2741..daaf1758 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -75,9 +75,10 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) { } // 创建名单 -func (this *IPListDAO) CreateIPList(tx *dbs.Tx, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) { +func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) { op := NewIPListOperator() op.IsOn = true + op.UserId = userId op.State = IPListStateEnabled op.Type = listType op.Name = name @@ -128,3 +129,18 @@ func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) { err = SharedSysSettingDAO.UpdateSetting(tx, SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value))) return value, nil } + +// 检查用户权限 +func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error { + ok, err := this.Query(tx). + Pk(listId). + Attr("userId", userId). + Exist() + if err != nil { + return err + } + if ok { + return nil + } + return ErrNotFound +} diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 73050c40..26778abe 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -49,14 +49,14 @@ func (this *HTTPFirewallPolicyService) FindAllEnabledHTTPFirewallPolicies(ctx co // 创建防火墙策略 func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Context, req *pb.CreateHTTPFirewallPolicyRequest) (*pb.CreateHTTPFirewallPolicyResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() - policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, req.IsOn, req.Name, req.Description, nil, nil) + policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, userId, req.IsOn, req.Name, req.Description, nil, nil) if err != nil { return nil, err } @@ -263,13 +263,20 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicyGroups(ctx contex // 修改inbound信息 func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallInboundConfig(ctx context.Context, req *pb.UpdateHTTPFirewallInboundConfigRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, req.HttpFirewallPolicyId) + if err != nil { + return nil, err + } + } + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInbound(tx, req.HttpFirewallPolicyId, req.InboundJSON) if err != nil { return nil, err diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 1bf1ef6d..6d66af6c 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -15,13 +15,20 @@ type IPItemService struct { // 创建IP func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPItemRequest) (*pb.CreateIPItemResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } + } + itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) if err != nil { return nil, err @@ -33,13 +40,25 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte // 修改IP func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId) + if err != nil { + return nil, err + } + + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId) + if err != nil { + return nil, err + } + } + err = models.SharedIPItemDAO.UpdateIPItem(tx, req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) if err != nil { return nil, err @@ -50,13 +69,25 @@ func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPIte // 删除IP func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPItemRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId) + if err != nil { + return nil, err + } + + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId) + if err != nil { + return nil, err + } + } + err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId) if err != nil { return nil, err @@ -67,13 +98,20 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // 计算IP数量 func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.CountIPItemsWithListIdRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } + } + count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId) if err != nil { return nil, err @@ -84,13 +122,20 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C // 列出单页的IP func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.ListIPItemsWithListIdRequest) (*pb.ListIPItemsWithListIdResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() + if userId > 0 { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } + } + items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Offset, req.Size) if err != nil { return nil, err @@ -113,7 +158,7 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li // 查找单个IP func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEnabledIPItemRequest) (*pb.FindEnabledIPItemResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -127,6 +172,14 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn if item == nil { return &pb.FindEnabledIPItemResponse{IpItem: nil}, nil } + + if userId > 0 { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, int64(item.ListId)) + if err != nil { + return nil, err + } + } + return &pb.FindEnabledIPItemResponse{IpItem: &pb.IPItem{ Id: int64(item.Id), IpFrom: item.IpFrom, diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 25981828..b8d638b3 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -15,14 +15,14 @@ type IPListService struct { // 创建IP列表 func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPListRequest) (*pb.CreateIPListResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } tx := this.NullTx() - listId, err := models.SharedIPListDAO.CreateIPList(tx, req.Type, req.Name, req.Code, req.TimeoutJSON) + listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON) if err != nil { return nil, err }