Files
EdgeAPI/internal/db/models/ip_item_dao.go

294 lines
6.5 KiB
Go
Raw Normal View History

2020-11-07 19:40:24 +08:00
package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
2021-02-02 19:29:36 +08:00
"github.com/TeaOSLab/EdgeAPI/internal/utils"
2020-11-07 19:40:24 +08:00
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
2020-11-07 19:40:24 +08:00
"github.com/iwind/TeaGo/types"
2021-02-02 19:29:36 +08:00
"math"
"time"
2020-11-07 19:40:24 +08:00
)
const (
IPItemStateEnabled = 1 // 已启用
IPItemStateDisabled = 0 // 已禁用
)
type IPItemType = string
const (
IPItemTypeIPv4 IPItemType = "ipv4" // IPv4
IPItemTypeIPv6 IPItemType = "ipv6" // IPv6
IPItemTypeAll IPItemType = "all" // 所有IP
)
2020-11-07 19:40:24 +08:00
type IPItemDAO dbs.DAO
func NewIPItemDAO() *IPItemDAO {
return dbs.NewDAO(&IPItemDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeIPItems",
Model: new(IPItem),
PkName: "id",
},
}).(*IPItemDAO)
}
var SharedIPItemDAO *IPItemDAO
func init() {
dbs.OnReady(func() {
SharedIPItemDAO = NewIPItemDAO()
})
}
// 启用条目
func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
2020-11-07 19:40:24 +08:00
Pk(id).
Set("state", IPItemStateEnabled).
Update()
return err
}
// 禁用条目
func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
_, err = this.Query(tx).
2020-11-07 19:40:24 +08:00
Pk(id).
Set("state", IPItemStateDisabled).
Set("version", version).
2020-11-07 19:40:24 +08:00
Update()
if err != nil {
return err
}
return this.NotifyUpdate(tx, id)
2020-11-07 19:40:24 +08:00
}
// 查找启用中的条目
func (this *IPItemDAO) FindEnabledIPItem(tx *dbs.Tx, id int64) (*IPItem, error) {
result, err := this.Query(tx).
2020-11-07 19:40:24 +08:00
Pk(id).
Attr("state", IPItemStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*IPItem), err
}
// 创建IP
func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, listId int64, ipFrom string, ipTo string, expiredAt int64, reason string, itemType IPItemType) (int64, error) {
version, err := SharedIPListDAO.IncreaseVersion(tx)
2020-11-07 19:40:24 +08:00
if err != nil {
return 0, err
}
op := NewIPItemOperator()
op.ListId = listId
op.IpFrom = ipFrom
op.IpTo = ipTo
2021-02-02 19:29:36 +08:00
op.IpFromLong = utils.IP2Long(ipFrom)
op.IpToLong = utils.IP2Long(ipTo)
2020-11-07 19:40:24 +08:00
op.Reason = reason
op.Type = itemType
2020-11-07 19:40:24 +08:00
op.Version = version
if expiredAt < 0 {
expiredAt = 0
}
op.ExpiredAt = expiredAt
op.State = IPItemStateEnabled
err = this.Save(tx, op)
2020-11-07 19:40:24 +08:00
if err != nil {
return 0, err
}
itemId := types.Int64(op.Id)
err = this.NotifyUpdate(tx, itemId)
if err != nil {
return 0, err
}
return itemId, nil
2020-11-07 19:40:24 +08:00
}
// 修改IP
func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string, itemType IPItemType) error {
2020-11-07 19:40:24 +08:00
if itemId <= 0 {
return errors.New("invalid itemId")
}
listId, err := this.Query(tx).
2020-11-07 19:40:24 +08:00
Pk(itemId).
Result("listId").
FindInt64Col(0)
if err != nil {
return err
}
if listId == 0 {
return errors.New("not found")
}
version, err := SharedIPListDAO.IncreaseVersion(tx)
2020-11-07 19:40:24 +08:00
if err != nil {
return err
}
op := NewIPItemOperator()
op.Id = itemId
op.IpFrom = ipFrom
op.IpTo = ipTo
2021-02-02 19:29:36 +08:00
op.IpFromLong = utils.IP2Long(ipFrom)
op.IpToLong = utils.IP2Long(ipTo)
2020-11-07 19:40:24 +08:00
op.Reason = reason
op.Type = itemType
2020-11-07 19:40:24 +08:00
if expiredAt < 0 {
expiredAt = 0
}
op.ExpiredAt = expiredAt
op.Version = version
err = this.Save(tx, op)
if err != nil {
return err
}
return this.NotifyUpdate(tx, itemId)
2020-11-07 19:40:24 +08:00
}
// 计算IP数量
func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64) (int64, error) {
return this.Query(tx).
2020-11-07 19:40:24 +08:00
State(IPItemStateEnabled).
Attr("listId", listId).
Count()
}
// 查找IP列表
func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, offset int64, size int64) (result []*IPItem, err error) {
_, err = this.Query(tx).
2020-11-07 19:40:24 +08:00
State(IPItemStateEnabled).
Attr("listId", listId).
DescPk().
Slice(&result).
Offset(offset).
Limit(size).
FindAll()
return
}
// 根据版本号查找IP列表
func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size int64) (result []*IPItem, err error) {
_, err = this.Query(tx).
// 这里不要设置状态参数,因为我们要知道哪些是删除的
Gt("version", version).
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
Param("expiredAt", time.Now().Unix()).
Asc("version").
Limit(size).
Slice(&result).
FindAll()
return
}
2021-01-03 20:18:07 +08:00
// 查找IPItem对应的列表ID
func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) {
return this.Query(tx).
Pk(itemId).
Result("listId").
FindInt64Col(0)
}
2021-02-02 19:29:36 +08:00
// 查找包含某个IP的Item
func (this *IPItemDAO) FindEnabledItemContainsIP(tx *dbs.Tx, listId int64, ip uint64) (*IPItem, error) {
query := this.Query(tx).
Attr("listId", listId).
State(IPItemStateEnabled)
if ip > math.MaxUint32 {
query.Where("(type='all' OR ipFromLong=:ip)")
} else {
query.Where("(type='all' OR ipFromLong=:ip OR (ipToLong>0 AND ipFromLong<=:ip AND ipToLong>=:ip))").
Param("ip", ip)
}
one, err := query.Find()
if err != nil {
return nil, err
}
if one == nil {
return nil, nil
}
return one.(*IPItem), nil
}
// 通知更新
func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error {
// 获取ListId
listId, err := this.FindItemListId(tx, itemId)
if err != nil {
return err
}
if listId == 0 {
return nil
}
httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
if err != nil {
return err
}
resultClusterIds := []int64{}
for _, policyId := range httpFirewallPolicyIds {
// 集群
clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
if err != nil {
return err
}
for _, clusterId := range clusterIds {
if !lists.ContainsInt64(resultClusterIds, clusterId) {
resultClusterIds = append(resultClusterIds, clusterId)
}
}
// 服务
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId)
if err != nil {
return err
}
if len(webIds) > 0 {
for _, webId := range webIds {
serverId, err := SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId)
if err != nil {
return err
}
if serverId > 0 {
clusterId, err := SharedServerDAO.FindServerClusterId(tx, serverId)
if err != nil {
return err
}
if !lists.ContainsInt64(resultClusterIds, clusterId) {
resultClusterIds = append(resultClusterIds, clusterId)
}
}
}
}
}
if len(resultClusterIds) > 0 {
for _, clusterId := range resultClusterIds {
err = SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, NodeTaskTypeIPItemChanged)
if err != nil {
return err
}
}
}
return nil
}