用户端可以添加WAF 黑白名单

This commit is contained in:
GoEdgeLab
2021-01-03 20:18:07 +08:00
parent 155dd5b798
commit 065ddbe734
9 changed files with 129 additions and 17 deletions

4
go.mod
View File

@@ -4,8 +4,6 @@ go 1.15
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
replace github.com/iwind/TeaGo => /Users/WorkSpace/TeaGo
require (
github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
@@ -15,7 +13,7 @@ require (
github.com/go-sql-driver/mysql v1.5.0
github.com/go-yaml/yaml v2.1.0+incompatible
github.com/golang/protobuf v1.4.2
github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e
github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea // indirect
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible
github.com/mozillazg/go-pinyin v0.18.0
github.com/pkg/sftp v1.12.0

3
go.sum
View File

@@ -173,6 +173,9 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/iij/doapi v0.0.0-20190504054126-0bbf12d6d7df/go.mod h1:QMZY7/J/KSQEhKWFeDesPjMj+wCHReeknARU3wqlyN4=
github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea h1:ACgcrVeyHpKt8K6k92RrmPndxVq7Qx+zNXZBzRKv3+0=
github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=

View File

@@ -99,8 +99,9 @@ func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies(tx *dbs.Tx) (r
}
// 创建策略
func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) {
func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) {
op := NewHTTPFirewallPolicyOperator()
op.UserId = userId
op.State = HTTPFirewallPolicyStateEnabled
op.IsOn = isOn
op.Name = name
@@ -282,3 +283,18 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
return config, nil
}
// 检查用户防火墙策略
func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId int64, firewallPolicyId int64) error {
ok, err := this.Query(tx).
Pk(firewallPolicyId).
Attr("userId", userId).
Exist()
if err != nil {
return err
}
if !ok {
return ErrNotFound
}
return nil
}

View File

@@ -242,7 +242,18 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
}
config.FirewallRef = firewallRef
// 暂不支持自定义防火墙策略设置,因为同一个集群下的服务需要集中管理
// 自定义防火墙设置
if firewallRef.FirewallPolicyId > 0 {
firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId)
if err != nil {
return nil, err
}
if firewallPolicy == nil {
config.FirewallRef = nil
} else {
config.FirewallPolicy = firewallPolicy
}
}
}
// 路径规则

View File

@@ -166,3 +166,11 @@ func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size i
FindAll()
return
}
// 查找IPItem对应的列表ID
func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) {
return this.Query(tx).
Pk(itemId).
Result("listId").
FindInt64Col(0)
}

View File

@@ -75,9 +75,10 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) {
}
// 创建名单
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) {
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) {
op := NewIPListOperator()
op.IsOn = true
op.UserId = userId
op.State = IPListStateEnabled
op.Type = listType
op.Name = name
@@ -128,3 +129,18 @@ func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) {
err = SharedSysSettingDAO.UpdateSetting(tx, SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value)))
return value, nil
}
// 检查用户权限
func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error {
ok, err := this.Query(tx).
Pk(listId).
Attr("userId", userId).
Exist()
if err != nil {
return err
}
if ok {
return nil
}
return ErrNotFound
}

View File

@@ -49,14 +49,14 @@ func (this *HTTPFirewallPolicyService) FindAllEnabledHTTPFirewallPolicies(ctx co
// 创建防火墙策略
func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Context, req *pb.CreateHTTPFirewallPolicyRequest) (*pb.CreateHTTPFirewallPolicyResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, req.IsOn, req.Name, req.Description, nil, nil)
policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, userId, req.IsOn, req.Name, req.Description, nil, nil)
if err != nil {
return nil, err
}
@@ -263,13 +263,20 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicyGroups(ctx contex
// 修改inbound信息
func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallInboundConfig(ctx context.Context, req *pb.UpdateHTTPFirewallInboundConfigRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInbound(tx, req.HttpFirewallPolicyId, req.InboundJSON)
if err != nil {
return nil, err

View File

@@ -15,13 +15,20 @@ type IPItemService struct {
// 创建IP
func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPItemRequest) (*pb.CreateIPItemResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason)
if err != nil {
return nil, err
@@ -33,13 +40,25 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte
// 修改IP
func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.UpdateIPItem(tx, req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason)
if err != nil {
return nil, err
@@ -50,13 +69,25 @@ func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPIte
// 删除IP
func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPItemRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
if err != nil {
return nil, err
@@ -67,13 +98,20 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// 计算IP数量
func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.CountIPItemsWithListIdRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId)
if err != nil {
return nil, err
@@ -84,13 +122,20 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C
// 列出单页的IP
func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.ListIPItemsWithListIdRequest) (*pb.ListIPItemsWithListIdResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Offset, req.Size)
if err != nil {
return nil, err
@@ -113,7 +158,7 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li
// 查找单个IP
func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEnabledIPItemRequest) (*pb.FindEnabledIPItemResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -127,6 +172,14 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn
if item == nil {
return &pb.FindEnabledIPItemResponse{IpItem: nil}, nil
}
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, int64(item.ListId))
if err != nil {
return nil, err
}
}
return &pb.FindEnabledIPItemResponse{IpItem: &pb.IPItem{
Id: int64(item.Id),
IpFrom: item.IpFrom,

View File

@@ -15,14 +15,14 @@ type IPListService struct {
// 创建IP列表
func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPListRequest) (*pb.CreateIPListResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
listId, err := models.SharedIPListDAO.CreateIPList(tx, req.Type, req.Name, req.Code, req.TimeoutJSON)
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON)
if err != nil {
return nil, err
}