diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 719ff5f5..4e1a64eb 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -63,6 +63,11 @@ func (this *HTTPFirewallPolicyDAO) DisableHTTPFirewallPolicy(tx *dbs.Tx, policyI return err } + err = this.NotifyDisable(tx, policyId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } @@ -482,6 +487,23 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyServerId(tx *dbs.Tx, poli return err } +// FindFirewallPolicyIdsWithServerId 查找服务独立关联的策略IDs +func (this *HTTPFirewallPolicyDAO) FindFirewallPolicyIdsWithServerId(tx *dbs.Tx, serverId int64) ([]int64, error) { + var result = []int64{} + ones, err := this.Query(tx). + Attr("serverId", serverId). + State(HTTPFirewallPolicyStateEnabled). + Result("id"). + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + result = append(result, int64(one.(*HTTPFirewallPolicy).Id)) + } + return result, nil +} + // NotifyUpdate 通知更新 func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) @@ -508,3 +530,65 @@ func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) erro return nil } + +// NotifyDisable 通知禁用 +func (this *HTTPFirewallPolicyDAO) NotifyDisable(tx *dbs.Tx, policyId int64) error { + if policyId <= 0 { + return nil + } + + // 禁用IP名单 + inboundString, err := this.Query(tx). + Pk(policyId). + Result("inbound"). + FindStringCol("") + if err != nil { + return err + } + if len(inboundString) > 0 { + var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{} + err = json.Unmarshal([]byte(inboundString), inboundConfig) + if err != nil { + // 不处理错误 + return nil + } + + if inboundConfig.AllowListRef != nil && inboundConfig.AllowListRef.ListId > 0 { + err = SharedIPListDAO.DisableIPList(tx, inboundConfig.AllowListRef.ListId) + if err != nil { + return err + } + + err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.AllowListRef.ListId) + if err != nil { + return err + } + } + + if inboundConfig.DenyListRef != nil && inboundConfig.DenyListRef.ListId > 0 { + err = SharedIPListDAO.DisableIPList(tx, inboundConfig.DenyListRef.ListId) + if err != nil { + return err + } + + err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.DenyListRef.ListId) + if err != nil { + return err + } + } + + if inboundConfig.GreyListRef != nil && inboundConfig.GreyListRef.ListId > 0 { + err = SharedIPListDAO.DisableIPList(tx, inboundConfig.GreyListRef.ListId) + if err != nil { + return err + } + + err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.GreyListRef.ListId) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 8b2f7e18..3c912e5d 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -75,6 +75,21 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { return this.NotifyUpdate(tx, id) } +// DisableIPItemsWithListId 禁用某个IP名单内的所有IP +func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error { + version, err := SharedIPListDAO.IncreaseVersion(tx) + if err != nil { + return err + } + + return this.Query(tx). + Attr("listId", listId). + State(IPItemStateEnabled). + Set("version", version). + Set("state", IPItemStateDisabled). + UpdateQuickly() +} + // FindEnabledIPItem 查找启用中的条目 func (this *IPItemDAO) FindEnabledIPItem(tx *dbs.Tx, id int64) (*IPItem, error) { result, err := this.Query(tx). @@ -249,6 +264,7 @@ func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size i Where("(expiredAt=0 OR expiredAt>:expiredAt)"). Param("expiredAt", time.Now().Unix()). Asc("version"). + Asc("id"). Limit(size). Slice(&result). FindAll() diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index d511529d..edc3200c 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -74,6 +74,13 @@ func (this *ServerDAO) DisableServer(tx *dbs.Tx, serverId int64) (err error) { if err != nil { return err } + + // 删除对应的操作 + err = this.NotifyDisable(tx, serverId) + if err != nil { + return err + } + err = this.NotifyUpdate(tx, serverId) if err != nil { return err @@ -2150,3 +2157,25 @@ func (this *ServerDAO) NotifyDNSUpdate(tx *dbs.Tx, serverId int64) error { } return dns.SharedDNSTaskDAO.CreateServerTask(tx, serverId, dns.DNSTaskTypeServerChange) } + +// NotifyDisable 通知禁用 +func (this *ServerDAO) NotifyDisable(tx *dbs.Tx, serverId int64) error { + // 禁用缓存策略相关的内容 + policyIds, err := SharedHTTPFirewallPolicyDAO.FindFirewallPolicyIdsWithServerId(tx, serverId) + if err != nil { + return err + } + for _, policyId := range policyIds { + err = SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(tx, policyId) + if err != nil { + return err + } + + err = SharedHTTPFirewallPolicyDAO.NotifyDisable(tx, policyId) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/db/models/sys_locker_dao_test.go b/internal/db/models/sys_locker_dao_test.go index 3574a1f5..13b73bad 100644 --- a/internal/db/models/sys_locker_dao_test.go +++ b/internal/db/models/sys_locker_dao_test.go @@ -34,7 +34,8 @@ func TestSysLocker_Increase(t *testing.T) { defer wg.Done() v, err := NewSysLockerDAO().Increase(nil, "hello", 0) if err != nil { - t.Fatal(err) + t.Log("err:", err) + return } t.Log("v:", v) }() diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 3768da20..475f238e 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -97,12 +97,12 @@ func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListE } var tx = this.NullTx() - lists, err := models.SharedIPListDAO.ListEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword, req.Offset, req.Size) + ipLists, err := models.SharedIPListDAO.ListEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword, req.Offset, req.Size) if err != nil { return nil, err } var pbLists []*pb.IPList - for _, list := range lists { + for _, list := range ipLists { pbLists = append(pbLists, &pb.IPList{ Id: int64(list.Id), IsOn: list.IsOn == 1, @@ -129,6 +129,13 @@ func (this *IPListService) DeleteIPList(ctx context.Context, req *pb.DeleteIPLis if err != nil { return nil, err } + + // 删除所有IP + err = models.SharedIPItemDAO.DisableIPItemsWithListId(tx, req.IpListId) + if err != nil { + return nil, err + } + return this.Success() }