From f49c26cdabfd980d27d724cab303c3cd8ef9a57c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Fri, 1 Jan 2021 23:31:30 +0800 Subject: [PATCH] =?UTF-8?q?=E6=89=80=E6=9C=89=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E7=9A=84=E6=93=8D=E4=BD=9C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=BA=8B=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 4 +- go.sum | 5 - internal/db/models/acme_authentication_dao.go | 8 +- internal/db/models/acme_task_dao.go | 78 +++--- internal/db/models/acme_task_log_dao.go | 8 +- internal/db/models/acme_user_dao.go | 40 ++-- internal/db/models/admin_dao.go | 68 +++--- internal/db/models/api_access_token_dao.go | 10 +- internal/db/models/api_node_dao.go | 60 ++--- internal/db/models/api_node_model_ext.go | 9 +- internal/db/models/api_token_dao.go | 24 +- internal/db/models/db_node_dao.go | 36 +-- internal/db/models/db_node_initializer.go | 4 +- internal/db/models/dns_domain_dao.go | 64 ++--- internal/db/models/dns_provider_dao.go | 40 ++-- internal/db/models/file_chunk_dao.go | 20 +- internal/db/models/file_dao.go | 20 +- internal/db/models/http_access_log_dao.go | 24 +- .../db/models/http_access_log_policy_dao.go | 30 +-- internal/db/models/http_cache_policy_dao.go | 48 ++-- .../db/models/http_firewall_policy_dao.go | 58 ++--- internal/db/models/http_firewall_rule_dao.go | 26 +- .../db/models/http_firewall_rule_group_dao.go | 50 ++-- .../db/models/http_firewall_rule_set_dao.go | 38 +-- internal/db/models/http_gzip_dao.go | 30 +-- internal/db/models/http_header_dao.go | 34 +-- internal/db/models/http_header_policy_dao.go | 54 ++--- internal/db/models/http_location_dao.go | 64 ++--- internal/db/models/http_page_dao.go | 30 +-- internal/db/models/http_rewrite_rule_dao.go | 30 +-- internal/db/models/http_web_dao.go | 120 +++++----- internal/db/models/http_websocket_dao.go | 24 +- internal/db/models/ip_item_dao.go | 40 ++-- internal/db/models/ip_library_dao.go | 24 +- internal/db/models/ip_list_dao.go | 32 +-- internal/db/models/log_dao.go | 24 +- internal/db/models/login_dao.go | 34 +-- internal/db/models/message_dao.go | 56 ++--- internal/db/models/node_cluster_dao.go | 164 ++++++------- internal/db/models/node_dao.go | 194 +++++++-------- internal/db/models/node_grant_dao.go | 36 +-- internal/db/models/node_group_dao.go | 32 +-- internal/db/models/node_ip_address_dao.go | 48 ++-- internal/db/models/node_log_dao.go | 16 +- internal/db/models/node_login_dao.go | 32 +-- internal/db/models/node_price_item_dao.go | 32 +-- internal/db/models/node_region_dao.go | 46 ++-- internal/db/models/origin_dao.go | 40 ++-- internal/db/models/provider_dao.go | 16 +- internal/db/models/region_city_dao.go | 24 +- internal/db/models/region_country_dao.go | 32 +-- internal/db/models/region_provider_dao.go | 16 +- internal/db/models/region_province_dao.go | 32 +-- internal/db/models/reverse_proxy_dao.go | 48 ++-- internal/db/models/server_daily_stat_dao.go | 20 +- internal/db/models/server_dao.go | 225 +++++++++--------- internal/db/models/server_group_dao.go | 32 +-- internal/db/models/ssl_cert_dao.go | 58 ++--- internal/db/models/ssl_cert_group_dao.go | 16 +- internal/db/models/ssl_policy_dao.go | 42 ++-- internal/db/models/sub_user_dao.go | 16 +- internal/db/models/sys_event_dao.go | 12 +- internal/db/models/sys_event_types.go | 11 +- internal/db/models/sys_locker_dao.go | 14 +- internal/db/models/sys_setting_dao.go | 20 +- internal/db/models/tcp_firewall_policy_dao.go | 6 +- internal/db/models/user_access_key_dao.go | 32 +-- internal/db/models/user_bill_dao.go | 34 +-- internal/db/models/user_dao.go | 68 +++--- internal/db/models/user_node_dao.go | 52 ++-- internal/db/models/user_node_model_ext.go | 2 +- internal/db/models/utils.go | 4 +- internal/installers/queue.go | 40 ++-- internal/iplibrary/manager.go | 2 +- internal/iplibrary/updater.go | 8 +- internal/nodes/api_node.go | 6 +- internal/nodes/node_status_executor.go | 2 +- internal/nodes/rest_server.go | 2 +- internal/remotelogs/utils.go | 2 +- .../services/service_acme_authentication.go | 4 +- internal/rpc/services/service_acme_task.go | 55 +++-- internal/rpc/services/service_acme_user.go | 34 ++- internal/rpc/services/service_admin.go | 72 ++++-- .../rpc/services/service_api_access_token.go | 7 +- internal/rpc/services/service_api_node.go | 28 ++- internal/rpc/services/service_base.go | 6 + internal/rpc/services/service_db_node.go | 28 ++- internal/rpc/services/service_dns.go | 7 +- internal/rpc/services/service_dns_domain.go | 80 +++++-- internal/rpc/services/service_dns_provider.go | 32 ++- internal/rpc/services/service_file.go | 8 +- internal/rpc/services/service_file_chunk.go | 13 +- .../rpc/services/service_http_access_log.go | 14 +- .../service_http_access_log_policy.go | 5 +- .../rpc/services/service_http_cache_policy.go | 32 ++- .../services/service_http_firewall_policy.go | 86 ++++--- .../service_http_firewall_rule_group.go | 24 +- .../service_http_firewall_rule_set.go | 16 +- internal/rpc/services/service_http_header.go | 12 +- .../services/service_http_header_policy.go | 28 ++- .../rpc/services/service_http_location.go | 40 +++- internal/rpc/services/service_http_page.go | 12 +- .../rpc/services/service_http_rewrite_rule.go | 8 +- internal/rpc/services/service_http_web.go | 72 ++++-- .../rpc/services/service_http_websocket.go | 8 +- internal/rpc/services/service_ip_item.go | 30 ++- internal/rpc/services/service_ip_library.go | 32 ++- internal/rpc/services/service_ip_list.go | 12 +- internal/rpc/services/service_log.go | 32 ++- internal/rpc/services/service_login.go | 11 +- internal/rpc/services/service_message.go | 28 ++- internal/rpc/services/service_node.go | 196 +++++++++------ internal/rpc/services/service_node_cluster.go | 134 ++++++++--- internal/rpc/services/service_node_grant.go | 26 +- internal/rpc/services/service_node_group.go | 24 +- .../rpc/services/service_node_ip_address.go | 28 ++- internal/rpc/services/service_node_log.go | 12 +- .../rpc/services/service_node_price_item.go | 24 +- internal/rpc/services/service_node_region.go | 40 +++- internal/rpc/services/service_node_stream.go | 12 +- internal/rpc/services/service_origin.go | 18 +- .../rpc/services/service_region_country.go | 9 +- .../rpc/services/service_region_province.go | 10 +- .../rpc/services/service_reverse_proxy.go | 28 ++- internal/rpc/services/service_server.go | 218 +++++++++++------ .../rpc/services/service_server_daily_stat.go | 4 +- internal/rpc/services/service_server_group.go | 24 +- internal/rpc/services/service_ssl_cert.go | 34 ++- internal/rpc/services/service_ssl_policy.go | 17 +- internal/rpc/services/service_sys_setting.go | 8 +- internal/rpc/services/service_user.go | 81 +++++-- .../rpc/services/service_user_access_key.go | 20 +- internal/rpc/services/service_user_bill.go | 14 +- internal/rpc/services/service_user_node.go | 32 ++- internal/rpc/services/sevice_http_gzip.go | 12 +- internal/rpc/utils/utils.go | 6 +- internal/setup/setup.go | 8 +- internal/tasks/event_looper.go | 8 +- internal/tasks/health_check_cluster_task.go | 6 +- internal/tasks/health_check_executor.go | 12 +- internal/tasks/health_check_task.go | 2 +- internal/tasks/log_task.go | 16 +- internal/tasks/message_task.go | 2 +- internal/tasks/node_log_cleaner_task.go | 2 +- internal/tasks/node_monitor_task.go | 12 +- .../tasks/ssl_cert_expire_check_executor.go | 36 +-- 146 files changed, 2845 insertions(+), 2068 deletions(-) diff --git a/go.mod b/go.mod index 73bae235..59bb9466 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.15 replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon +replace github.com/iwind/TeaGo => /Users/WorkSpace/TeaGo + require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 @@ -13,7 +15,7 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/go-yaml/yaml v2.1.0+incompatible github.com/golang/protobuf v1.4.2 - github.com/iwind/TeaGo v0.0.0-20201209122854-4c8b1780a42b + github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/mozillazg/go-pinyin v0.18.0 github.com/pkg/sftp v1.12.0 diff --git a/go.sum b/go.sum index d29a3305..358d441c 100644 --- a/go.sum +++ b/go.sum @@ -173,9 +173,6 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/iij/doapi v0.0.0-20190504054126-0bbf12d6d7df/go.mod h1:QMZY7/J/KSQEhKWFeDesPjMj+wCHReeknARU3wqlyN4= -github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= -github.com/iwind/TeaGo v0.0.0-20201209122854-4c8b1780a42b h1:tLO0mUyXn4Szo6SPEtJmeR2aQSHBXy9MsfnaLlulQA0= -github.com/iwind/TeaGo v0.0.0-20201209122854-4c8b1780a42b/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= @@ -282,8 +279,6 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/sacloud/libsacloud v1.36.2/go.mod h1:P7YAOVmnIn3DKHqCZcUKYUXmSwGBm3yS7IBEjKVSrjg= github.com/shirou/gopsutil v2.20.9+incompatible h1:msXs2frUV+O/JLva9EDLpuJ84PrFsdCTCQex8PUdtkQ= github.com/shirou/gopsutil v2.20.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= -github.com/shirou/gopsutil v3.20.11+incompatible h1:LJr4ZQK4mPpIV5gOa4jCOKOGb4ty4DZO54I4FGqIpto= -github.com/shirou/gopsutil v3.20.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= diff --git a/internal/db/models/acme_authentication_dao.go b/internal/db/models/acme_authentication_dao.go index 77ade7a4..221fc708 100644 --- a/internal/db/models/acme_authentication_dao.go +++ b/internal/db/models/acme_authentication_dao.go @@ -28,19 +28,19 @@ func init() { } // 创建认证信息 -func (this *ACMEAuthenticationDAO) CreateAuth(taskId int64, domain string, token string, key string) error { +func (this *ACMEAuthenticationDAO) CreateAuth(tx *dbs.Tx, taskId int64, domain string, token string, key string) error { op := NewACMEAuthenticationOperator() op.TaskId = taskId op.Domain = domain op.Token = token op.Key = key - err := this.Save(op) + err := this.Save(tx, op) return err } // 根据令牌查找认证信息 -func (this *ACMEAuthenticationDAO) FindAuthWithToken(token string) (*ACMEAuthentication, error) { - one, err := this.Query(). +func (this *ACMEAuthenticationDAO) FindAuthWithToken(tx *dbs.Tx, token string) (*ACMEAuthentication, error) { + one, err := this.Query(tx). Attr("token", token). DescPk(). Find() diff --git a/internal/db/models/acme_task_dao.go b/internal/db/models/acme_task_dao.go index 2efd5799..5d87f81d 100644 --- a/internal/db/models/acme_task_dao.go +++ b/internal/db/models/acme_task_dao.go @@ -41,8 +41,8 @@ func init() { } // 启用条目 -func (this *ACMETaskDAO) EnableACMETask(id int64) error { - _, err := this.Query(). +func (this *ACMETaskDAO) EnableACMETask(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ACMETaskStateEnabled). Update() @@ -50,8 +50,8 @@ func (this *ACMETaskDAO) EnableACMETask(id int64) error { } // 禁用条目 -func (this *ACMETaskDAO) DisableACMETask(id int64) error { - _, err := this.Query(). +func (this *ACMETaskDAO) DisableACMETask(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ACMETaskStateDisabled). Update() @@ -59,8 +59,8 @@ func (this *ACMETaskDAO) DisableACMETask(id int64) error { } // 查找启用中的条目 -func (this *ACMETaskDAO) FindEnabledACMETask(id int64) (*ACMETask, error) { - result, err := this.Query(). +func (this *ACMETaskDAO) FindEnabledACMETask(tx *dbs.Tx, id int64) (*ACMETask, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ACMETaskStateEnabled). Find() @@ -71,24 +71,24 @@ func (this *ACMETaskDAO) FindEnabledACMETask(id int64) (*ACMETask, error) { } // 计算某个ACME用户相关的任务数量 -func (this *ACMETaskDAO) CountACMETasksWithACMEUserId(acmeUserId int64) (int64, error) { - return this.Query(). +func (this *ACMETaskDAO) CountACMETasksWithACMEUserId(tx *dbs.Tx, acmeUserId int64) (int64, error) { + return this.Query(tx). State(ACMETaskStateEnabled). Attr("acmeUserId", acmeUserId). Count() } // 计算某个DNS服务商相关的任务数量 -func (this *ACMETaskDAO) CountACMETasksWithDNSProviderId(dnsProviderId int64) (int64, error) { - return this.Query(). +func (this *ACMETaskDAO) CountACMETasksWithDNSProviderId(tx *dbs.Tx, dnsProviderId int64) (int64, error) { + return this.Query(tx). State(ACMETaskStateEnabled). Attr("dnsProviderId", dnsProviderId). Count() } // 停止某个证书相关任务 -func (this *ACMETaskDAO) DisableAllTasksWithCertId(certId int64) error { - _, err := this.Query(). +func (this *ACMETaskDAO) DisableAllTasksWithCertId(tx *dbs.Tx, certId int64) error { + _, err := this.Query(tx). Attr("certId", certId). Set("state", ACMETaskStateDisabled). Update() @@ -96,15 +96,15 @@ func (this *ACMETaskDAO) DisableAllTasksWithCertId(certId int64) error { } // 计算所有任务数量 -func (this *ACMETaskDAO) CountAllEnabledACMETasks(adminId int64, userId int64) (int64, error) { - return NewQuery(this, adminId, userId). +func (this *ACMETaskDAO) CountAllEnabledACMETasks(tx *dbs.Tx, adminId int64, userId int64) (int64, error) { + return NewQuery(tx, this, adminId, userId). State(ACMETaskStateEnabled). Count() } // 列出单页任务 -func (this *ACMETaskDAO) ListEnabledACMETasks(adminId int64, userId int64, offset int64, size int64) (result []*ACMETask, err error) { - _, err = NewQuery(this, adminId, userId). +func (this *ACMETaskDAO) ListEnabledACMETasks(tx *dbs.Tx, adminId int64, userId int64, offset int64, size int64) (result []*ACMETask, err error) { + _, err = NewQuery(tx, this, adminId, userId). State(ACMETaskStateEnabled). DescPk(). Offset(offset). @@ -115,7 +115,7 @@ func (this *ACMETaskDAO) ListEnabledACMETasks(adminId int64, userId int64, offse } // 创建任务 -func (this *ACMETaskDAO) CreateACMETask(adminId int64, userId int64, authType acme.AuthType, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) (int64, error) { +func (this *ACMETaskDAO) CreateACMETask(tx *dbs.Tx, adminId int64, userId int64, authType acme.AuthType, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) (int64, error) { op := NewACMETaskOperator() op.AdminId = adminId op.UserId = userId @@ -137,7 +137,7 @@ func (this *ACMETaskDAO) CreateACMETask(adminId int64, userId int64, authType ac op.AutoRenew = autoRenew op.IsOn = true op.State = ACMETaskStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -145,7 +145,7 @@ func (this *ACMETaskDAO) CreateACMETask(adminId int64, userId int64, authType ac } // 修改任务 -func (this *ACMETaskDAO) UpdateACMETask(acmeTaskId int64, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) error { +func (this *ACMETaskDAO) UpdateACMETask(tx *dbs.Tx, acmeTaskId int64, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) error { if acmeTaskId <= 0 { return errors.New("invalid acmeTaskId") } @@ -167,20 +167,20 @@ func (this *ACMETaskDAO) UpdateACMETask(acmeTaskId int64, acmeUserId int64, dnsP } op.AutoRenew = autoRenew - err := this.Save(op) + err := this.Save(tx, op) return err } // 检查权限 -func (this *ACMETaskDAO) CheckACMETask(adminId int64, userId int64, acmeTaskId int64) (bool, error) { - return NewQuery(this, adminId, userId). +func (this *ACMETaskDAO) CheckACMETask(tx *dbs.Tx, adminId int64, userId int64, acmeTaskId int64) (bool, error) { + return NewQuery(tx, this, adminId, userId). State(ACMETaskStateEnabled). Pk(acmeTaskId). Exist() } // 设置任务关联的证书 -func (this *ACMETaskDAO) UpdateACMETaskCert(taskId int64, certId int64) error { +func (this *ACMETaskDAO) UpdateACMETaskCert(tx *dbs.Tx, taskId int64, certId int64) error { if taskId <= 0 { return errors.New("invalid taskId") } @@ -188,16 +188,16 @@ func (this *ACMETaskDAO) UpdateACMETaskCert(taskId int64, certId int64) error { op := NewACMETaskOperator() op.Id = taskId op.CertId = certId - err := this.Save(op) + err := this.Save(tx, op) return err } // 执行任务并记录日志 -func (this *ACMETaskDAO) RunTask(taskId int64) (isOk bool, errMsg string, resultCertId int64) { - isOk, errMsg, resultCertId = this.runTaskWithoutLog(taskId) +func (this *ACMETaskDAO) RunTask(tx *dbs.Tx, taskId int64) (isOk bool, errMsg string, resultCertId int64) { + isOk, errMsg, resultCertId = this.runTaskWithoutLog(tx, taskId) // 记录日志 - err := SharedACMETaskLogDAO.CreateACMETaskLog(taskId, isOk, errMsg) + err := SharedACMETaskLogDAO.CreateACMETaskLog(tx, taskId, isOk, errMsg) if err != nil { logs.Error(err) } @@ -206,8 +206,8 @@ func (this *ACMETaskDAO) RunTask(taskId int64) (isOk bool, errMsg string, result } // 执行任务但并不记录日志 -func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg string, resultCertId int64) { - task, err := this.FindEnabledACMETask(taskId) +func (this *ACMETaskDAO) runTaskWithoutLog(tx *dbs.Tx, taskId int64) (isOk bool, errMsg string, resultCertId int64) { + task, err := this.FindEnabledACMETask(tx, taskId) if err != nil { errMsg = "查询任务信息时出错:" + err.Error() return @@ -222,7 +222,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri } // ACME用户 - user, err := SharedACMEUserDAO.FindEnabledACMEUser(int64(task.AcmeUserId)) + user, err := SharedACMEUserDAO.FindEnabledACMEUser(tx, int64(task.AcmeUserId)) if err != nil { errMsg = "查询ACME用户时出错:" + err.Error() return @@ -244,7 +244,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri return err } - err = SharedACMEUserDAO.UpdateACMEUserRegistration(int64(user.Id), resourceJSON) + err = SharedACMEUserDAO.UpdateACMEUserRegistration(tx, int64(user.Id), resourceJSON) return err }) @@ -259,7 +259,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri var acmeTask *acme.Task = nil if task.AuthType == acme.AuthTypeDNS { // DNS服务商 - dnsProvider, err := SharedDNSProviderDAO.FindEnabledDNSProvider(int64(task.DnsProviderId)) + dnsProvider, err := SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(task.DnsProviderId)) if err != nil { errMsg = "查找DNS服务商账号信息时出错:" + err.Error() return @@ -301,7 +301,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri acmeRequest := acme.NewRequest(acmeTask) acmeRequest.OnAuth(func(domain, token, keyAuth string) { - err := SharedACMEAuthenticationDAO.CreateAuth(taskId, domain, token, keyAuth) + err := SharedACMEAuthenticationDAO.CreateAuth(tx, taskId, domain, token, keyAuth) if err != nil { logs.Println("[ACME]write authentication to database error: " + err.Error()) } @@ -326,7 +326,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri // 保存证书 resultCertId = int64(task.CertId) if resultCertId > 0 { - cert, err := SharedSSLCertDAO.FindEnabledSSLCert(resultCertId) + cert, err := SharedSSLCertDAO.FindEnabledSSLCert(tx, resultCertId) if err != nil { errMsg = "证书生成成功,但查询已绑定的证书时出错:" + err.Error() return @@ -335,7 +335,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri errMsg = "证书已被管理员或用户删除" // 禁用 - err = SharedACMETaskDAO.DisableACMETask(taskId) + err = SharedACMETaskDAO.DisableACMETask(tx, taskId) if err != nil { errMsg = "禁用失效的ACME任务出错:" + err.Error() } @@ -343,26 +343,26 @@ func (this *ACMETaskDAO) runTaskWithoutLog(taskId int64) (isOk bool, errMsg stri return } - err = SharedSSLCertDAO.UpdateCert(resultCertId, cert.IsOn == 1, cert.Name, cert.Description, cert.ServerName, cert.IsCA == 1, certData, keyData, sslConfig.TimeBeginAt, sslConfig.TimeEndAt, sslConfig.DNSNames, sslConfig.CommonNames) + err = SharedSSLCertDAO.UpdateCert(tx, resultCertId, cert.IsOn == 1, cert.Name, cert.Description, cert.ServerName, cert.IsCA == 1, certData, keyData, sslConfig.TimeBeginAt, sslConfig.TimeEndAt, sslConfig.DNSNames, sslConfig.CommonNames) if err != nil { errMsg = "证书生成成功,但是修改数据库中的证书信息时出错:" + err.Error() return } } else { - resultCertId, err = SharedSSLCertDAO.CreateCert(int64(task.AdminId), int64(task.UserId), true, task.DnsDomain+"免费证书", "免费申请的证书", "", false, certData, keyData, sslConfig.TimeBeginAt, sslConfig.TimeEndAt, sslConfig.DNSNames, sslConfig.CommonNames) + resultCertId, err = SharedSSLCertDAO.CreateCert(tx, int64(task.AdminId), int64(task.UserId), true, task.DnsDomain+"免费证书", "免费申请的证书", "", false, certData, keyData, sslConfig.TimeBeginAt, sslConfig.TimeEndAt, sslConfig.DNSNames, sslConfig.CommonNames) if err != nil { errMsg = "证书生成成功,但是保存到数据库失败:" + err.Error() return } - err = SharedSSLCertDAO.UpdateCertACME(resultCertId, int64(task.Id)) + err = SharedSSLCertDAO.UpdateCertACME(tx, resultCertId, int64(task.Id)) if err != nil { errMsg = "证书生成成功,修改证书ACME信息时出错:" + err.Error() return } // 设置成功 - err = SharedACMETaskDAO.UpdateACMETaskCert(taskId, resultCertId) + err = SharedACMETaskDAO.UpdateACMETaskCert(tx, taskId, resultCertId) if err != nil { errMsg = "证书生成成功,设置任务关联的证书时出错:" + err.Error() return diff --git a/internal/db/models/acme_task_log_dao.go b/internal/db/models/acme_task_log_dao.go index 23812e30..d0fc797f 100644 --- a/internal/db/models/acme_task_log_dao.go +++ b/internal/db/models/acme_task_log_dao.go @@ -28,18 +28,18 @@ func init() { } // 生成日志 -func (this *ACMETaskLogDAO) CreateACMETaskLog(taskId int64, isOk bool, errMsg string) error { +func (this *ACMETaskLogDAO) CreateACMETaskLog(tx *dbs.Tx, taskId int64, isOk bool, errMsg string) error { op := NewACMETaskLogOperator() op.TaskId = taskId op.Error = errMsg op.IsOk = isOk - err := this.Save(op) + err := this.Save(tx, op) return err } // 取得任务的最后一条执行日志 -func (this *ACMETaskLogDAO) FindLatestACMETasKLog(taskId int64) (*ACMETaskLog, error) { - one, err := this.Query(). +func (this *ACMETaskLogDAO) FindLatestACMETasKLog(tx *dbs.Tx, taskId int64) (*ACMETaskLog, error) { + one, err := this.Query(tx). Attr("taskId", taskId). DescPk(). Find() diff --git a/internal/db/models/acme_user_dao.go b/internal/db/models/acme_user_dao.go index decfd11d..b0879065 100644 --- a/internal/db/models/acme_user_dao.go +++ b/internal/db/models/acme_user_dao.go @@ -40,8 +40,8 @@ func init() { } // 启用条目 -func (this *ACMEUserDAO) EnableACMEUser(id int64) error { - _, err := this.Query(). +func (this *ACMEUserDAO) EnableACMEUser(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ACMEUserStateEnabled). Update() @@ -49,8 +49,8 @@ func (this *ACMEUserDAO) EnableACMEUser(id int64) error { } // 禁用条目 -func (this *ACMEUserDAO) DisableACMEUser(id int64) error { - _, err := this.Query(). +func (this *ACMEUserDAO) DisableACMEUser(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ACMEUserStateDisabled). Update() @@ -58,8 +58,8 @@ func (this *ACMEUserDAO) DisableACMEUser(id int64) error { } // 查找启用中的条目 -func (this *ACMEUserDAO) FindEnabledACMEUser(id int64) (*ACMEUser, error) { - result, err := this.Query(). +func (this *ACMEUserDAO) FindEnabledACMEUser(tx *dbs.Tx, id int64) (*ACMEUser, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ACMEUserStateEnabled). Find() @@ -70,7 +70,7 @@ func (this *ACMEUserDAO) FindEnabledACMEUser(id int64) (*ACMEUser, error) { } // 创建用户 -func (this *ACMEUserDAO) CreateACMEUser(adminId int64, userId int64, email string, description string) (int64, error) { +func (this *ACMEUserDAO) CreateACMEUser(tx *dbs.Tx, adminId int64, userId int64, email string, description string) (int64, error) { // 生成私钥 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -90,7 +90,7 @@ func (this *ACMEUserDAO) CreateACMEUser(adminId int64, userId int64, email strin op.Description = description op.PrivateKey = privateKeyText op.State = ACMEUserStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -98,32 +98,32 @@ func (this *ACMEUserDAO) CreateACMEUser(adminId int64, userId int64, email strin } // 修改用户信息 -func (this *ACMEUserDAO) UpdateACMEUser(acmeUserId int64, description string) error { +func (this *ACMEUserDAO) UpdateACMEUser(tx *dbs.Tx, acmeUserId int64, description string) error { if acmeUserId <= 0 { return errors.New("invalid acmeUserId") } op := NewACMEUserOperator() op.Id = acmeUserId op.Description = description - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改用户ACME注册信息 -func (this *ACMEUserDAO) UpdateACMEUserRegistration(acmeUserId int64, registrationJSON []byte) error { +func (this *ACMEUserDAO) UpdateACMEUserRegistration(tx *dbs.Tx, acmeUserId int64, registrationJSON []byte) error { if acmeUserId <= 0 { return errors.New("invalid acmeUserId") } op := NewACMEUserOperator() op.Id = acmeUserId op.Registration = registrationJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算用户数量 -func (this *ACMEUserDAO) CountACMEUsersWithAdminId(adminId int64, userId int64) (int64, error) { - query := this.Query() +func (this *ACMEUserDAO) CountACMEUsersWithAdminId(tx *dbs.Tx, adminId int64, userId int64) (int64, error) { + query := this.Query(tx) if adminId > 0 { query.Attr("adminId", adminId) } @@ -137,8 +137,8 @@ func (this *ACMEUserDAO) CountACMEUsersWithAdminId(adminId int64, userId int64) } // 列出当前管理员的用户 -func (this *ACMEUserDAO) ListACMEUsers(adminId int64, userId int64, offset int64, size int64) (result []*ACMEUser, err error) { - query := this.Query() +func (this *ACMEUserDAO) ListACMEUsers(tx *dbs.Tx, adminId int64, userId int64, offset int64, size int64) (result []*ACMEUser, err error) { + query := this.Query(tx) if adminId > 0 { query.Attr("adminId", adminId) } @@ -157,13 +157,13 @@ func (this *ACMEUserDAO) ListACMEUsers(adminId int64, userId int64, offset int64 } // 查找所有用户 -func (this *ACMEUserDAO) FindAllACMEUsers(adminId int64, userId int64) (result []*ACMEUser, err error) { +func (this *ACMEUserDAO) FindAllACMEUsers(tx *dbs.Tx, adminId int64, userId int64) (result []*ACMEUser, err error) { // 防止没有传入条件导致返回的数据过多 if adminId <= 0 && userId <= 0 { return nil, errors.New("'adminId' or 'userId' should not be empty") } - query := this.Query() + query := this.Query(tx) if adminId > 0 { query.Attr("adminId", adminId) } @@ -179,12 +179,12 @@ func (this *ACMEUserDAO) FindAllACMEUsers(adminId int64, userId int64) (result [ } // 检查用户权限 -func (this *ACMEUserDAO) CheckACMEUser(acmeUserId int64, adminId int64, userId int64) (bool, error) { +func (this *ACMEUserDAO) CheckACMEUser(tx *dbs.Tx, acmeUserId int64, adminId int64, userId int64) (bool, error) { if acmeUserId <= 0 { return false, nil } - query := this.Query() + query := this.Query(tx) if adminId > 0 { query.Attr("adminId", adminId) } else if userId > 0 { diff --git a/internal/db/models/admin_dao.go b/internal/db/models/admin_dao.go index 20e725b0..10baf036 100644 --- a/internal/db/models/admin_dao.go +++ b/internal/db/models/admin_dao.go @@ -36,24 +36,24 @@ func init() { } // 启用条目 -func (this *AdminDAO) EnableAdmin(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *AdminDAO) EnableAdmin(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", AdminStateEnabled). Update() } // 禁用条目 -func (this *AdminDAO) DisableAdmin(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *AdminDAO) DisableAdmin(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", AdminStateDisabled). Update() } // 查找启用中的条目 -func (this *AdminDAO) FindEnabledAdmin(id int64) (*Admin, error) { - result, err := this.Query(). +func (this *AdminDAO) FindEnabledAdmin(tx *dbs.Tx, id int64) (*Admin, error) { + result, err := this.Query(tx). Pk(id). Attr("state", AdminStateEnabled). Find() @@ -64,27 +64,27 @@ func (this *AdminDAO) FindEnabledAdmin(id int64) (*Admin, error) { } // 检查管理员是否存在 -func (this *AdminDAO) ExistEnabledAdmin(adminId int64) (bool, error) { - return this.Query(). +func (this *AdminDAO) ExistEnabledAdmin(tx *dbs.Tx, adminId int64) (bool, error) { + return this.Query(tx). Pk(adminId). State(AdminStateEnabled). Exist() } // 获取管理员名称 -func (this *AdminDAO) FindAdminFullname(adminId int64) (string, error) { - return this.Query(). +func (this *AdminDAO) FindAdminFullname(tx *dbs.Tx, adminId int64) (string, error) { + return this.Query(tx). Pk(adminId). Result("fullname"). FindStringCol("") } // 检查用户名、密码 -func (this *AdminDAO) CheckAdminPassword(username string, encryptedPassword string) (int64, error) { +func (this *AdminDAO) CheckAdminPassword(tx *dbs.Tx, username string, encryptedPassword string) (int64, error) { if len(username) == 0 || len(encryptedPassword) == 0 { return 0, nil } - return this.Query(). + return this.Query(tx). Attr("username", username). Attr("password", encryptedPassword). Attr("state", AdminStateEnabled). @@ -94,8 +94,8 @@ func (this *AdminDAO) CheckAdminPassword(username string, encryptedPassword stri } // 根据用户名查询管理员ID -func (this *AdminDAO) FindAdminIdWithUsername(username string) (int64, error) { - one, err := this.Query(). +func (this *AdminDAO) FindAdminIdWithUsername(tx *dbs.Tx, username string) (int64, error) { + one, err := this.Query(tx). Attr("username", username). State(AdminStateEnabled). ResultPk(). @@ -110,19 +110,19 @@ func (this *AdminDAO) FindAdminIdWithUsername(username string) (int64, error) { } // 更改管理员密码 -func (this *AdminDAO) UpdateAdminPassword(adminId int64, password string) error { +func (this *AdminDAO) UpdateAdminPassword(tx *dbs.Tx, adminId int64, password string) error { if adminId <= 0 { return errors.New("invalid adminId") } op := NewAdminOperator() op.Id = adminId op.Password = stringutil.Md5(password) - err := this.Save(op) + err := this.Save(tx, op) return err } // 创建管理员 -func (this *AdminDAO) CreateAdmin(username string, password string, fullname string, isSuper bool, modulesJSON []byte) (int64, error) { +func (this *AdminDAO) CreateAdmin(tx *dbs.Tx, username string, password string, fullname string, isSuper bool, modulesJSON []byte) (int64, error) { op := NewAdminOperator() op.IsOn = true op.State = AdminStateEnabled @@ -135,7 +135,7 @@ func (this *AdminDAO) CreateAdmin(username string, password string, fullname str } else { op.Modules = "[]" } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -143,19 +143,19 @@ func (this *AdminDAO) CreateAdmin(username string, password string, fullname str } // 修改管理员个人资料 -func (this *AdminDAO) UpdateAdminInfo(adminId int64, fullname string) error { +func (this *AdminDAO) UpdateAdminInfo(tx *dbs.Tx, adminId int64, fullname string) error { if adminId <= 0 { return errors.New("invalid adminId") } op := NewAdminOperator() op.Id = adminId op.Fullname = fullname - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改管理员详细信息 -func (this *AdminDAO) UpdateAdmin(adminId int64, username string, password string, fullname string, isSuper bool, modulesJSON []byte, isOn bool) error { +func (this *AdminDAO) UpdateAdmin(tx *dbs.Tx, adminId int64, username string, password string, fullname string, isSuper bool, modulesJSON []byte, isOn bool) error { if adminId <= 0 { return errors.New("invalid adminId") } @@ -173,13 +173,13 @@ func (this *AdminDAO) UpdateAdmin(adminId int64, username string, password strin op.Modules = "[]" } op.IsOn = isOn - err := this.Save(op) + err := this.Save(tx, op) return err } // 检查用户名是否存在 -func (this *AdminDAO) CheckAdminUsername(adminId int64, username string) (bool, error) { - query := this.Query(). +func (this *AdminDAO) CheckAdminUsername(tx *dbs.Tx, adminId int64, username string) (bool, error) { + query := this.Query(tx). State(AdminStateEnabled). Attr("username", username) if adminId > 0 { @@ -191,7 +191,7 @@ func (this *AdminDAO) CheckAdminUsername(adminId int64, username string) (bool, } // 修改管理员登录信息 -func (this *AdminDAO) UpdateAdminLogin(adminId int64, username string, password string) error { +func (this *AdminDAO) UpdateAdminLogin(tx *dbs.Tx, adminId int64, username string, password string) error { if adminId <= 0 { return errors.New("invalid adminId") } @@ -201,19 +201,19 @@ func (this *AdminDAO) UpdateAdminLogin(adminId int64, username string, password if len(password) > 0 { op.Password = stringutil.Md5(password) } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改管理员可以管理的模块 -func (this *AdminDAO) UpdateAdminModules(adminId int64, allowModulesJSON []byte) error { +func (this *AdminDAO) UpdateAdminModules(tx *dbs.Tx, adminId int64, allowModulesJSON []byte) error { if adminId <= 0 { return errors.New("invalid adminId") } op := NewAdminOperator() op.Id = adminId op.Modules = allowModulesJSON - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -221,8 +221,8 @@ func (this *AdminDAO) UpdateAdminModules(adminId int64, allowModulesJSON []byte) } // 查询所有管理的权限 -func (this *AdminDAO) FindAllAdminModules() (result []*Admin, err error) { - _, err = this.Query(). +func (this *AdminDAO) FindAllAdminModules(tx *dbs.Tx) (result []*Admin, err error) { + _, err = this.Query(tx). State(AdminStateEnabled). Attr("isOn", true). Result("id", "modules", "isSuper"). @@ -232,15 +232,15 @@ func (this *AdminDAO) FindAllAdminModules() (result []*Admin, err error) { } // 计算所有管理员数量 -func (this *AdminDAO) CountAllEnabledAdmins() (int64, error) { - return this.Query(). +func (this *AdminDAO) CountAllEnabledAdmins(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(AdminStateEnabled). Count() } // 列出单页的管理员 -func (this *AdminDAO) ListEnabledAdmins(offset int64, size int64) (result []*Admin, err error) { - _, err = this.Query(). +func (this *AdminDAO) ListEnabledAdmins(tx *dbs.Tx, offset int64, size int64) (result []*Admin, err error) { + _, err = this.Query(tx). State(AdminStateEnabled). Result("id", "isOn", "username", "fullname", "isSuper", "createdAt"). Offset(offset). diff --git a/internal/db/models/api_access_token_dao.go b/internal/db/models/api_access_token_dao.go index 61e211b4..c02ffcd1 100644 --- a/internal/db/models/api_access_token_dao.go +++ b/internal/db/models/api_access_token_dao.go @@ -30,9 +30,9 @@ func init() { } // 生成AccessToken -func (this *APIAccessTokenDAO) GenerateAccessToken(userId int64) (token string, expiresAt int64, err error) { +func (this *APIAccessTokenDAO) GenerateAccessToken(tx *dbs.Tx, userId int64) (token string, expiresAt int64, err error) { // 查询以前的 - accessToken, err := this.Query(). + accessToken, err := this.Query(tx). Attr("userId", userId). Find() if err != nil { @@ -52,13 +52,13 @@ func (this *APIAccessTokenDAO) GenerateAccessToken(userId int64) (token string, op.Token = token op.CreatedAt = time.Now().Unix() op.ExpiredAt = expiresAt - err = this.Save(op) + err = this.Save(tx, op) return } // 查找AccessToken -func (this *APIAccessTokenDAO) FindAccessToken(token string) (*APIAccessToken, error) { - one, err := this.Query(). +func (this *APIAccessTokenDAO) FindAccessToken(tx *dbs.Tx, token string) (*APIAccessToken, error) { + one, err := this.Query(tx). Attr("token", token). Find() if one == nil || err != nil { diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index 685e5a68..a60f26a9 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -39,8 +39,8 @@ func init() { } // 启用条目 -func (this *APINodeDAO) EnableAPINode(id int64) error { - _, err := this.Query(). +func (this *APINodeDAO) EnableAPINode(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", APINodeStateEnabled). Update() @@ -48,8 +48,8 @@ func (this *APINodeDAO) EnableAPINode(id int64) error { } // 禁用条目 -func (this *APINodeDAO) DisableAPINode(id int64) error { - _, err := this.Query(). +func (this *APINodeDAO) DisableAPINode(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", APINodeStateDisabled). Update() @@ -57,8 +57,8 @@ func (this *APINodeDAO) DisableAPINode(id int64) error { } // 查找启用中的条目 -func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) { - result, err := this.Query(). +func (this *APINodeDAO) FindEnabledAPINode(tx *dbs.Tx, id int64) (*APINode, error) { + result, err := this.Query(tx). Pk(id). Attr("state", APINodeStateEnabled). Find() @@ -69,8 +69,8 @@ func (this *APINodeDAO) FindEnabledAPINode(id int64) (*APINode, error) { } // 根据ID和Secret查找节点 -func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, secret string) (*APINode, error) { - one, err := this.Query(). +func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(tx *dbs.Tx, uniqueId string, secret string) (*APINode, error) { + one, err := this.Query(tx). State(APINodeStateEnabled). Attr("uniqueId", uniqueId). Attr("secret", secret). @@ -82,21 +82,21 @@ func (this *APINodeDAO) FindEnabledAPINodeWithUniqueIdAndSecret(uniqueId string, } // 根据主键查找名称 -func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) { - return this.Query(). +func (this *APINodeDAO) FindAPINodeName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建API节点 -func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { - uniqueId, err := this.genUniqueId() +func (this *APINodeDAO) CreateAPINode(tx *dbs.Tx, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { + uniqueId, err := this.genUniqueId(tx) if err != nil { return 0, err } secret := rands.String(32) - err = NewApiTokenDAO().CreateAPIToken(uniqueId, secret, NodeRoleAPI) + err = NewApiTokenDAO().CreateAPIToken(tx, uniqueId, secret, NodeRoleAPI) if err != nil { return } @@ -126,7 +126,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON } op.State = NodeStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return } @@ -135,7 +135,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON } // 修改API节点 -func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) error { +func (this *APINodeDAO) UpdateAPINode(tx *dbs.Tx, nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -173,13 +173,13 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str op.AccessAddrs = "[]" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 列出所有可用API节点 -func (this *APINodeDAO) FindAllEnabledAPINodes() (result []*APINode, err error) { - _, err = this.Query(). +func (this *APINodeDAO) FindAllEnabledAPINodes(tx *dbs.Tx) (result []*APINode, err error) { + _, err = this.Query(tx). Attr("clusterId", 0). // 非集群专用 State(APINodeStateEnabled). Desc("order"). @@ -190,8 +190,8 @@ func (this *APINodeDAO) FindAllEnabledAPINodes() (result []*APINode, err error) } // 列出所有可用而且启用的API节点 -func (this *APINodeDAO) FindAllEnabledAndOnAPINodes() (result []*APINode, err error) { - _, err = this.Query(). +func (this *APINodeDAO) FindAllEnabledAndOnAPINodes(tx *dbs.Tx) (result []*APINode, err error) { + _, err = this.Query(tx). Attr("clusterId", 0). // 非集群专用 Attr("isOn", true). State(APINodeStateEnabled). @@ -203,15 +203,15 @@ func (this *APINodeDAO) FindAllEnabledAndOnAPINodes() (result []*APINode, err er } // 计算API节点数量 -func (this *APINodeDAO) CountAllEnabledAPINodes() (int64, error) { - return this.Query(). +func (this *APINodeDAO) CountAllEnabledAPINodes(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(APINodeStateEnabled). Count() } // 列出单页的API节点 -func (this *APINodeDAO) ListEnabledAPINodes(offset int64, size int64) (result []*APINode, err error) { - _, err = this.Query(). +func (this *APINodeDAO) ListEnabledAPINodes(tx *dbs.Tx, offset int64, size int64) (result []*APINode, err error) { + _, err = this.Query(tx). Attr("clusterId", 0). // 非集群专用 State(APINodeStateEnabled). Offset(offset). @@ -224,7 +224,7 @@ func (this *APINodeDAO) ListEnabledAPINodes(offset int64, size int64) (result [] } // 根据主机名和端口获取ID -func (this *APINodeDAO) FindEnabledAPINodeIdWithAddr(protocol string, host string, port int) (int64, error) { +func (this *APINodeDAO) FindEnabledAPINodeIdWithAddr(tx *dbs.Tx, protocol string, host string, port int) (int64, error) { addr := maps.Map{ "protocol": protocol, "host": host, @@ -235,7 +235,7 @@ func (this *APINodeDAO) FindEnabledAPINodeIdWithAddr(protocol string, host strin return 0, err } - one, err := this.Query(). + one, err := this.Query(tx). State(APINodeStateEnabled). Where("JSON_CONTAINS(accessAddrs, :addr)"). Param("addr", string(addrJSON)). @@ -251,8 +251,8 @@ func (this *APINodeDAO) FindEnabledAPINodeIdWithAddr(protocol string, host strin } // 设置API节点状态 -func (this *APINodeDAO) UpdateAPINodeStatus(apiNodeId int64, statusJSON []byte) error { - _, err := this.Query(). +func (this *APINodeDAO) UpdateAPINodeStatus(tx *dbs.Tx, apiNodeId int64, statusJSON []byte) error { + _, err := this.Query(tx). Pk(apiNodeId). Set("status", statusJSON). Update() @@ -260,10 +260,10 @@ func (this *APINodeDAO) UpdateAPINodeStatus(apiNodeId int64, statusJSON []byte) } // 生成唯一ID -func (this *APINodeDAO) genUniqueId() (string, error) { +func (this *APINodeDAO) genUniqueId(tx *dbs.Tx) (string, error) { for { uniqueId := rands.HexString(32) - ok, err := this.Query(). + ok, err := this.Query(tx). Attr("uniqueId", uniqueId). Exist() if err != nil { diff --git a/internal/db/models/api_node_model_ext.go b/internal/db/models/api_node_model_ext.go index 1162d035..9c398c01 100644 --- a/internal/db/models/api_node_model_ext.go +++ b/internal/db/models/api_node_model_ext.go @@ -3,6 +3,7 @@ package models import ( "encoding/json" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/iwind/TeaGo/dbs" ) // 解析HTTP配置 @@ -25,7 +26,7 @@ func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) { } // 解析HTTPS配置 -func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { +func (this *APINode) DecodeHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig, error) { if !IsNotNull(this.Https) { return nil, nil } @@ -43,7 +44,7 @@ func (this *APINode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { if config.SSLPolicyRef != nil { policyId := config.SSLPolicyRef.SSLPolicyId if policyId > 0 { - sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId) if err != nil { return nil, err } @@ -117,7 +118,7 @@ func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) } // 解析HTTPS配置 -func (this *APINode) DecodeRestHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { +func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig, error) { if this.RestIsOn != 1 { return nil, nil } @@ -138,7 +139,7 @@ func (this *APINode) DecodeRestHTTPS() (*serverconfigs.HTTPSProtocolConfig, erro if config.SSLPolicyRef != nil { policyId := config.SSLPolicyRef.SSLPolicyId if policyId > 0 { - sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId) if err != nil { return nil, err } diff --git a/internal/db/models/api_token_dao.go b/internal/db/models/api_token_dao.go index e53a2340..190b4150 100644 --- a/internal/db/models/api_token_dao.go +++ b/internal/db/models/api_token_dao.go @@ -33,24 +33,24 @@ func init() { } // 启用条目 -func (this *ApiTokenDAO) EnableApiToken(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *ApiTokenDAO) EnableApiToken(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", ApiTokenStateEnabled). Update() } // 禁用条目 -func (this *ApiTokenDAO) DisableApiToken(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *ApiTokenDAO) DisableApiToken(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", ApiTokenStateDisabled). Update() } // 查找启用中的条目 -func (this *ApiTokenDAO) FindEnabledApiToken(id uint32) (*ApiToken, error) { - result, err := this.Query(). +func (this *ApiTokenDAO) FindEnabledApiToken(tx *dbs.Tx, id uint32) (*ApiToken, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ApiTokenStateEnabled). Find() @@ -62,8 +62,8 @@ func (this *ApiTokenDAO) FindEnabledApiToken(id uint32) (*ApiToken, error) { // 获取节点Token信息 // TODO 需要添加缓存 -func (this *ApiTokenDAO) FindEnabledTokenWithNode(nodeId string) (*ApiToken, error) { - one, err := this.Query(). +func (this *ApiTokenDAO) FindEnabledTokenWithNode(tx *dbs.Tx, nodeId string) (*ApiToken, error) { + one, err := this.Query(tx). Attr("nodeId", nodeId). State(ApiTokenStateEnabled). Find() @@ -74,8 +74,8 @@ func (this *ApiTokenDAO) FindEnabledTokenWithNode(nodeId string) (*ApiToken, err } // 根据角色获取节点 -func (this *ApiTokenDAO) FindEnabledTokenWithRole(role string) (*ApiToken, error) { - one, err := this.Query(). +func (this *ApiTokenDAO) FindEnabledTokenWithRole(tx *dbs.Tx, role string) (*ApiToken, error) { + one, err := this.Query(tx). Attr("role", role). State(ApiTokenStateEnabled). Find() @@ -86,12 +86,12 @@ func (this *ApiTokenDAO) FindEnabledTokenWithRole(role string) (*ApiToken, error } // 保存API Token -func (this *ApiTokenDAO) CreateAPIToken(nodeId string, secret string, role NodeRole) error { +func (this *ApiTokenDAO) CreateAPIToken(tx *dbs.Tx, nodeId string, secret string, role NodeRole) error { op := NewApiTokenOperator() op.NodeId = nodeId op.Secret = secret op.Role = role op.State = ApiTokenStateEnabled - err := this.Save(op) + err := this.Save(tx, op) return err } diff --git a/internal/db/models/db_node_dao.go b/internal/db/models/db_node_dao.go index 29a0b43a..96fa6966 100644 --- a/internal/db/models/db_node_dao.go +++ b/internal/db/models/db_node_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *DBNodeDAO) EnableDBNode(id int64) error { - _, err := this.Query(). +func (this *DBNodeDAO) EnableDBNode(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DBNodeStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *DBNodeDAO) EnableDBNode(id int64) error { } // 禁用条目 -func (this *DBNodeDAO) DisableDBNode(id int64) error { - _, err := this.Query(). +func (this *DBNodeDAO) DisableDBNode(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DBNodeStateDisabled). Update() @@ -53,8 +53,8 @@ func (this *DBNodeDAO) DisableDBNode(id int64) error { } // 查找启用中的条目 -func (this *DBNodeDAO) FindEnabledDBNode(id int64) (*DBNode, error) { - result, err := this.Query(). +func (this *DBNodeDAO) FindEnabledDBNode(tx *dbs.Tx, id int64) (*DBNode, error) { + result, err := this.Query(tx). Pk(id). Attr("state", DBNodeStateEnabled). Find() @@ -65,23 +65,23 @@ func (this *DBNodeDAO) FindEnabledDBNode(id int64) (*DBNode, error) { } // 根据主键查找名称 -func (this *DBNodeDAO) FindDBNodeName(id int64) (string, error) { - return this.Query(). +func (this *DBNodeDAO) FindDBNodeName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 计算可用的节点数量 -func (this *DBNodeDAO) CountAllEnabledNodes() (int64, error) { - return this.Query(). +func (this *DBNodeDAO) CountAllEnabledNodes(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(DBNodeStateEnabled). Count() } // 获取单页的节点 -func (this *DBNodeDAO) ListEnabledNodes(offset int64, size int64) (result []*DBNode, err error) { - _, err = this.Query(). +func (this *DBNodeDAO) ListEnabledNodes(tx *dbs.Tx, offset int64, size int64) (result []*DBNode, err error) { + _, err = this.Query(tx). State(DBNodeStateEnabled). Offset(offset). Limit(size). @@ -92,7 +92,7 @@ func (this *DBNodeDAO) ListEnabledNodes(offset int64, size int64) (result []*DBN } // 创建节点 -func (this *DBNodeDAO) CreateDBNode(isOn bool, name string, description string, host string, port int32, database string, username string, password string, charset string) (int64, error) { +func (this *DBNodeDAO) CreateDBNode(tx *dbs.Tx, isOn bool, name string, description string, host string, port int32, database string, username string, password string, charset string) (int64, error) { op := NewDBNodeOperator() op.State = NodeStateEnabled op.IsOn = isOn @@ -104,7 +104,7 @@ func (this *DBNodeDAO) CreateDBNode(isOn bool, name string, description string, op.Username = username op.Password = password op.Charset = charset - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -112,7 +112,7 @@ func (this *DBNodeDAO) CreateDBNode(isOn bool, name string, description string, } // 修改节点 -func (this *DBNodeDAO) UpdateNode(nodeId int64, isOn bool, name string, description string, host string, port int32, database string, username string, password string, charset string) error { +func (this *DBNodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, isOn bool, name string, description string, host string, port int32, database string, username string, password string, charset string) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -127,13 +127,13 @@ func (this *DBNodeDAO) UpdateNode(nodeId int64, isOn bool, name string, descript op.Username = username op.Password = password op.Charset = charset - err := this.Save(op) + err := this.Save(tx, op) return err } // 查找所有可用的数据库节点 -func (this *DBNodeDAO) FindAllEnabledAndOnDBNodes() (result []*DBNode, err error) { - _, err = this.Query(). +func (this *DBNodeDAO) FindAllEnabledAndOnDBNodes(tx *dbs.Tx) (result []*DBNode, err error) { + _, err = this.Query(tx). State(DBNodeStateEnabled). Attr("isOn", true). Slice(&result). diff --git a/internal/db/models/db_node_initializer.go b/internal/db/models/db_node_initializer.go index a4c259b5..197bbda3 100644 --- a/internal/db/models/db_node_initializer.go +++ b/internal/db/models/db_node_initializer.go @@ -149,7 +149,7 @@ func (this *DBNodeInitializer) Start() { // 单次运行 func (this *DBNodeInitializer) loop() error { - dbNodes, err := SharedDBNodeDAO.FindAllEnabledAndOnDBNodes() + dbNodes, err := SharedDBNodeDAO.FindAllEnabledAndOnDBNodes(nil) if err != nil { return err } @@ -218,7 +218,7 @@ func (this *DBNodeInitializer) loop() error { logs.Println("[DB_NODE]create first table in database node failed: " + err.Error()) // 创建节点日志 - createLogErr := SharedNodeLogDAO.CreateLog(NodeRoleDatabase, nodeId, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix()) + createLogErr := SharedNodeLogDAO.CreateLog(nil, NodeRoleDatabase, nodeId, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix()) if createLogErr != nil { logs.Println("[NODE_LOG]" + createLogErr.Error()) } diff --git a/internal/db/models/dns_domain_dao.go b/internal/db/models/dns_domain_dao.go index 5b124c7d..017421eb 100644 --- a/internal/db/models/dns_domain_dao.go +++ b/internal/db/models/dns_domain_dao.go @@ -40,8 +40,8 @@ func init() { } // 启用条目 -func (this *DNSDomainDAO) EnableDNSDomain(id int64) error { - _, err := this.Query(). +func (this *DNSDomainDAO) EnableDNSDomain(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DNSDomainStateEnabled). Update() @@ -49,8 +49,8 @@ func (this *DNSDomainDAO) EnableDNSDomain(id int64) error { } // 禁用条目 -func (this *DNSDomainDAO) DisableDNSDomain(id int64) error { - _, err := this.Query(). +func (this *DNSDomainDAO) DisableDNSDomain(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DNSDomainStateDisabled). Update() @@ -58,8 +58,8 @@ func (this *DNSDomainDAO) DisableDNSDomain(id int64) error { } // 查找启用中的条目 -func (this *DNSDomainDAO) FindEnabledDNSDomain(id int64) (*DNSDomain, error) { - result, err := this.Query(). +func (this *DNSDomainDAO) FindEnabledDNSDomain(tx *dbs.Tx, id int64) (*DNSDomain, error) { + result, err := this.Query(tx). Pk(id). Attr("state", DNSDomainStateEnabled). Find() @@ -70,15 +70,15 @@ func (this *DNSDomainDAO) FindEnabledDNSDomain(id int64) (*DNSDomain, error) { } // 根据主键查找名称 -func (this *DNSDomainDAO) FindDNSDomainName(id int64) (string, error) { - return this.Query(). +func (this *DNSDomainDAO) FindDNSDomainName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建域名 -func (this *DNSDomainDAO) CreateDomain(adminId int64, userId int64, providerId int64, name string) (int64, error) { +func (this *DNSDomainDAO) CreateDomain(tx *dbs.Tx, adminId int64, userId int64, providerId int64, name string) (int64, error) { op := NewDNSDomainOperator() op.ProviderId = providerId op.AdminId = adminId @@ -86,7 +86,7 @@ func (this *DNSDomainDAO) CreateDomain(adminId int64, userId int64, providerId i op.Name = name op.State = DNSDomainStateEnabled op.IsOn = true - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -94,7 +94,7 @@ func (this *DNSDomainDAO) CreateDomain(adminId int64, userId int64, providerId i } // 修改域名 -func (this *DNSDomainDAO) UpdateDomain(domainId int64, name string, isOn bool) error { +func (this *DNSDomainDAO) UpdateDomain(tx *dbs.Tx, domainId int64, name string, isOn bool) error { if domainId <= 0 { return errors.New("invalid domainId") } @@ -102,7 +102,7 @@ func (this *DNSDomainDAO) UpdateDomain(domainId int64, name string, isOn bool) e op.Id = domainId op.Name = name op.IsOn = isOn - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -110,8 +110,8 @@ func (this *DNSDomainDAO) UpdateDomain(domainId int64, name string, isOn bool) e } // 查询一个服务商下面的所有域名 -func (this *DNSDomainDAO) FindAllEnabledDomainsWithProviderId(providerId int64) (result []*DNSDomain, err error) { - _, err = this.Query(). +func (this *DNSDomainDAO) FindAllEnabledDomainsWithProviderId(tx *dbs.Tx, providerId int64) (result []*DNSDomain, err error) { + _, err = this.Query(tx). State(DNSDomainStateEnabled). Attr("providerId", providerId). AscPk(). @@ -121,27 +121,27 @@ func (this *DNSDomainDAO) FindAllEnabledDomainsWithProviderId(providerId int64) } // 计算某个服务商下的域名数量 -func (this *DNSDomainDAO) CountAllEnabledDomainsWithProviderId(providerId int64) (int64, error) { - return this.Query(). +func (this *DNSDomainDAO) CountAllEnabledDomainsWithProviderId(tx *dbs.Tx, providerId int64) (int64, error) { + return this.Query(tx). State(DNSDomainStateEnabled). Attr("providerId", providerId). Count() } // 更新域名数据 -func (this *DNSDomainDAO) UpdateDomainData(domainId int64, data string) error { +func (this *DNSDomainDAO) UpdateDomainData(tx *dbs.Tx, domainId int64, data string) error { if domainId <= 0 { return errors.New("invalid domainId") } op := NewDNSDomainOperator() op.Id = domainId op.Data = data - err := this.Save(op) + err := this.Save(tx, op) return err } // 更新域名解析记录 -func (this *DNSDomainDAO) UpdateDomainRecords(domainId int64, recordsJSON []byte) error { +func (this *DNSDomainDAO) UpdateDomainRecords(tx *dbs.Tx, domainId int64, recordsJSON []byte) error { if domainId <= 0 { return errors.New("invalid domainId") } @@ -149,12 +149,12 @@ func (this *DNSDomainDAO) UpdateDomainRecords(domainId int64, recordsJSON []byte op.Id = domainId op.Records = recordsJSON op.DataUpdatedAt = time.Now().Unix() - err := this.Save(op) + err := this.Save(tx, op) return err } // 更新线路 -func (this *DNSDomainDAO) UpdateDomainRoutes(domainId int64, routesJSON []byte) error { +func (this *DNSDomainDAO) UpdateDomainRoutes(tx *dbs.Tx, domainId int64, routesJSON []byte) error { if domainId <= 0 { return errors.New("invalid domainId") } @@ -162,13 +162,13 @@ func (this *DNSDomainDAO) UpdateDomainRoutes(domainId int64, routesJSON []byte) op.Id = domainId op.Routes = routesJSON op.DataUpdatedAt = time.Now().Unix() - err := this.Save(op) + err := this.Save(tx, op) return err } // 查找域名线路 -func (this *DNSDomainDAO) FindDomainRoutes(domainId int64) ([]*dnsclients.Route, error) { - routes, err := this.Query(). +func (this *DNSDomainDAO) FindDomainRoutes(tx *dbs.Tx, domainId int64) ([]*dnsclients.Route, error) { + routes, err := this.Query(tx). Pk(domainId). Result("routes"). FindStringCol("") @@ -187,8 +187,8 @@ func (this *DNSDomainDAO) FindDomainRoutes(domainId int64) ([]*dnsclients.Route, } // 查找线路名称 -func (this *DNSDomainDAO) FindDomainRouteName(domainId int64, routeCode string) (string, error) { - routes, err := this.FindDomainRoutes(domainId) +func (this *DNSDomainDAO) FindDomainRouteName(tx *dbs.Tx, domainId int64, routeCode string) (string, error) { + routes, err := this.FindDomainRoutes(tx, domainId) if err != nil { return "", err } @@ -201,15 +201,15 @@ func (this *DNSDomainDAO) FindDomainRouteName(domainId int64, routeCode string) } // 判断是否有域名可选 -func (this *DNSDomainDAO) ExistAvailableDomains() (bool, error) { - subQuery, err := SharedDNSProviderDAO.Query(). +func (this *DNSDomainDAO) ExistAvailableDomains(tx *dbs.Tx) (bool, error) { + subQuery, err := SharedDNSProviderDAO.Query(tx). Where("state=1"). // 这里要使用非变量 ResultPk(). AsSQL() if err != nil { return false, err } - return this.Query(). + return this.Query(tx). State(DNSDomainStateEnabled). Attr("isOn", true). Where("providerId IN (" + subQuery + ")"). @@ -217,7 +217,7 @@ func (this *DNSDomainDAO) ExistAvailableDomains() (bool, error) { } // 检查域名解析记录是否存在 -func (this *DNSDomainDAO) ExistDomainRecord(domainId int64, recordName string, recordType string, recordRoute string, recordValue string) (bool, error) { +func (this *DNSDomainDAO) ExistDomainRecord(tx *dbs.Tx, domainId int64, recordName string, recordType string, recordRoute string, recordValue string) (bool, error) { query := maps.Map{ "name": recordName, "type": recordType, @@ -230,7 +230,7 @@ func (this *DNSDomainDAO) ExistDomainRecord(domainId int64, recordName string, r // CNAME兼容点(.)符号 if recordType == "CNAME" && !strings.HasSuffix(recordValue, ".") { - b, err := this.ExistDomainRecord(domainId, recordName, recordType, recordRoute, recordValue+".") + b, err := this.ExistDomainRecord(tx, domainId, recordName, recordType, recordRoute, recordValue+".") if err != nil { return false, err } @@ -240,7 +240,7 @@ func (this *DNSDomainDAO) ExistDomainRecord(domainId int64, recordName string, r } } recordType = strings.ToUpper(recordType) - return this.Query(). + return this.Query(tx). Pk(domainId). Where("JSON_CONTAINS(records, :query)"). Param("query", query.AsJSON()). diff --git a/internal/db/models/dns_provider_dao.go b/internal/db/models/dns_provider_dao.go index 2c6f2247..6deea611 100644 --- a/internal/db/models/dns_provider_dao.go +++ b/internal/db/models/dns_provider_dao.go @@ -36,8 +36,8 @@ func init() { } // 启用条目 -func (this *DNSProviderDAO) EnableDNSProvider(id int64) error { - _, err := this.Query(). +func (this *DNSProviderDAO) EnableDNSProvider(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DNSProviderStateEnabled). Update() @@ -45,8 +45,8 @@ func (this *DNSProviderDAO) EnableDNSProvider(id int64) error { } // 禁用条目 -func (this *DNSProviderDAO) DisableDNSProvider(id int64) error { - _, err := this.Query(). +func (this *DNSProviderDAO) DisableDNSProvider(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", DNSProviderStateDisabled). Update() @@ -54,8 +54,8 @@ func (this *DNSProviderDAO) DisableDNSProvider(id int64) error { } // 查找启用中的条目 -func (this *DNSProviderDAO) FindEnabledDNSProvider(id int64) (*DNSProvider, error) { - result, err := this.Query(). +func (this *DNSProviderDAO) FindEnabledDNSProvider(tx *dbs.Tx, id int64) (*DNSProvider, error) { + result, err := this.Query(tx). Pk(id). Attr("state", DNSProviderStateEnabled). Find() @@ -66,7 +66,7 @@ func (this *DNSProviderDAO) FindEnabledDNSProvider(id int64) (*DNSProvider, erro } // 创建服务商 -func (this *DNSProviderDAO) CreateDNSProvider(adminId int64, userId int64, providerType string, name string, apiParamsJSON []byte) (int64, error) { +func (this *DNSProviderDAO) CreateDNSProvider(tx *dbs.Tx, adminId int64, userId int64, providerType string, name string, apiParamsJSON []byte) (int64, error) { op := NewDNSProviderOperator() op.AdminId = adminId op.UserId = userId @@ -76,7 +76,7 @@ func (this *DNSProviderDAO) CreateDNSProvider(adminId int64, userId int64, provi op.ApiParams = apiParamsJSON } op.State = DNSProviderStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -84,7 +84,7 @@ func (this *DNSProviderDAO) CreateDNSProvider(adminId int64, userId int64, provi } // 修改服务商 -func (this *DNSProviderDAO) UpdateDNSProvider(dnsProviderId int64, name string, apiParamsJSON []byte) error { +func (this *DNSProviderDAO) UpdateDNSProvider(tx *dbs.Tx, dnsProviderId int64, name string, apiParamsJSON []byte) error { if dnsProviderId <= 0 { return errors.New("invalid dnsProviderId") } @@ -98,7 +98,7 @@ func (this *DNSProviderDAO) UpdateDNSProvider(dnsProviderId int64, name string, op.ApiParams = apiParamsJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -106,15 +106,15 @@ func (this *DNSProviderDAO) UpdateDNSProvider(dnsProviderId int64, name string, } // 计算服务商数量 -func (this *DNSProviderDAO) CountAllEnabledDNSProviders(adminId int64, userId int64) (int64, error) { - return NewQuery(this, adminId, userId). +func (this *DNSProviderDAO) CountAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64) (int64, error) { + return NewQuery(tx, this, adminId, userId). State(DNSProviderStateEnabled). Count() } // 列出单页服务商 -func (this *DNSProviderDAO) ListEnabledDNSProviders(adminId int64, userId int64, offset int64, size int64) (result []*DNSProvider, err error) { - _, err = NewQuery(this, adminId, userId). +func (this *DNSProviderDAO) ListEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64, offset int64, size int64) (result []*DNSProvider, err error) { + _, err = NewQuery(tx, this, adminId, userId). State(DNSProviderStateEnabled). Offset(offset). Limit(size). @@ -125,8 +125,8 @@ func (this *DNSProviderDAO) ListEnabledDNSProviders(adminId int64, userId int64, } // 列出所有服务商 -func (this *DNSProviderDAO) FindAllEnabledDNSProviders(adminId int64, userId int64) (result []*DNSProvider, err error) { - _, err = NewQuery(this, adminId, userId). +func (this *DNSProviderDAO) FindAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64) (result []*DNSProvider, err error) { + _, err = NewQuery(tx, this, adminId, userId). State(DNSProviderStateEnabled). DescPk(). Slice(&result). @@ -135,8 +135,8 @@ func (this *DNSProviderDAO) FindAllEnabledDNSProviders(adminId int64, userId int } // 查询某个类型下的所有服务商 -func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(providerType string) (result []*DNSProvider, err error) { - _, err = this.Query(). +func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(tx *dbs.Tx, providerType string) (result []*DNSProvider, err error) { + _, err = this.Query(tx). State(DNSProviderStateEnabled). Attr("type", providerType). DescPk(). @@ -146,8 +146,8 @@ func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(providerType stri } // 更新数据更新时间 -func (this *DNSProviderDAO) UpdateProviderDataUpdatedTime(providerId int64) error { - _, err := this.Query(). +func (this *DNSProviderDAO) UpdateProviderDataUpdatedTime(tx *dbs.Tx, providerId int64) error { + _, err := this.Query(tx). Pk(providerId). Set("dataUpdatedAt", time.Now().Unix()). Update() diff --git a/internal/db/models/file_chunk_dao.go b/internal/db/models/file_chunk_dao.go index c704bc6b..27e6bc63 100644 --- a/internal/db/models/file_chunk_dao.go +++ b/internal/db/models/file_chunk_dao.go @@ -30,11 +30,11 @@ func init() { } // 创建文件Chunk -func (this *FileChunkDAO) CreateFileChunk(fileId int64, data []byte) (int64, error) { +func (this *FileChunkDAO) CreateFileChunk(tx *dbs.Tx, fileId int64, data []byte) (int64, error) { op := NewFileChunkOperator() op.FileId = fileId op.Data = data - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -42,8 +42,8 @@ func (this *FileChunkDAO) CreateFileChunk(fileId int64, data []byte) (int64, err } // 列出所有的文件Chunk -func (this *FileChunkDAO) FindAllFileChunks(fileId int64) (result []*FileChunk, err error) { - _, err = this.Query(). +func (this *FileChunkDAO) FindAllFileChunks(tx *dbs.Tx, fileId int64) (result []*FileChunk, err error) { + _, err = this.Query(tx). Attr("fileId", fileId). AscPk(). Slice(&result). @@ -52,8 +52,8 @@ func (this *FileChunkDAO) FindAllFileChunks(fileId int64) (result []*FileChunk, } // 读取文件的所有片段ID -func (this *FileChunkDAO) FindAllFileChunkIds(fileId int64) ([]int64, error) { - ones, err := this.Query(). +func (this *FileChunkDAO) FindAllFileChunkIds(tx *dbs.Tx, fileId int64) ([]int64, error) { + ones, err := this.Query(tx). Attr("fileId", fileId). AscPk(). ResultPk(). @@ -69,19 +69,19 @@ func (this *FileChunkDAO) FindAllFileChunkIds(fileId int64) ([]int64, error) { } // 删除以前的文件 -func (this *FileChunkDAO) DeleteFileChunks(fileId int64) error { +func (this *FileChunkDAO) DeleteFileChunks(tx *dbs.Tx, fileId int64) error { if fileId <= 0 { return errors.New("invalid fileId") } - _, err := this.Query(). + _, err := this.Query(tx). Attr("fileId", fileId). Delete() return err } // 根据ID查找片段 -func (this *FileChunkDAO) FindFileChunk(chunkId int64) (*FileChunk, error) { - one, err := this.Query(). +func (this *FileChunkDAO) FindFileChunk(tx *dbs.Tx, chunkId int64) (*FileChunk, error) { + one, err := this.Query(tx). Pk(chunkId). Find() if err != nil { diff --git a/internal/db/models/file_dao.go b/internal/db/models/file_dao.go index 657f6c9e..c59004cd 100644 --- a/internal/db/models/file_dao.go +++ b/internal/db/models/file_dao.go @@ -34,8 +34,8 @@ func init() { } // 启用条目 -func (this *FileDAO) EnableFile(id int64) error { - _, err := this.Query(). +func (this *FileDAO) EnableFile(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", FileStateEnabled). Update() @@ -43,8 +43,8 @@ func (this *FileDAO) EnableFile(id int64) error { } // 禁用条目 -func (this *FileDAO) DisableFile(id int64) error { - _, err := this.Query(). +func (this *FileDAO) DisableFile(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", FileStateDisabled). Update() @@ -52,8 +52,8 @@ func (this *FileDAO) DisableFile(id int64) error { } // 查找启用中的条目 -func (this *FileDAO) FindEnabledFile(id int64) (*File, error) { - result, err := this.Query(). +func (this *FileDAO) FindEnabledFile(tx *dbs.Tx, id int64) (*File, error) { + result, err := this.Query(tx). Pk(id). Attr("state", FileStateEnabled). Find() @@ -64,14 +64,14 @@ func (this *FileDAO) FindEnabledFile(id int64) (*File, error) { } // 创建文件 -func (this *FileDAO) CreateFile(businessType, description string, filename string, size int64) (int64, error) { +func (this *FileDAO) CreateFile(tx *dbs.Tx, businessType, description string, filename string, size int64) (int64, error) { op := NewFileOperator() op.Type = businessType op.Description = description op.State = FileStateEnabled op.Size = size op.Filename = filename - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -80,8 +80,8 @@ func (this *FileDAO) CreateFile(businessType, description string, filename strin } // 将文件置为已完成 -func (this *FileDAO) UpdateFileIsFinished(fileId int64) error { - _, err := this.Query(). +func (this *FileDAO) UpdateFileIsFinished(tx *dbs.Tx, fileId int64) error { + _, err := this.Query(tx). Pk(fileId). Set("isFinished", true). Update() diff --git a/internal/db/models/http_access_log_dao.go b/internal/db/models/http_access_log_dao.go index be05616e..867b96ac 100644 --- a/internal/db/models/http_access_log_dao.go +++ b/internal/db/models/http_access_log_dao.go @@ -42,7 +42,7 @@ func NewHTTPAccessLogDAO() *HTTPAccessLogDAO { } // 创建访问日志 -func (this *HTTPAccessLogDAO) CreateHTTPAccessLogs(accessLogs []*pb.HTTPAccessLog) error { +func (this *HTTPAccessLogDAO) CreateHTTPAccessLogs(tx *dbs.Tx, accessLogs []*pb.HTTPAccessLog) error { dao := randomAccessLogDAO() if dao == nil { dao = &HTTPAccessLogDAOWrapper{ @@ -50,11 +50,11 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLogs(accessLogs []*pb.HTTPAccessLo NodeId: 0, } } - return this.CreateHTTPAccessLogsWithDAO(dao, accessLogs) + return this.CreateHTTPAccessLogsWithDAO(tx, dao, accessLogs) } // 使用特定的DAO创建访问日志 -func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(daoWrapper *HTTPAccessLogDAOWrapper, accessLogs []*pb.HTTPAccessLog) error { +func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(tx *dbs.Tx, daoWrapper *HTTPAccessLogDAOWrapper, accessLogs []*pb.HTTPAccessLog) error { if daoWrapper == nil { return errors.New("dao should not be nil") } @@ -90,7 +90,7 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(daoWrapper *HTTPAccess } fields["content"] = content - _, err = dao.Query(). + _, err = dao.Query(tx). Table(table). Sets(fields). Insert() @@ -101,7 +101,7 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(daoWrapper *HTTPAccess if err != nil { return err } - _, err = dao.Query(). + _, err = dao.Query(tx). Table(table). Sets(fields). Insert() @@ -116,7 +116,7 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(daoWrapper *HTTPAccess } // 读取往前的 单页访问日志 -func (this *HTTPAccessLogDAO) ListAccessLogs(lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, hasMore bool, err error) { +func (this *HTTPAccessLogDAO) ListAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, hasMore bool, err error) { if len(day) != 8 { return } @@ -126,18 +126,18 @@ func (this *HTTPAccessLogDAO) ListAccessLogs(lastRequestId string, size int64, d size = 1000 } - result, nextLastRequestId, err = this.listAccessLogs(lastRequestId, size, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) + result, nextLastRequestId, err = this.listAccessLogs(tx, lastRequestId, size, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) if err != nil || int64(len(result)) < size { return } - moreResult, _, _ := this.listAccessLogs(nextLastRequestId, 1, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) + moreResult, _, _ := this.listAccessLogs(tx, nextLastRequestId, 1, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) hasMore = len(moreResult) > 0 return } // 读取往前的单页访问日志 -func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, err error) { +func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, err error) { if size <= 0 { return nil, lastRequestId, nil } @@ -177,7 +177,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, d return } - query := dao.Query() + query := dao.Query(tx) // 条件 if serverId > 0 { @@ -262,7 +262,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, d } // 根据请求ID获取访问日志 -func (this *HTTPAccessLogDAO) FindAccessLogWithRequestId(requestId string) (*HTTPAccessLog, error) { +func (this *HTTPAccessLogDAO) FindAccessLogWithRequestId(tx *dbs.Tx, requestId string) (*HTTPAccessLog, error) { if !regexp.MustCompile(`^\d{30,}`).MatchString(requestId) { return nil, errors.New("invalid requestId") } @@ -301,7 +301,7 @@ func (this *HTTPAccessLogDAO) FindAccessLogWithRequestId(requestId string) (*HTT return } - one, err := dao.Query(). + one, err := dao.Query(tx). Table(tableName). Attr("requestId", requestId). Find() diff --git a/internal/db/models/http_access_log_policy_dao.go b/internal/db/models/http_access_log_policy_dao.go index fd1daf07..4adde44b 100644 --- a/internal/db/models/http_access_log_policy_dao.go +++ b/internal/db/models/http_access_log_policy_dao.go @@ -39,19 +39,19 @@ func init() { func (this *HTTPAccessLogPolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPAccessLogPolicyDAO) EnableHTTPAccessLogPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPAccessLogPolicyDAO) EnableHTTPAccessLogPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPAccessLogPolicyStateEnabled). Update() @@ -59,8 +59,8 @@ func (this *HTTPAccessLogPolicyDAO) EnableHTTPAccessLogPolicy(id int64) error { } // 禁用条目 -func (this *HTTPAccessLogPolicyDAO) DisableHTTPAccessLogPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPAccessLogPolicyDAO) DisableHTTPAccessLogPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPAccessLogPolicyStateDisabled). Update() @@ -68,8 +68,8 @@ func (this *HTTPAccessLogPolicyDAO) DisableHTTPAccessLogPolicy(id int64) error { } // 查找启用中的条目 -func (this *HTTPAccessLogPolicyDAO) FindEnabledHTTPAccessLogPolicy(id int64) (*HTTPAccessLogPolicy, error) { - result, err := this.Query(). +func (this *HTTPAccessLogPolicyDAO) FindEnabledHTTPAccessLogPolicy(tx *dbs.Tx, id int64) (*HTTPAccessLogPolicy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPAccessLogPolicyStateEnabled). Find() @@ -80,16 +80,16 @@ func (this *HTTPAccessLogPolicyDAO) FindEnabledHTTPAccessLogPolicy(id int64) (*H } // 根据主键查找名称 -func (this *HTTPAccessLogPolicyDAO) FindHTTPAccessLogPolicyName(id int64) (string, error) { - return this.Query(). +func (this *HTTPAccessLogPolicyDAO) FindHTTPAccessLogPolicyName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 查找所有可用策略信息 -func (this *HTTPAccessLogPolicyDAO) FindAllEnabledAccessLogPolicies() (result []*HTTPAccessLogPolicy, err error) { - _, err = this.Query(). +func (this *HTTPAccessLogPolicyDAO) FindAllEnabledAccessLogPolicies(tx *dbs.Tx) (result []*HTTPAccessLogPolicy, err error) { + _, err = this.Query(tx). State(HTTPAccessLogPolicyStateEnabled). DescPk(). Slice(&result). @@ -98,8 +98,8 @@ func (this *HTTPAccessLogPolicyDAO) FindAllEnabledAccessLogPolicies() (result [] } // 组合配置 -func (this *HTTPAccessLogPolicyDAO) ComposeAccessLogPolicyConfig(policyId int64) (*serverconfigs.HTTPAccessLogStoragePolicy, error) { - policy, err := this.FindEnabledHTTPAccessLogPolicy(policyId) +func (this *HTTPAccessLogPolicyDAO) ComposeAccessLogPolicyConfig(tx *dbs.Tx, policyId int64) (*serverconfigs.HTTPAccessLogStoragePolicy, error) { + policy, err := this.FindEnabledHTTPAccessLogPolicy(tx, policyId) if err != nil { return nil, err } diff --git a/internal/db/models/http_cache_policy_dao.go b/internal/db/models/http_cache_policy_dao.go index 83059489..fedb1dfc 100644 --- a/internal/db/models/http_cache_policy_dao.go +++ b/internal/db/models/http_cache_policy_dao.go @@ -41,19 +41,19 @@ func init() { func (this *HTTPCachePolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPCachePolicyDAO) EnableHTTPCachePolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPCachePolicyDAO) EnableHTTPCachePolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPCachePolicyStateEnabled). Update() @@ -61,8 +61,8 @@ func (this *HTTPCachePolicyDAO) EnableHTTPCachePolicy(id int64) error { } // 禁用条目 -func (this *HTTPCachePolicyDAO) DisableHTTPCachePolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPCachePolicyDAO) DisableHTTPCachePolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPCachePolicyStateDisabled). Update() @@ -70,8 +70,8 @@ func (this *HTTPCachePolicyDAO) DisableHTTPCachePolicy(id int64) error { } // 查找启用中的条目 -func (this *HTTPCachePolicyDAO) FindEnabledHTTPCachePolicy(id int64) (*HTTPCachePolicy, error) { - result, err := this.Query(). +func (this *HTTPCachePolicyDAO) FindEnabledHTTPCachePolicy(tx *dbs.Tx, id int64) (*HTTPCachePolicy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPCachePolicyStateEnabled). Find() @@ -82,16 +82,16 @@ func (this *HTTPCachePolicyDAO) FindEnabledHTTPCachePolicy(id int64) (*HTTPCache } // 根据主键查找名称 -func (this *HTTPCachePolicyDAO) FindHTTPCachePolicyName(id int64) (string, error) { - return this.Query(). +func (this *HTTPCachePolicyDAO) FindHTTPCachePolicyName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 查找所有可用的缓存策略 -func (this *HTTPCachePolicyDAO) FindAllEnabledCachePolicies() (result []*HTTPCachePolicy, err error) { - _, err = this.Query(). +func (this *HTTPCachePolicyDAO) FindAllEnabledCachePolicies(tx *dbs.Tx) (result []*HTTPCachePolicy, err error) { + _, err = this.Query(tx). State(HTTPCachePolicyStateEnabled). DescPk(). Slice(&result). @@ -100,7 +100,7 @@ func (this *HTTPCachePolicyDAO) FindAllEnabledCachePolicies() (result []*HTTPCac } // 创建缓存策略 -func (this *HTTPCachePolicyDAO) CreateCachePolicy(isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) (int64, error) { +func (this *HTTPCachePolicyDAO) CreateCachePolicy(tx *dbs.Tx, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) (int64, error) { op := NewHTTPCachePolicyOperator() op.State = HTTPCachePolicyStateEnabled op.IsOn = isOn @@ -117,7 +117,7 @@ func (this *HTTPCachePolicyDAO) CreateCachePolicy(isOn bool, name string, descri if len(storageOptionsJSON) > 0 { op.Options = storageOptionsJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -125,7 +125,7 @@ func (this *HTTPCachePolicyDAO) CreateCachePolicy(isOn bool, name string, descri } // 修改缓存策略 -func (this *HTTPCachePolicyDAO) UpdateCachePolicy(policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) error { +func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -146,13 +146,13 @@ func (this *HTTPCachePolicyDAO) UpdateCachePolicy(policyId int64, isOn bool, nam if len(storageOptionsJSON) > 0 { op.Options = storageOptionsJSON } - err := this.Save(op) + err := this.Save(tx, op) return errors.Wrap(err) } // 组合配置 -func (this *HTTPCachePolicyDAO) ComposeCachePolicy(policyId int64) (*serverconfigs.HTTPCachePolicy, error) { - policy, err := this.FindEnabledHTTPCachePolicy(policyId) +func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64) (*serverconfigs.HTTPCachePolicy, error) { + policy, err := this.FindEnabledHTTPCachePolicy(tx, policyId) if err != nil { return nil, err } @@ -203,15 +203,15 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(policyId int64) (*serverconfi } // 计算可用缓存策略数量 -func (this *HTTPCachePolicyDAO) CountAllEnabledHTTPCachePolicies() (int64, error) { - return this.Query(). +func (this *HTTPCachePolicyDAO) CountAllEnabledHTTPCachePolicies(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(HTTPCachePolicyStateEnabled). Count() } // 列出单页的缓存策略 -func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(offset int64, size int64) ([]*serverconfigs.HTTPCachePolicy, error) { - ones, err := this.Query(). +func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(tx *dbs.Tx, offset int64, size int64) ([]*serverconfigs.HTTPCachePolicy, error) { + ones, err := this.Query(tx). State(HTTPCachePolicyStateEnabled). ResultPk(). Offset(offset). @@ -231,7 +231,7 @@ func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(offset int64, size cachePolicies := []*serverconfigs.HTTPCachePolicy{} for _, policyId := range cachePolicyIds { - cachePolicyConfig, err := this.ComposeCachePolicy(policyId) + cachePolicyConfig, err := this.ComposeCachePolicy(tx, policyId) if err != nil { return nil, errors.Wrap(err) } diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 43bf5da6..da90e691 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPFirewallPolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPFirewallPolicyDAO) EnableHTTPFirewallPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallPolicyDAO) EnableHTTPFirewallPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallPolicyStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPFirewallPolicyDAO) EnableHTTPFirewallPolicy(id int64) error { } // 禁用条目 -func (this *HTTPFirewallPolicyDAO) DisableHTTPFirewallPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallPolicyDAO) DisableHTTPFirewallPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallPolicyStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPFirewallPolicyDAO) DisableHTTPFirewallPolicy(id int64) error { } // 查找启用中的条目 -func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicy(id int64) (*HTTPFirewallPolicy, error) { - result, err := this.Query(). +func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicy(tx *dbs.Tx, id int64) (*HTTPFirewallPolicy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPFirewallPolicyStateEnabled). Find() @@ -81,16 +81,16 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicy(id int64) (*HTT } // 根据主键查找名称 -func (this *HTTPFirewallPolicyDAO) FindHTTPFirewallPolicyName(id int64) (string, error) { - return this.Query(). +func (this *HTTPFirewallPolicyDAO) FindHTTPFirewallPolicyName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 查找所有可用策略 -func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies() (result []*HTTPFirewallPolicy, err error) { - _, err = this.Query(). +func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies(tx *dbs.Tx) (result []*HTTPFirewallPolicy, err error) { + _, err = this.Query(tx). State(HTTPFirewallPolicyStateEnabled). DescPk(). Slice(&result). @@ -99,7 +99,7 @@ func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies() (result []*H } // 创建策略 -func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) { +func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) { op := NewHTTPFirewallPolicyOperator() op.State = HTTPFirewallPolicyStateEnabled op.IsOn = isOn @@ -111,12 +111,12 @@ func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(isOn bool, name string, if len(outboundJSON) > 0 { op.Outbound = outboundJSON } - err := this.Save(op) + err := this.Save(tx, op) return types.Int64(op.Id), err } // 修改策略的Inbound和Outbound -func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(policyId int64, inboundJSON []byte, outboundJSON []byte) error { +func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *dbs.Tx, policyId int64, inboundJSON []byte, outboundJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -132,12 +132,12 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(policy } else { op.Outbound = "null" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改策略的Inbound -func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInbound(policyId int64, inboundJSON []byte) error { +func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInbound(tx *dbs.Tx, policyId int64, inboundJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -148,12 +148,12 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInbound(policyId int64, i } else { op.Inbound = "null" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改策略 -func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(policyId int64, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte, blockOptionsJSON []byte) error { +func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte, blockOptionsJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -175,20 +175,20 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(policyId int64, isOn boo if len(blockOptionsJSON) > 0 { op.BlockOptions = blockOptionsJSON } - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算所有可用的策略数量 -func (this *HTTPFirewallPolicyDAO) CountAllEnabledFirewallPolicies() (int64, error) { - return this.Query(). +func (this *HTTPFirewallPolicyDAO) CountAllEnabledFirewallPolicies(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(HTTPFirewallPolicyStateEnabled). Count() } // 列出单页的策略 -func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(offset int64, size int64) (result []*HTTPFirewallPolicy, err error) { - _, err = this.Query(). +func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(tx *dbs.Tx, offset int64, size int64) (result []*HTTPFirewallPolicy, err error) { + _, err = this.Query(tx). State(HTTPFirewallPolicyStateEnabled). Offset(offset). Limit(size). @@ -199,8 +199,8 @@ func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(offset int64, siz } // 组合策略配置 -func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(policyId int64) (*firewallconfigs.HTTPFirewallPolicy, error) { - policy, err := this.FindEnabledHTTPFirewallPolicy(policyId) +func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId int64) (*firewallconfigs.HTTPFirewallPolicy, error) { + policy, err := this.FindEnabledHTTPFirewallPolicy(tx, policyId) if err != nil { return nil, err } @@ -226,7 +226,7 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(policyId int64) (*firew resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{} for _, groupRef := range inbound.GroupRefs { - groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(groupRef.GroupId) + groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId) if err != nil { return nil, err } @@ -254,7 +254,7 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(policyId int64) (*firew resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{} for _, groupRef := range outbound.GroupRefs { - groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(groupRef.GroupId) + groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId) if err != nil { return nil, err } diff --git a/internal/db/models/http_firewall_rule_dao.go b/internal/db/models/http_firewall_rule_dao.go index 4e5f3a52..b1b60bde 100644 --- a/internal/db/models/http_firewall_rule_dao.go +++ b/internal/db/models/http_firewall_rule_dao.go @@ -39,19 +39,19 @@ func init() { func (this *HTTPFirewallRuleDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleStateEnabled). Update() @@ -59,8 +59,8 @@ func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(id int64) error { } // 禁用条目 -func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleStateDisabled). Update() @@ -68,8 +68,8 @@ func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(id int64) error { } // 查找启用中的条目 -func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(id int64) (*HTTPFirewallRule, error) { - result, err := this.Query(). +func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(tx *dbs.Tx, id int64) (*HTTPFirewallRule, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPFirewallRuleStateEnabled). Find() @@ -80,8 +80,8 @@ func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(id int64) (*HTTPFir } // 组合配置 -func (this *HTTPFirewallRuleDAO) ComposeFirewallRule(ruleId int64) (*firewallconfigs.HTTPFirewallRule, error) { - rule, err := this.FindEnabledHTTPFirewallRule(ruleId) +func (this *HTTPFirewallRuleDAO) ComposeFirewallRule(tx *dbs.Tx, ruleId int64) (*firewallconfigs.HTTPFirewallRule, error) { + rule, err := this.FindEnabledHTTPFirewallRule(tx, ruleId) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func (this *HTTPFirewallRuleDAO) ComposeFirewallRule(ruleId int64) (*firewallcon } // 从配置中配置规则 -func (this *HTTPFirewallRuleDAO) CreateOrUpdateRuleFromConfig(ruleConfig *firewallconfigs.HTTPFirewallRule) (int64, error) { +func (this *HTTPFirewallRuleDAO) CreateOrUpdateRuleFromConfig(tx *dbs.Tx, ruleConfig *firewallconfigs.HTTPFirewallRule) (int64, error) { op := NewHTTPFirewallRuleOperator() op.Id = ruleConfig.Id op.State = HTTPFirewallRuleStateEnabled @@ -150,7 +150,7 @@ func (this *HTTPFirewallRuleDAO) CreateOrUpdateRuleFromConfig(ruleConfig *firewa } op.CheckpointOptions = checkpointOptionsJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } diff --git a/internal/db/models/http_firewall_rule_group_dao.go b/internal/db/models/http_firewall_rule_group_dao.go index 626da28b..b2c9cdc0 100644 --- a/internal/db/models/http_firewall_rule_group_dao.go +++ b/internal/db/models/http_firewall_rule_group_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPFirewallRuleGroupDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleGroupStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(id int64) erro } // 禁用条目 -func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleGroupStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(id int64) err } // 查找启用中的条目 -func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(id int64) (*HTTPFirewallRuleGroup, error) { - result, err := this.Query(). +func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(tx *dbs.Tx, id int64) (*HTTPFirewallRuleGroup, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPFirewallRuleGroupStateEnabled). Find() @@ -81,16 +81,16 @@ func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(id int64) } // 根据主键查找名称 -func (this *HTTPFirewallRuleGroupDAO) FindHTTPFirewallRuleGroupName(id int64) (string, error) { - return this.Query(). +func (this *HTTPFirewallRuleGroupDAO) FindHTTPFirewallRuleGroupName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 组合配置 -func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(groupId int64) (*firewallconfigs.HTTPFirewallRuleGroup, error) { - group, err := this.FindEnabledHTTPFirewallRuleGroup(groupId) +func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(tx *dbs.Tx, groupId int64) (*firewallconfigs.HTTPFirewallRuleGroup, error) { + group, err := this.FindEnabledHTTPFirewallRuleGroup(tx, groupId) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(groupId int64) (* return nil, err } for _, setRef := range setRefs { - setConfig, err := SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(setRef.SetId) + setConfig, err := SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(tx, setRef.SetId) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(groupId int64) (* } // 从配置中创建分组 -func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewallconfigs.HTTPFirewallRuleGroup) (int64, error) { +func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(tx *dbs.Tx, groupConfig *firewallconfigs.HTTPFirewallRuleGroup) (int64, error) { op := NewHTTPFirewallRuleGroupOperator() op.IsOn = groupConfig.IsOn op.Name = groupConfig.Name @@ -137,7 +137,7 @@ func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewal // sets setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} for _, setConfig := range groupConfig.Sets { - setId, err := SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(setConfig) + setId, err := SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, setConfig) if err != nil { return 0, err } @@ -151,7 +151,7 @@ func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewal return 0, err } op.Sets = setRefsJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -159,8 +159,8 @@ func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewal } // 修改开启状态 -func (this *HTTPFirewallRuleGroupDAO) UpdateGroupIsOn(groupId int64, isOn bool) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleGroupDAO) UpdateGroupIsOn(tx *dbs.Tx, groupId int64, isOn bool) error { + _, err := this.Query(tx). Pk(groupId). Set("isOn", isOn). Update() @@ -168,13 +168,13 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroupIsOn(groupId int64, isOn bool) } // 创建分组 -func (this *HTTPFirewallRuleGroupDAO) CreateGroup(isOn bool, name string, description string) (int64, error) { +func (this *HTTPFirewallRuleGroupDAO) CreateGroup(tx *dbs.Tx, isOn bool, name string, description string) (int64, error) { op := NewHTTPFirewallRuleGroupOperator() op.State = HTTPFirewallRuleStateEnabled op.IsOn = isOn op.Name = name op.Description = description - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -182,7 +182,7 @@ func (this *HTTPFirewallRuleGroupDAO) CreateGroup(isOn bool, name string, descri } // 修改分组 -func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(groupId int64, isOn bool, name string, description string) error { +func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(tx *dbs.Tx, groupId int64, isOn bool, name string, description string) error { if groupId <= 0 { return errors.New("invalid groupId") } @@ -191,18 +191,18 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(groupId int64, isOn bool, name op.IsOn = isOn op.Name = name op.Description = description - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改分组中的规则集 -func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(groupId int64, setsJSON []byte) error { +func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(tx *dbs.Tx, groupId int64, setsJSON []byte) error { if groupId <= 0 { return errors.New("invalid groupId") } op := NewHTTPFirewallRuleGroupOperator() op.Id = groupId op.Sets = setsJSON - err := this.Save(op) + err := this.Save(tx, op) return err } diff --git a/internal/db/models/http_firewall_rule_set_dao.go b/internal/db/models/http_firewall_rule_set_dao.go index 0fd410ea..d6589796 100644 --- a/internal/db/models/http_firewall_rule_set_dao.go +++ b/internal/db/models/http_firewall_rule_set_dao.go @@ -41,19 +41,19 @@ func init() { func (this *HTTPFirewallRuleSetDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleSetStateEnabled). Update() @@ -61,8 +61,8 @@ func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(id int64) error { } // 禁用条目 -func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(id int64) error { - _, err := this.Query(). +func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPFirewallRuleSetStateDisabled). Update() @@ -70,8 +70,8 @@ func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(id int64) error { } // 查找启用中的条目 -func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(id int64) (*HTTPFirewallRuleSet, error) { - result, err := this.Query(). +func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(tx *dbs.Tx, id int64) (*HTTPFirewallRuleSet, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPFirewallRuleSetStateEnabled). Find() @@ -82,16 +82,16 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(id int64) (*H } // 根据主键查找名称 -func (this *HTTPFirewallRuleSetDAO) FindHTTPFirewallRuleSetName(id int64) (string, error) { - return this.Query(). +func (this *HTTPFirewallRuleSetDAO) FindHTTPFirewallRuleSetName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 组合配置 -func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(setId int64) (*firewallconfigs.HTTPFirewallRuleSet, error) { - set, err := this.FindEnabledHTTPFirewallRuleSet(setId) +func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int64) (*firewallconfigs.HTTPFirewallRuleSet, error) { + set, err := this.FindEnabledHTTPFirewallRuleSet(tx, setId) if err != nil { return nil, err } @@ -113,7 +113,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(setId int64) (*firewa return nil, err } for _, ruleRef := range ruleRefs { - ruleConfig, err := SharedHTTPFirewallRuleDAO.ComposeFirewallRule(ruleRef.RuleId) + ruleConfig, err := SharedHTTPFirewallRuleDAO.ComposeFirewallRule(tx, ruleRef.RuleId) if err != nil { return nil, err } @@ -138,7 +138,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(setId int64) (*firewa } // 从配置中创建规则集 -func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(setConfig *firewallconfigs.HTTPFirewallRuleSet) (int64, error) { +func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(tx *dbs.Tx, setConfig *firewallconfigs.HTTPFirewallRuleSet) (int64, error) { op := NewHTTPFirewallRuleSetOperator() op.State = HTTPFirewallRuleSetStateEnabled op.Id = setConfig.Id @@ -162,7 +162,7 @@ func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(setConfig *firew // rules ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{} for _, ruleConfig := range setConfig.Rules { - ruleId, err := SharedHTTPFirewallRuleDAO.CreateOrUpdateRuleFromConfig(ruleConfig) + ruleId, err := SharedHTTPFirewallRuleDAO.CreateOrUpdateRuleFromConfig(tx, ruleConfig) if err != nil { return 0, err } @@ -176,7 +176,7 @@ func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(setConfig *firew return 0, err } op.Rules = ruleRefsJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -184,11 +184,11 @@ func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(setConfig *firew } // 设置是否启用 -func (this *HTTPFirewallRuleSetDAO) UpdateRuleSetIsOn(ruleSetId int64, isOn bool) error { +func (this *HTTPFirewallRuleSetDAO) UpdateRuleSetIsOn(tx *dbs.Tx, ruleSetId int64, isOn bool) error { if ruleSetId <= 0 { return errors.New("invalid ruleSetId") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(ruleSetId). Set("isOn", isOn). Update() diff --git a/internal/db/models/http_gzip_dao.go b/internal/db/models/http_gzip_dao.go index 98c11b01..07a9c1e8 100644 --- a/internal/db/models/http_gzip_dao.go +++ b/internal/db/models/http_gzip_dao.go @@ -41,19 +41,19 @@ func init() { func (this *HTTPGzipDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPGzipDAO) EnableHTTPGzip(id int64) error { - _, err := this.Query(). +func (this *HTTPGzipDAO) EnableHTTPGzip(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPGzipStateEnabled). Update() @@ -61,8 +61,8 @@ func (this *HTTPGzipDAO) EnableHTTPGzip(id int64) error { } // 禁用条目 -func (this *HTTPGzipDAO) DisableHTTPGzip(id int64) error { - _, err := this.Query(). +func (this *HTTPGzipDAO) DisableHTTPGzip(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPGzipStateDisabled). Update() @@ -70,8 +70,8 @@ func (this *HTTPGzipDAO) DisableHTTPGzip(id int64) error { } // 查找启用中的条目 -func (this *HTTPGzipDAO) FindEnabledHTTPGzip(id int64) (*HTTPGzip, error) { - result, err := this.Query(). +func (this *HTTPGzipDAO) FindEnabledHTTPGzip(tx *dbs.Tx, id int64) (*HTTPGzip, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPGzipStateEnabled). Find() @@ -82,8 +82,8 @@ func (this *HTTPGzipDAO) FindEnabledHTTPGzip(id int64) (*HTTPGzip, error) { } // 组合配置 -func (this *HTTPGzipDAO) ComposeGzipConfig(gzipId int64) (*serverconfigs.HTTPGzipConfig, error) { - gzip, err := this.FindEnabledHTTPGzip(gzipId) +func (this *HTTPGzipDAO) ComposeGzipConfig(tx *dbs.Tx, gzipId int64) (*serverconfigs.HTTPGzipConfig, error) { + gzip, err := this.FindEnabledHTTPGzip(tx, gzipId) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func (this *HTTPGzipDAO) ComposeGzipConfig(gzipId int64) (*serverconfigs.HTTPGzi } // 创建Gzip -func (this *HTTPGzipDAO) CreateGzip(level int, minLengthJSON []byte, maxLengthJSON []byte, condsJSON []byte) (int64, error) { +func (this *HTTPGzipDAO) CreateGzip(tx *dbs.Tx, level int, minLengthJSON []byte, maxLengthJSON []byte, condsJSON []byte) (int64, error) { op := NewHTTPGzipOperator() op.State = HTTPGzipStateEnabled op.IsOn = true @@ -140,7 +140,7 @@ func (this *HTTPGzipDAO) CreateGzip(level int, minLengthJSON []byte, maxLengthJS if len(condsJSON) > 0 { op.Conds = JSONBytes(condsJSON) } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -148,7 +148,7 @@ func (this *HTTPGzipDAO) CreateGzip(level int, minLengthJSON []byte, maxLengthJS } // 修改Gzip -func (this *HTTPGzipDAO) UpdateGzip(gzipId int64, level int, minLengthJSON []byte, maxLengthJSON []byte, condsJSON []byte) error { +func (this *HTTPGzipDAO) UpdateGzip(tx *dbs.Tx, gzipId int64, level int, minLengthJSON []byte, maxLengthJSON []byte, condsJSON []byte) error { if gzipId <= 0 { return errors.New("invalid gzipId") } @@ -164,6 +164,6 @@ func (this *HTTPGzipDAO) UpdateGzip(gzipId int64, level int, minLengthJSON []byt if len(condsJSON) > 0 { op.Conds = JSONBytes(condsJSON) } - err := this.Save(op) + err := this.Save(tx, op) return err } diff --git a/internal/db/models/http_header_dao.go b/internal/db/models/http_header_dao.go index e9998d29..c6deb92a 100644 --- a/internal/db/models/http_header_dao.go +++ b/internal/db/models/http_header_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPHeaderDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPHeaderDAO) EnableHTTPHeader(id int64) error { - _, err := this.Query(). +func (this *HTTPHeaderDAO) EnableHTTPHeader(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPHeaderStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPHeaderDAO) EnableHTTPHeader(id int64) error { } // 禁用条目 -func (this *HTTPHeaderDAO) DisableHTTPHeader(id uint32) error { - _, err := this.Query(). +func (this *HTTPHeaderDAO) DisableHTTPHeader(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPHeaderStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPHeaderDAO) DisableHTTPHeader(id uint32) error { } // 查找启用中的条目 -func (this *HTTPHeaderDAO) FindEnabledHTTPHeader(id int64) (*HTTPHeader, error) { - result, err := this.Query(). +func (this *HTTPHeaderDAO) FindEnabledHTTPHeader(tx *dbs.Tx, id int64) (*HTTPHeader, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPHeaderStateEnabled). Find() @@ -81,15 +81,15 @@ func (this *HTTPHeaderDAO) FindEnabledHTTPHeader(id int64) (*HTTPHeader, error) } // 根据主键查找名称 -func (this *HTTPHeaderDAO) FindHTTPHeaderName(id int64) (string, error) { - return this.Query(). +func (this *HTTPHeaderDAO) FindHTTPHeaderName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建Header -func (this *HTTPHeaderDAO) CreateHeader(name string, value string) (int64, error) { +func (this *HTTPHeaderDAO) CreateHeader(tx *dbs.Tx, name string, value string) (int64, error) { op := NewHTTPHeaderOperator() op.State = HTTPHeaderStateEnabled op.IsOn = true @@ -105,7 +105,7 @@ func (this *HTTPHeaderDAO) CreateHeader(name string, value string) (int64, error } op.Status = statusJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -113,7 +113,7 @@ func (this *HTTPHeaderDAO) CreateHeader(name string, value string) (int64, error } // 修改Header -func (this *HTTPHeaderDAO) UpdateHeader(headerId int64, name string, value string) error { +func (this *HTTPHeaderDAO) UpdateHeader(tx *dbs.Tx, headerId int64, name string, value string) error { if headerId <= 0 { return errors.New("invalid headerId") } @@ -122,7 +122,7 @@ func (this *HTTPHeaderDAO) UpdateHeader(headerId int64, name string, value strin op.Id = headerId op.Name = name op.Value = value - err := this.Save(op) + err := this.Save(tx, op) // TODO 更新相关配置 @@ -130,8 +130,8 @@ func (this *HTTPHeaderDAO) UpdateHeader(headerId int64, name string, value strin } // 组合Header配置 -func (this *HTTPHeaderDAO) ComposeHeaderConfig(headerId int64) (*shared.HTTPHeaderConfig, error) { - header, err := this.FindEnabledHTTPHeader(headerId) +func (this *HTTPHeaderDAO) ComposeHeaderConfig(tx *dbs.Tx, headerId int64) (*shared.HTTPHeaderConfig, error) { + header, err := this.FindEnabledHTTPHeader(tx, headerId) if err != nil { return nil, err } diff --git a/internal/db/models/http_header_policy_dao.go b/internal/db/models/http_header_policy_dao.go index aa782cdd..c4f95c6b 100644 --- a/internal/db/models/http_header_policy_dao.go +++ b/internal/db/models/http_header_policy_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPHeaderPolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPHeaderPolicyDAO) EnableHTTPHeaderPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPHeaderPolicyDAO) EnableHTTPHeaderPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPHeaderPolicyStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPHeaderPolicyDAO) EnableHTTPHeaderPolicy(id int64) error { } // 禁用条目 -func (this *HTTPHeaderPolicyDAO) DisableHTTPHeaderPolicy(id int64) error { - _, err := this.Query(). +func (this *HTTPHeaderPolicyDAO) DisableHTTPHeaderPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPHeaderPolicyStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPHeaderPolicyDAO) DisableHTTPHeaderPolicy(id int64) error { } // 查找启用中的条目 -func (this *HTTPHeaderPolicyDAO) FindEnabledHTTPHeaderPolicy(id int64) (*HTTPHeaderPolicy, error) { - result, err := this.Query(). +func (this *HTTPHeaderPolicyDAO) FindEnabledHTTPHeaderPolicy(tx *dbs.Tx, id int64) (*HTTPHeaderPolicy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPHeaderPolicyStateEnabled). Find() @@ -81,11 +81,11 @@ func (this *HTTPHeaderPolicyDAO) FindEnabledHTTPHeaderPolicy(id int64) (*HTTPHea } // 创建策略 -func (this *HTTPHeaderPolicyDAO) CreateHeaderPolicy() (int64, error) { +func (this *HTTPHeaderPolicyDAO) CreateHeaderPolicy(tx *dbs.Tx) (int64, error) { op := NewHTTPHeaderPolicyOperator() op.IsOn = true op.State = HTTPHeaderPolicyStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -93,7 +93,7 @@ func (this *HTTPHeaderPolicyDAO) CreateHeaderPolicy() (int64, error) { } // 修改AddHeaders -func (this *HTTPHeaderPolicyDAO) UpdateAddingHeaders(policyId int64, headersJSON []byte) error { +func (this *HTTPHeaderPolicyDAO) UpdateAddingHeaders(tx *dbs.Tx, policyId int64, headersJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -101,13 +101,13 @@ func (this *HTTPHeaderPolicyDAO) UpdateAddingHeaders(policyId int64, headersJSON op := NewHTTPHeaderPolicyOperator() op.Id = policyId op.AddHeaders = headersJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改SetHeaders -func (this *HTTPHeaderPolicyDAO) UpdateSettingHeaders(policyId int64, headersJSON []byte) error { +func (this *HTTPHeaderPolicyDAO) UpdateSettingHeaders(tx *dbs.Tx, policyId int64, headersJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -115,13 +115,13 @@ func (this *HTTPHeaderPolicyDAO) UpdateSettingHeaders(policyId int64, headersJSO op := NewHTTPHeaderPolicyOperator() op.Id = policyId op.SetHeaders = headersJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改ReplaceHeaders -func (this *HTTPHeaderPolicyDAO) UpdateReplacingHeaders(policyId int64, headersJSON []byte) error { +func (this *HTTPHeaderPolicyDAO) UpdateReplacingHeaders(tx *dbs.Tx, policyId int64, headersJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -129,13 +129,13 @@ func (this *HTTPHeaderPolicyDAO) UpdateReplacingHeaders(policyId int64, headersJ op := NewHTTPHeaderPolicyOperator() op.Id = policyId op.ReplaceHeaders = headersJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改AddTrailers -func (this *HTTPHeaderPolicyDAO) UpdateAddingTrailers(policyId int64, headersJSON []byte) error { +func (this *HTTPHeaderPolicyDAO) UpdateAddingTrailers(tx *dbs.Tx, policyId int64, headersJSON []byte) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -143,13 +143,13 @@ func (this *HTTPHeaderPolicyDAO) UpdateAddingTrailers(policyId int64, headersJSO op := NewHTTPHeaderPolicyOperator() op.Id = policyId op.AddTrailers = headersJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改DeleteHeaders -func (this *HTTPHeaderPolicyDAO) UpdateDeletingHeaders(policyId int64, headerNames []string) error { +func (this *HTTPHeaderPolicyDAO) UpdateDeletingHeaders(tx *dbs.Tx, policyId int64, headerNames []string) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -162,14 +162,14 @@ func (this *HTTPHeaderPolicyDAO) UpdateDeletingHeaders(policyId int64, headerNam op := NewHTTPHeaderPolicyOperator() op.Id = policyId op.DeleteHeaders = string(namesJSON) - err = this.Save(op) + err = this.Save(tx, op) return err } // 组合配置 -func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(headerPolicyId int64) (*shared.HTTPHeaderPolicy, error) { - policy, err := this.FindEnabledHTTPHeaderPolicy(headerPolicyId) +func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(tx *dbs.Tx, headerPolicyId int64) (*shared.HTTPHeaderPolicy, error) { + policy, err := this.FindEnabledHTTPHeaderPolicy(tx, headerPolicyId) if err != nil { return nil, err } @@ -190,7 +190,7 @@ func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(headerPolicyId int64) } if len(refs) > 0 { for _, ref := range refs { - headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(ref.HeaderId) + headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, ref.HeaderId) if err != nil { return nil, err } @@ -209,7 +209,7 @@ func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(headerPolicyId int64) if len(refs) > 0 { resultRefs := []*shared.HTTPHeaderRef{} for _, ref := range refs { - headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(ref.HeaderId) + headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, ref.HeaderId) if err != nil { return nil, err } @@ -233,7 +233,7 @@ func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(headerPolicyId int64) if len(refs) > 0 { resultRefs := []*shared.HTTPHeaderRef{} for _, ref := range refs { - headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(ref.HeaderId) + headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, ref.HeaderId) if err != nil { return nil, err } @@ -257,7 +257,7 @@ func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(headerPolicyId int64) if len(refs) > 0 { resultRefs := []*shared.HTTPHeaderRef{} for _, ref := range refs { - headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(ref.HeaderId) + headerConfig, err := SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, ref.HeaderId) if err != nil { return nil, err } diff --git a/internal/db/models/http_location_dao.go b/internal/db/models/http_location_dao.go index 360bf10c..71434e7d 100644 --- a/internal/db/models/http_location_dao.go +++ b/internal/db/models/http_location_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPLocationDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPLocationDAO) EnableHTTPLocation(id int64) error { - _, err := this.Query(). +func (this *HTTPLocationDAO) EnableHTTPLocation(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPLocationStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPLocationDAO) EnableHTTPLocation(id int64) error { } // 禁用条目 -func (this *HTTPLocationDAO) DisableHTTPLocation(id int64) error { - _, err := this.Query(). +func (this *HTTPLocationDAO) DisableHTTPLocation(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPLocationStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPLocationDAO) DisableHTTPLocation(id int64) error { } // 查找启用中的条目 -func (this *HTTPLocationDAO) FindEnabledHTTPLocation(id int64) (*HTTPLocation, error) { - result, err := this.Query(). +func (this *HTTPLocationDAO) FindEnabledHTTPLocation(tx *dbs.Tx, id int64) (*HTTPLocation, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPLocationStateEnabled). Find() @@ -81,15 +81,15 @@ func (this *HTTPLocationDAO) FindEnabledHTTPLocation(id int64) (*HTTPLocation, e } // 根据主键查找名称 -func (this *HTTPLocationDAO) FindHTTPLocationName(id int64) (string, error) { - return this.Query(). +func (this *HTTPLocationDAO) FindHTTPLocationName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建路径规则 -func (this *HTTPLocationDAO) CreateLocation(parentId int64, name string, pattern string, description string, isBreak bool) (int64, error) { +func (this *HTTPLocationDAO) CreateLocation(tx *dbs.Tx, parentId int64, name string, pattern string, description string, isBreak bool) (int64, error) { op := NewHTTPLocationOperator() op.IsOn = true op.State = HTTPLocationStateEnabled @@ -98,7 +98,7 @@ func (this *HTTPLocationDAO) CreateLocation(parentId int64, name string, pattern op.Pattern = pattern op.Description = description op.IsBreak = isBreak - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -106,7 +106,7 @@ func (this *HTTPLocationDAO) CreateLocation(parentId int64, name string, pattern } // 修改路径规则 -func (this *HTTPLocationDAO) UpdateLocation(locationId int64, name string, pattern string, description string, isOn bool, isBreak bool) error { +func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name string, pattern string, description string, isOn bool, isBreak bool) error { if locationId <= 0 { return errors.New("invalid locationId") } @@ -117,13 +117,13 @@ func (this *HTTPLocationDAO) UpdateLocation(locationId int64, name string, patte op.Description = description op.IsOn = isOn op.IsBreak = isBreak - err := this.Save(op) + err := this.Save(tx, op) return err } // 组合配置 -func (this *HTTPLocationDAO) ComposeLocationConfig(locationId int64) (*serverconfigs.HTTPLocationConfig, error) { - location, err := this.FindEnabledHTTPLocation(locationId) +func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64) (*serverconfigs.HTTPLocationConfig, error) { + location, err := this.FindEnabledHTTPLocation(tx, locationId) if err != nil { return nil, err } @@ -142,7 +142,7 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(locationId int64) (*servercon // web if location.WebId > 0 { - webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(int64(location.WebId)) + webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(location.WebId)) if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(locationId int64) (*servercon } config.ReverseProxyRef = ref if ref.ReverseProxyId > 0 { - reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(ref.ReverseProxyId) + reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, ref.ReverseProxyId) if err != nil { return nil, err } @@ -170,8 +170,8 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(locationId int64) (*servercon } // 查找反向代理设置 -func (this *HTTPLocationDAO) FindLocationReverseProxy(locationId int64) (*serverconfigs.ReverseProxyRef, error) { - refString, err := this.Query(). +func (this *HTTPLocationDAO) FindLocationReverseProxy(tx *dbs.Tx, locationId int64) (*serverconfigs.ReverseProxyRef, error) { + refString, err := this.Query(tx). Pk(locationId). Result("reverseProxy"). FindStringCol("") @@ -190,20 +190,20 @@ func (this *HTTPLocationDAO) FindLocationReverseProxy(locationId int64) (*server } // 更改反向代理设置 -func (this *HTTPLocationDAO) UpdateLocationReverseProxy(locationId int64, reverseProxyJSON []byte) error { +func (this *HTTPLocationDAO) UpdateLocationReverseProxy(tx *dbs.Tx, locationId int64, reverseProxyJSON []byte) error { if locationId <= 0 { return errors.New("invalid locationId") } op := NewHTTPLocationOperator() op.Id = locationId op.ReverseProxy = JSONBytes(reverseProxyJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 查找WebId -func (this *HTTPLocationDAO) FindLocationWebId(locationId int64) (int64, error) { - webId, err := this.Query(). +func (this *HTTPLocationDAO) FindLocationWebId(tx *dbs.Tx, locationId int64) (int64, error) { + webId, err := this.Query(tx). Pk(locationId). Result("webId"). FindIntCol(0) @@ -211,25 +211,25 @@ func (this *HTTPLocationDAO) FindLocationWebId(locationId int64) (int64, error) } // 更改Web设置 -func (this *HTTPLocationDAO) UpdateLocationWeb(locationId int64, webId int64) error { +func (this *HTTPLocationDAO) UpdateLocationWeb(tx *dbs.Tx, locationId int64, webId int64) error { if locationId <= 0 { return errors.New("invalid locationId") } op := NewHTTPLocationOperator() op.Id = locationId op.WebId = webId - err := this.Save(op) + err := this.Save(tx, op) return err } // 转换引用为配置 -func (this *HTTPLocationDAO) ConvertLocationRefs(refs []*serverconfigs.HTTPLocationRef) (locations []*serverconfigs.HTTPLocationConfig, err error) { +func (this *HTTPLocationDAO) ConvertLocationRefs(tx *dbs.Tx, refs []*serverconfigs.HTTPLocationRef) (locations []*serverconfigs.HTTPLocationConfig, err error) { for _, ref := range refs { - config, err := this.ComposeLocationConfig(ref.LocationId) + config, err := this.ComposeLocationConfig(tx, ref.LocationId) if err != nil { return nil, err } - children, err := this.ConvertLocationRefs(ref.Children) + children, err := this.ConvertLocationRefs(tx, ref.Children) if err != nil { return nil, err } @@ -241,11 +241,11 @@ func (this *HTTPLocationDAO) ConvertLocationRefs(refs []*serverconfigs.HTTPLocat } // 根据WebId查找LocationId -func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(webId int64) (locationId int64, err error) { +func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(tx *dbs.Tx, webId int64) (locationId int64, err error) { if webId <= 0 { return } - return this.Query(). + return this.Query(tx). Attr("webId", webId). ResultPk(). FindInt64Col(0) diff --git a/internal/db/models/http_page_dao.go b/internal/db/models/http_page_dao.go index 6a0744cd..c1603400 100644 --- a/internal/db/models/http_page_dao.go +++ b/internal/db/models/http_page_dao.go @@ -40,19 +40,19 @@ func init() { func (this *HTTPPageDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPPageDAO) EnableHTTPPage(id int64) error { - _, err := this.Query(). +func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPPageStateEnabled). Update() @@ -60,8 +60,8 @@ func (this *HTTPPageDAO) EnableHTTPPage(id int64) error { } // 禁用条目 -func (this *HTTPPageDAO) DisableHTTPPage(id int64) error { - _, err := this.Query(). +func (this *HTTPPageDAO) DisableHTTPPage(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPPageStateDisabled). Update() @@ -69,8 +69,8 @@ func (this *HTTPPageDAO) DisableHTTPPage(id int64) error { } // 查找启用中的条目 -func (this *HTTPPageDAO) FindEnabledHTTPPage(id int64) (*HTTPPage, error) { - result, err := this.Query(). +func (this *HTTPPageDAO) FindEnabledHTTPPage(tx *dbs.Tx, id int64) (*HTTPPage, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPPageStateEnabled). Find() @@ -81,7 +81,7 @@ func (this *HTTPPageDAO) FindEnabledHTTPPage(id int64) (*HTTPPage, error) { } // 创建Page -func (this *HTTPPageDAO) CreatePage(statusList []string, url string, newStatus int) (pageId int64, err error) { +func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, statusList []string, url string, newStatus int) (pageId int64, err error) { op := NewHTTPPageOperator() op.IsOn = true op.State = HTTPPageStateEnabled @@ -95,7 +95,7 @@ func (this *HTTPPageDAO) CreatePage(statusList []string, url string, newStatus i } op.Url = url op.NewStatus = newStatus - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -104,7 +104,7 @@ func (this *HTTPPageDAO) CreatePage(statusList []string, url string, newStatus i } // 修改Page -func (this *HTTPPageDAO) UpdatePage(pageId int64, statusList []string, url string, newStatus int) error { +func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []string, url string, newStatus int) error { if pageId <= 0 { return errors.New("invalid pageId") } @@ -125,14 +125,14 @@ func (this *HTTPPageDAO) UpdatePage(pageId int64, statusList []string, url strin op.Url = url op.NewStatus = newStatus - err = this.Save(op) + err = this.Save(tx, op) return err } // 组合配置 -func (this *HTTPPageDAO) ComposePageConfig(pageId int64) (*serverconfigs.HTTPPageConfig, error) { - page, err := this.FindEnabledHTTPPage(pageId) +func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64) (*serverconfigs.HTTPPageConfig, error) { + page, err := this.FindEnabledHTTPPage(tx, pageId) if err != nil { return nil, err } diff --git a/internal/db/models/http_rewrite_rule_dao.go b/internal/db/models/http_rewrite_rule_dao.go index 1dce4914..31d39a0a 100644 --- a/internal/db/models/http_rewrite_rule_dao.go +++ b/internal/db/models/http_rewrite_rule_dao.go @@ -39,19 +39,19 @@ func init() { func (this *HTTPRewriteRuleDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(id int64) error { - _, err := this.Query(). +func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPRewriteRuleStateEnabled). Update() @@ -59,8 +59,8 @@ func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(id int64) error { } // 禁用条目 -func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(id int64) error { - _, err := this.Query(). +func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPRewriteRuleStateDisabled). Update() @@ -68,8 +68,8 @@ func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(id int64) error { } // 查找启用中的条目 -func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(id int64) (*HTTPRewriteRule, error) { - result, err := this.Query(). +func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(tx *dbs.Tx, id int64) (*HTTPRewriteRule, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPRewriteRuleStateEnabled). Find() @@ -80,8 +80,8 @@ func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(id int64) (*HTTPRewri } // 构造配置 -func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(rewriteRuleId int64) (*serverconfigs.HTTPRewriteRule, error) { - rule, err := this.FindEnabledHTTPRewriteRule(rewriteRuleId) +func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int64) (*serverconfigs.HTTPRewriteRule, error) { + rule, err := this.FindEnabledHTTPRewriteRule(tx, rewriteRuleId) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(rewriteRuleId int64) (*server } // 创建规则 -func (this *HTTPRewriteRuleDAO) CreateRewriteRule(pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) (int64, error) { +func (this *HTTPRewriteRuleDAO) CreateRewriteRule(tx *dbs.Tx, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) (int64, error) { op := NewHTTPRewriteRuleOperator() op.State = HTTPRewriteRuleStateEnabled op.IsOn = isOn @@ -115,12 +115,12 @@ func (this *HTTPRewriteRuleDAO) CreateRewriteRule(pattern string, replace string op.IsBreak = isBreak op.WithQuery = withQuery op.ProxyHost = proxyHost - err := this.Save(op) + err := this.Save(tx, op) return types.Int64(op.Id), err } // 修改规则 -func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(rewriteRuleId int64, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) error { +func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int64, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) error { if rewriteRuleId <= 0 { return errors.New("invalid rewriteRuleId") } @@ -134,6 +134,6 @@ func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(rewriteRuleId int64, pattern s op.IsBreak = isBreak op.WithQuery = withQuery op.ProxyHost = proxyHost - err := this.Save(op) + err := this.Save(tx, op) return err } diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 2f75f92d..405a41d4 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -43,19 +43,19 @@ func init() { func (this *HTTPWebDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *HTTPWebDAO) EnableHTTPWeb(id int64) error { - _, err := this.Query(). +func (this *HTTPWebDAO) EnableHTTPWeb(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPWebStateEnabled). Update() @@ -63,8 +63,8 @@ func (this *HTTPWebDAO) EnableHTTPWeb(id int64) error { } // 禁用条目 -func (this *HTTPWebDAO) DisableHTTPWeb(id int64) error { - _, err := this.Query(). +func (this *HTTPWebDAO) DisableHTTPWeb(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPWebStateDisabled). Update() @@ -72,8 +72,8 @@ func (this *HTTPWebDAO) DisableHTTPWeb(id int64) error { } // 查找启用中的条目 -func (this *HTTPWebDAO) FindEnabledHTTPWeb(id int64) (*HTTPWeb, error) { - result, err := this.Query(). +func (this *HTTPWebDAO) FindEnabledHTTPWeb(tx *dbs.Tx, id int64) (*HTTPWeb, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPWebStateEnabled). Find() @@ -84,8 +84,8 @@ func (this *HTTPWebDAO) FindEnabledHTTPWeb(id int64) (*HTTPWeb, error) { } // 组合配置 -func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebConfig, error) { - web, err := SharedHTTPWebDAO.FindEnabledHTTPWeb(webId) +func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfigs.HTTPWebConfig, error) { + web, err := SharedHTTPWebDAO.FindEnabledHTTPWeb(tx, webId) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon } config.GzipRef = gzipRef - gzipConfig, err := SharedHTTPGzipDAO.ComposeGzipConfig(gzipRef.GzipId) + gzipConfig, err := SharedHTTPGzipDAO.ComposeGzipConfig(tx, gzipRef.GzipId) if err != nil { return nil, err } @@ -143,7 +143,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon config.RequestHeaderPolicyRef = ref if ref.HeaderPolicyId > 0 { - headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, ref.HeaderPolicyId) if err != nil { return nil, err } @@ -162,7 +162,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon config.ResponseHeaderPolicyRef = ref if ref.HeaderPolicyId > 0 { - headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, ref.HeaderPolicyId) if err != nil { return nil, err } @@ -190,7 +190,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon return nil, err } for index, page := range pages { - pageConfig, err := SharedHTTPPageDAO.ComposePageConfig(page.Id) + pageConfig, err := SharedHTTPPageDAO.ComposePageConfig(tx, page.Id) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon if len(refs) > 0 { config.LocationRefs = refs - locations, err := SharedHTTPLocationDAO.ConvertLocationRefs(refs) + locations, err := SharedHTTPLocationDAO.ConvertLocationRefs(tx, refs) if err != nil { return nil, err } @@ -282,7 +282,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon } config.WebsocketRef = ref if ref.WebsocketId > 0 { - websocketConfig, err := SharedHTTPWebsocketDAO.ComposeWebsocketConfig(ref.WebsocketId) + websocketConfig, err := SharedHTTPWebsocketDAO.ComposeWebsocketConfig(tx, ref.WebsocketId) if err != nil { return nil, err } @@ -300,7 +300,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon return nil, err } for _, ref := range refs { - rewriteRule, err := SharedHTTPRewriteRuleDAO.ComposeRewriteRule(ref.RewriteRuleId) + rewriteRule, err := SharedHTTPRewriteRuleDAO.ComposeRewriteRule(tx, ref.RewriteRuleId) if err != nil { return nil, err } @@ -315,7 +315,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon } // 创建Web配置 -func (this *HTTPWebDAO) CreateWeb(adminId int64, userId int64, rootJSON []byte) (int64, error) { +func (this *HTTPWebDAO) CreateWeb(tx *dbs.Tx, adminId int64, userId int64, rootJSON []byte) (int64, error) { op := NewHTTPWebOperator() op.State = HTTPWebStateEnabled op.AdminId = adminId @@ -323,7 +323,7 @@ func (this *HTTPWebDAO) CreateWeb(adminId int64, userId int64, rootJSON []byte) if len(rootJSON) > 0 { op.Root = JSONBytes(rootJSON) } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -331,188 +331,188 @@ func (this *HTTPWebDAO) CreateWeb(adminId int64, userId int64, rootJSON []byte) } // 修改Web配置 -func (this *HTTPWebDAO) UpdateWeb(webId int64, rootJSON []byte) error { +func (this *HTTPWebDAO) UpdateWeb(tx *dbs.Tx, webId int64, rootJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Root = JSONBytes(rootJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改Gzip配置 -func (this *HTTPWebDAO) UpdateWebGzip(webId int64, gzipJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebGzip(tx *dbs.Tx, webId int64, gzipJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Gzip = JSONBytes(gzipJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改字符编码 -func (this *HTTPWebDAO) UpdateWebCharset(webId int64, charsetJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebCharset(tx *dbs.Tx, webId int64, charsetJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Charset = JSONBytes(charsetJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改请求Header策略 -func (this *HTTPWebDAO) UpdateWebRequestHeaderPolicy(webId int64, headerPolicyJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebRequestHeaderPolicy(tx *dbs.Tx, webId int64, headerPolicyJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.RequestHeader = JSONBytes(headerPolicyJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改响应Header策略 -func (this *HTTPWebDAO) UpdateWebResponseHeaderPolicy(webId int64, headerPolicyJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebResponseHeaderPolicy(tx *dbs.Tx, webId int64, headerPolicyJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.ResponseHeader = JSONBytes(headerPolicyJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改特殊页面配置 -func (this *HTTPWebDAO) UpdateWebPages(webId int64, pagesJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebPages(tx *dbs.Tx, webId int64, pagesJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Pages = JSONBytes(pagesJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改Shutdown配置 -func (this *HTTPWebDAO) UpdateWebShutdown(webId int64, shutdownJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebShutdown(tx *dbs.Tx, webId int64, shutdownJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Shutdown = JSONBytes(shutdownJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改访问日志策略 -func (this *HTTPWebDAO) UpdateWebAccessLogConfig(webId int64, accessLogJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebAccessLogConfig(tx *dbs.Tx, webId int64, accessLogJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.AccessLog = JSONBytes(accessLogJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改统计配置 -func (this *HTTPWebDAO) UpdateWebStat(webId int64, statJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebStat(tx *dbs.Tx, webId int64, statJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Stat = JSONBytes(statJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改缓存配置 -func (this *HTTPWebDAO) UpdateWebCache(webId int64, cacheJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebCache(tx *dbs.Tx, webId int64, cacheJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Cache = JSONBytes(cacheJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改防火墙配置 -func (this *HTTPWebDAO) UpdateWebFirewall(webId int64, firewallJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebFirewall(tx *dbs.Tx, webId int64, firewallJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Firewall = JSONBytes(firewallJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改路径规则配置 -func (this *HTTPWebDAO) UpdateWebLocations(webId int64, locationsJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebLocations(tx *dbs.Tx, webId int64, locationsJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Locations = JSONBytes(locationsJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 更改跳转到HTTPS设置 -func (this *HTTPWebDAO) UpdateWebRedirectToHTTPS(webId int64, redirectToHTTPSJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebRedirectToHTTPS(tx *dbs.Tx, webId int64, redirectToHTTPSJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.RedirectToHttps = JSONBytes(redirectToHTTPSJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改Websocket设置 -func (this *HTTPWebDAO) UpdateWebsocket(webId int64, websocketJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebsocket(tx *dbs.Tx, webId int64, websocketJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.Websocket = JSONBytes(websocketJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改重写规则设置 -func (this *HTTPWebDAO) UpdateWebRewriteRules(webId int64, rewriteRulesJSON []byte) error { +func (this *HTTPWebDAO) UpdateWebRewriteRules(tx *dbs.Tx, webId int64, rewriteRulesJSON []byte) error { if webId <= 0 { return errors.New("invalid webId") } op := NewHTTPWebOperator() op.Id = webId op.RewriteRules = JSONBytes(rewriteRulesJSON) - err := this.Save(op) + err := this.Save(tx, op) return err } // 根据缓存策略ID查找所有的WebId -func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]int64, error) { - ones, err := this.Query(). +func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(tx *dbs.Tx, cachePolicyId int64) ([]int64, error) { + ones, err := this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). Where(`JSON_CONTAINS(cache, '{"cachePolicyId": ` + strconv.FormatInt(cachePolicyId, 10) + ` }', '$.cacheRefs')`). @@ -527,7 +527,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]i // 判断是否为Location for { - locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(webId) + locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(tx, webId) if err != nil { return nil, err } @@ -542,7 +542,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]i // 查找包含此Location的Web // TODO 需要支持嵌套的Location查询 - webId, err = this.FindEnabledWebIdWithLocationId(locationId) + webId, err = this.FindEnabledWebIdWithLocationId(tx, locationId) if err != nil { return nil, err } @@ -555,8 +555,8 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]i } // 根据防火墙策略ID查找所有的WebId -func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId int64) ([]int64, error) { - ones, err := this.Query(). +func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(tx *dbs.Tx, firewallPolicyId int64) ([]int64, error) { + ones, err := this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). Where(`JSON_CONTAINS(firewall, '{"isOn": true, "firewallPolicyId": ` + strconv.FormatInt(firewallPolicyId, 10) + ` }')`). @@ -571,7 +571,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId i // 判断是否为Location for { - locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(webId) + locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(tx, webId) if err != nil { return nil, err } @@ -586,7 +586,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId i // 查找包含此Location的Web // TODO 需要支持嵌套的Location查询 - webId, err = this.FindEnabledWebIdWithLocationId(locationId) + webId, err = this.FindEnabledWebIdWithLocationId(tx, locationId) if err != nil { return nil, err } @@ -599,8 +599,8 @@ func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId i } // 查找包含某个Location的Web -func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(locationId int64) (webId int64, err error) { - return this.Query(). +func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(tx *dbs.Tx, locationId int64) (webId int64, err error) { + return this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). Where(`JSON_CONTAINS(locations, '{"locationId": ` + strconv.FormatInt(locationId, 10) + ` }')`). diff --git a/internal/db/models/http_websocket_dao.go b/internal/db/models/http_websocket_dao.go index 394503f7..2d7aaebc 100644 --- a/internal/db/models/http_websocket_dao.go +++ b/internal/db/models/http_websocket_dao.go @@ -38,8 +38,8 @@ func init() { } // 启用条目 -func (this *HTTPWebsocketDAO) EnableHTTPWebsocket(id int64) error { - _, err := this.Query(). +func (this *HTTPWebsocketDAO) EnableHTTPWebsocket(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPWebsocketStateEnabled). Update() @@ -47,8 +47,8 @@ func (this *HTTPWebsocketDAO) EnableHTTPWebsocket(id int64) error { } // 禁用条目 -func (this *HTTPWebsocketDAO) DisableHTTPWebsocket(id int64) error { - _, err := this.Query(). +func (this *HTTPWebsocketDAO) DisableHTTPWebsocket(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", HTTPWebsocketStateDisabled). Update() @@ -56,8 +56,8 @@ func (this *HTTPWebsocketDAO) DisableHTTPWebsocket(id int64) error { } // 查找启用中的条目 -func (this *HTTPWebsocketDAO) FindEnabledHTTPWebsocket(id int64) (*HTTPWebsocket, error) { - result, err := this.Query(). +func (this *HTTPWebsocketDAO) FindEnabledHTTPWebsocket(tx *dbs.Tx, id int64) (*HTTPWebsocket, error) { + result, err := this.Query(tx). Pk(id). Attr("state", HTTPWebsocketStateEnabled). Find() @@ -68,8 +68,8 @@ func (this *HTTPWebsocketDAO) FindEnabledHTTPWebsocket(id int64) (*HTTPWebsocket } // 组合配置 -func (this *HTTPWebsocketDAO) ComposeWebsocketConfig(websocketId int64) (*serverconfigs.HTTPWebsocketConfig, error) { - websocket, err := this.FindEnabledHTTPWebsocket(websocketId) +func (this *HTTPWebsocketDAO) ComposeWebsocketConfig(tx *dbs.Tx, websocketId int64) (*serverconfigs.HTTPWebsocketConfig, error) { + websocket, err := this.FindEnabledHTTPWebsocket(tx, websocketId) if err != nil { return nil, err } @@ -106,7 +106,7 @@ func (this *HTTPWebsocketDAO) ComposeWebsocketConfig(websocketId int64) (*server } // 创建Websocket配置 -func (this *HTTPWebsocketDAO) CreateWebsocket(handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) (websocketId int64, err error) { +func (this *HTTPWebsocketDAO) CreateWebsocket(tx *dbs.Tx, handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) (websocketId int64, err error) { op := NewHTTPWebsocketOperator() op.IsOn = true op.State = HTTPWebsocketStateEnabled @@ -123,12 +123,12 @@ func (this *HTTPWebsocketDAO) CreateWebsocket(handshakeTimeoutJSON []byte, allow } op.RequestSameOrigin = requestSameOrigin op.RequestOrigin = requestOrigin - err = this.Save(op) + err = this.Save(tx, op) return types.Int64(op.Id), err } // 修改Websocket配置 -func (this *HTTPWebsocketDAO) UpdateWebsocket(websocketId int64, handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) error { +func (this *HTTPWebsocketDAO) UpdateWebsocket(tx *dbs.Tx, websocketId int64, handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) error { if websocketId <= 0 { return errors.New("invalid websocketId") } @@ -149,6 +149,6 @@ func (this *HTTPWebsocketDAO) UpdateWebsocket(websocketId int64, handshakeTimeou } op.RequestSameOrigin = requestSameOrigin op.RequestOrigin = requestOrigin - err := this.Save(op) + err := this.Save(tx, op) return err } diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 01aebbc8..cfed27c7 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -36,8 +36,8 @@ func init() { } // 启用条目 -func (this *IPItemDAO) EnableIPItem(id int64) error { - _, err := this.Query(). +func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", IPItemStateEnabled). Update() @@ -45,13 +45,13 @@ func (this *IPItemDAO) EnableIPItem(id int64) error { } // 禁用条目 -func (this *IPItemDAO) DisableIPItem(id int64) error { - version, err := SharedIPListDAO.IncreaseVersion() +func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { + version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return err } - _, err = this.Query(). + _, err = this.Query(tx). Pk(id). Set("state", IPItemStateDisabled). Set("version", version). @@ -60,8 +60,8 @@ func (this *IPItemDAO) DisableIPItem(id int64) error { } // 查找启用中的条目 -func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) { - result, err := this.Query(). +func (this *IPItemDAO) FindEnabledIPItem(tx *dbs.Tx, id int64) (*IPItem, error) { + result, err := this.Query(tx). Pk(id). Attr("state", IPItemStateEnabled). Find() @@ -72,8 +72,8 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) { } // 创建IP -func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) { - version, err := SharedIPListDAO.IncreaseVersion() +func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) { + version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return 0, err } @@ -89,7 +89,7 @@ func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, ex } op.ExpiredAt = expiredAt op.State = IPItemStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -97,12 +97,12 @@ func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, ex } // 修改IP -func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string) error { +func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string) error { if itemId <= 0 { return errors.New("invalid itemId") } - listId, err := this.Query(). + listId, err := this.Query(tx). Pk(itemId). Result("listId"). FindInt64Col(0) @@ -113,7 +113,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex return errors.New("not found") } - version, err := SharedIPListDAO.IncreaseVersion() + version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return err } @@ -128,21 +128,21 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex } op.ExpiredAt = expiredAt op.Version = version - err = this.Save(op) + err = this.Save(tx, op) return err } // 计算IP数量 -func (this *IPItemDAO) CountIPItemsWithListId(listId int64) (int64, error) { - return this.Query(). +func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64) (int64, error) { + return this.Query(tx). State(IPItemStateEnabled). Attr("listId", listId). Count() } // 查找IP列表 -func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size int64) (result []*IPItem, err error) { - _, err = this.Query(). +func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, offset int64, size int64) (result []*IPItem, err error) { + _, err = this.Query(tx). State(IPItemStateEnabled). Attr("listId", listId). DescPk(). @@ -154,8 +154,8 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in } // 根据版本号查找IP列表 -func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) { - _, err = this.Query(). +func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size int64) (result []*IPItem, err error) { + _, err = this.Query(tx). // 这里不要设置状态参数,因为我们要知道哪些是删除的 Gt("version", version). Where("(expiredAt=0 OR expiredAt>:expiredAt)"). diff --git a/internal/db/models/ip_library_dao.go b/internal/db/models/ip_library_dao.go index 201748ca..cff83ae6 100644 --- a/internal/db/models/ip_library_dao.go +++ b/internal/db/models/ip_library_dao.go @@ -34,8 +34,8 @@ func init() { } // 启用条目 -func (this *IPLibraryDAO) EnableIPLibrary(id int64) error { - _, err := this.Query(). +func (this *IPLibraryDAO) EnableIPLibrary(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", IPLibraryStateEnabled). Update() @@ -43,8 +43,8 @@ func (this *IPLibraryDAO) EnableIPLibrary(id int64) error { } // 禁用条目 -func (this *IPLibraryDAO) DisableIPLibrary(id int64) error { - _, err := this.Query(). +func (this *IPLibraryDAO) DisableIPLibrary(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", IPLibraryStateDisabled). Update() @@ -52,8 +52,8 @@ func (this *IPLibraryDAO) DisableIPLibrary(id int64) error { } // 查找启用中的条目 -func (this *IPLibraryDAO) FindEnabledIPLibrary(id int64) (*IPLibrary, error) { - result, err := this.Query(). +func (this *IPLibraryDAO) FindEnabledIPLibrary(tx *dbs.Tx, id int64) (*IPLibrary, error) { + result, err := this.Query(tx). Pk(id). Attr("state", IPLibraryStateEnabled). Find() @@ -64,8 +64,8 @@ func (this *IPLibraryDAO) FindEnabledIPLibrary(id int64) (*IPLibrary, error) { } // 查找某个类型的IP库列表 -func (this *IPLibraryDAO) FindAllEnabledIPLibrariesWithType(libraryType string) (result []*IPLibrary, err error) { - _, err = this.Query(). +func (this *IPLibraryDAO) FindAllEnabledIPLibrariesWithType(tx *dbs.Tx, libraryType string) (result []*IPLibrary, err error) { + _, err = this.Query(tx). State(IPLibraryStateEnabled). Attr("type", libraryType). DescPk(). @@ -75,8 +75,8 @@ func (this *IPLibraryDAO) FindAllEnabledIPLibrariesWithType(libraryType string) } // 查找某个类型的最新的IP库 -func (this *IPLibraryDAO) FindLatestIPLibraryWithType(libraryType string) (*IPLibrary, error) { - one, err := this.Query(). +func (this *IPLibraryDAO) FindLatestIPLibraryWithType(tx *dbs.Tx, libraryType string) (*IPLibrary, error) { + one, err := this.Query(tx). State(IPLibraryStateEnabled). Attr("type", libraryType). DescPk(). @@ -91,12 +91,12 @@ func (this *IPLibraryDAO) FindLatestIPLibraryWithType(libraryType string) (*IPLi } // 创建新的IP库 -func (this *IPLibraryDAO) CreateIPLibrary(libraryType string, fileId int64) (int64, error) { +func (this *IPLibraryDAO) CreateIPLibrary(tx *dbs.Tx, libraryType string, fileId int64) (int64, error) { op := NewIPLibraryOperator() op.Type = libraryType op.FileId = fileId op.State = IPLibraryStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index b7929130..531c2741 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -37,8 +37,8 @@ func init() { } // 启用条目 -func (this *IPListDAO) EnableIPList(id int64) error { - _, err := this.Query(). +func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", IPListStateEnabled). Update() @@ -46,8 +46,8 @@ func (this *IPListDAO) EnableIPList(id int64) error { } // 禁用条目 -func (this *IPListDAO) DisableIPList(id int64) error { - _, err := this.Query(). +func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", IPListStateDisabled). Update() @@ -55,8 +55,8 @@ func (this *IPListDAO) DisableIPList(id int64) error { } // 查找启用中的条目 -func (this *IPListDAO) FindEnabledIPList(id int64) (*IPList, error) { - result, err := this.Query(). +func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64) (*IPList, error) { + result, err := this.Query(tx). Pk(id). Attr("state", IPListStateEnabled). Find() @@ -67,15 +67,15 @@ func (this *IPListDAO) FindEnabledIPList(id int64) (*IPList, error) { } // 根据主键查找名称 -func (this *IPListDAO) FindIPListName(id int64) (string, error) { - return this.Query(). +func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建名单 -func (this *IPListDAO) CreateIPList(listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) { +func (this *IPListDAO) CreateIPList(tx *dbs.Tx, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) { op := NewIPListOperator() op.IsOn = true op.State = IPListStateEnabled @@ -85,7 +85,7 @@ func (this *IPListDAO) CreateIPList(listType ipconfigs.IPListType, name string, if len(timeoutJSON) > 0 { op.Timeout = timeoutJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -93,7 +93,7 @@ func (this *IPListDAO) CreateIPList(listType ipconfigs.IPListType, name string, } // 修改名单 -func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, timeoutJSON []byte) error { +func (this *IPListDAO) UpdateIPList(tx *dbs.Tx, listId int64, name string, code string, timeoutJSON []byte) error { if listId <= 0 { return errors.New("invalid listId") } @@ -106,18 +106,18 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time } else { op.Timeout = "null" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 增加版本 -func (this *IPListDAO) IncreaseVersion() (int64, error) { - valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion) +func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) { + valueJSON, err := SharedSysSettingDAO.ReadSetting(tx, SettingCodeIPListVersion) if err != nil { return 0, err } if len(valueJSON) == 0 { - err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1")) + err = SharedSysSettingDAO.UpdateSetting(tx, SettingCodeIPListVersion, []byte("1")) if err != nil { return 0, err } @@ -125,6 +125,6 @@ func (this *IPListDAO) IncreaseVersion() (int64, error) { } value := types.Int64(string(valueJSON)) + 1 - err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value))) + err = SharedSysSettingDAO.UpdateSetting(tx, SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value))) return value, nil } diff --git a/internal/db/models/log_dao.go b/internal/db/models/log_dao.go index 2d33e21d..4df92a4f 100644 --- a/internal/db/models/log_dao.go +++ b/internal/db/models/log_dao.go @@ -34,7 +34,7 @@ func init() { } // 创建管理员日志 -func (this *LogDAO) CreateLog(adminType string, adminId int64, level string, description string, action string, ip string) error { +func (this *LogDAO) CreateLog(tx *dbs.Tx, adminType string, adminId int64, level string, description string, action string, ip string) error { op := NewLogOperator() op.Level = level op.Description = description @@ -53,16 +53,16 @@ func (this *LogDAO) CreateLog(adminType string, adminId int64, level string, des op.Day = timeutil.Format("Ymd") op.Type = LogTypeAdmin - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算所有日志数量 -func (this *LogDAO) CountLogs(dayFrom string, dayTo string, keyword string, userType string) (int64, error) { +func (this *LogDAO) CountLogs(tx *dbs.Tx, dayFrom string, dayTo string, keyword string, userType string) (int64, error) { dayFrom = this.formatDay(dayFrom) dayTo = this.formatDay(dayTo) - query := this.Query() + query := this.Query(tx) if len(dayFrom) > 0 { query.Gte("day", dayFrom) @@ -87,11 +87,11 @@ func (this *LogDAO) CountLogs(dayFrom string, dayTo string, keyword string, user } // 列出单页日志 -func (this *LogDAO) ListLogs(offset int64, size int64, dayFrom string, dayTo string, keyword string, userType string) (result []*Log, err error) { +func (this *LogDAO) ListLogs(tx *dbs.Tx, offset int64, size int64, dayFrom string, dayTo string, keyword string, userType string) (result []*Log, err error) { dayFrom = this.formatDay(dayFrom) dayTo = this.formatDay(dayTo) - query := this.Query() + query := this.Query(tx) if len(dayFrom) > 0 { query.Gte("day", dayFrom) } @@ -121,28 +121,28 @@ func (this *LogDAO) ListLogs(offset int64, size int64, dayFrom string, dayTo str } // 物理删除日志 -func (this *LogDAO) DeleteLogPermanently(logId int64) error { +func (this *LogDAO) DeleteLogPermanently(tx *dbs.Tx, logId int64) error { if logId <= 0 { return errors.New("invalid logId") } - _, err := this.Delete(logId) + _, err := this.Delete(tx, logId) return err } // 物理删除所有日志 -func (this *LogDAO) DeleteAllLogsPermanently() error { - _, err := this.Query(). +func (this *LogDAO) DeleteAllLogsPermanently(tx *dbs.Tx) error { + _, err := this.Query(tx). Delete() return err } // 物理删除某些天之前的日志 -func (this *LogDAO) DeleteLogsPermanentlyBeforeDays(days int) error { +func (this *LogDAO) DeleteLogsPermanentlyBeforeDays(tx *dbs.Tx, days int) error { if days <= 0 { days = 0 } untilDay := timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) - _, err := this.Query(). + _, err := this.Query(tx). Lte("day", untilDay). Delete() return err diff --git a/internal/db/models/login_dao.go b/internal/db/models/login_dao.go index 5fbeadc1..419ca3a1 100644 --- a/internal/db/models/login_dao.go +++ b/internal/db/models/login_dao.go @@ -41,8 +41,8 @@ func init() { } // 启用条目 -func (this *LoginDAO) EnableLogin(id int64) error { - _, err := this.Query(). +func (this *LoginDAO) EnableLogin(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", LoginStateEnabled). Update() @@ -50,8 +50,8 @@ func (this *LoginDAO) EnableLogin(id int64) error { } // 禁用条目 -func (this *LoginDAO) DisableLogin(id int64) error { - _, err := this.Query(). +func (this *LoginDAO) DisableLogin(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", LoginStateDisabled). Update() @@ -59,8 +59,8 @@ func (this *LoginDAO) DisableLogin(id int64) error { } // 查找启用中的条目 -func (this *LoginDAO) FindEnabledLogin(id int64) (*Login, error) { - result, err := this.Query(). +func (this *LoginDAO) FindEnabledLogin(tx *dbs.Tx, id int64) (*Login, error) { + result, err := this.Query(tx). Pk(id). Attr("state", LoginStateEnabled). Find() @@ -71,7 +71,7 @@ func (this *LoginDAO) FindEnabledLogin(id int64) (*Login, error) { } // 创建认证 -func (this *LoginDAO) CreateLogin(Id int64, loginType LoginType, params maps.Map) (int64, error) { +func (this *LoginDAO) CreateLogin(tx *dbs.Tx, Id int64, loginType LoginType, params maps.Map) (int64, error) { if Id <= 0 { return 0, errors.New("invalid Id") } @@ -84,13 +84,13 @@ func (this *LoginDAO) CreateLogin(Id int64, loginType LoginType, params maps.Map op.Params = params.AsJSON() op.State = LoginStateEnabled op.IsOn = true - return this.SaveInt64(op) + return this.SaveInt64(tx, op) } // 修改认证 -func (this *LoginDAO) UpdateLogin(adminId int64, loginType LoginType, params maps.Map, isOn bool) error { +func (this *LoginDAO) UpdateLogin(tx *dbs.Tx, adminId int64, loginType LoginType, params maps.Map, isOn bool) error { // 是否已经存在 - loginId, err := this.Query(). + loginId, err := this.Query(tx). Attr("adminId", adminId). Attr("type", loginType). State(LoginStateEnabled). @@ -114,12 +114,12 @@ func (this *LoginDAO) UpdateLogin(adminId int64, loginType LoginType, params map op.IsOn = isOn op.Params = params.AsJSON() - return this.Save(op) + return this.Save(tx, op) } // 禁用相关认证 -func (this *LoginDAO) DisableLoginWithAdminId(adminId int64, loginType LoginType) error { - _, err := this.Query(). +func (this *LoginDAO) DisableLoginWithAdminId(tx *dbs.Tx, adminId int64, loginType LoginType) error { + _, err := this.Query(tx). Attr("adminId", adminId). Attr("type", loginType). Set("isOn", false). @@ -128,8 +128,8 @@ func (this *LoginDAO) DisableLoginWithAdminId(adminId int64, loginType LoginType } // 查找管理员相关的认证 -func (this *LoginDAO) FindEnabledLoginWithAdminId(adminId int64, loginType LoginType) (*Login, error) { - one, err := this.Query(). +func (this *LoginDAO) FindEnabledLoginWithAdminId(tx *dbs.Tx, adminId int64, loginType LoginType) (*Login, error) { + one, err := this.Query(tx). Attr("adminId", adminId). Attr("type", loginType). State(LoginStateEnabled). @@ -141,8 +141,8 @@ func (this *LoginDAO) FindEnabledLoginWithAdminId(adminId int64, loginType Login } // 检查某个认证是否启用 -func (this *LoginDAO) CheckLoginIsOn(adminId int64, loginType LoginType) (bool, error) { - return this.Query(). +func (this *LoginDAO) CheckLoginIsOn(tx *dbs.Tx, adminId int64, loginType LoginType) (bool, error) { + return this.Query(tx). Attr("adminId", adminId). Attr("type", loginType). State(LoginStateEnabled). diff --git a/internal/db/models/message_dao.go b/internal/db/models/message_dao.go index 443be5f1..9080c137 100644 --- a/internal/db/models/message_dao.go +++ b/internal/db/models/message_dao.go @@ -61,8 +61,8 @@ func init() { } // 启用条目 -func (this *MessageDAO) EnableMessage(id int64) error { - _, err := this.Query(). +func (this *MessageDAO) EnableMessage(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", MessageStateEnabled). Update() @@ -70,8 +70,8 @@ func (this *MessageDAO) EnableMessage(id int64) error { } // 禁用条目 -func (this *MessageDAO) DisableMessage(id int64) error { - _, err := this.Query(). +func (this *MessageDAO) DisableMessage(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", MessageStateDisabled). Update() @@ -79,8 +79,8 @@ func (this *MessageDAO) DisableMessage(id int64) error { } // 查找启用中的条目 -func (this *MessageDAO) FindEnabledMessage(id int64) (*Message, error) { - result, err := this.Query(). +func (this *MessageDAO) FindEnabledMessage(tx *dbs.Tx, id int64) (*Message, error) { + result, err := this.Query(tx). Pk(id). Attr("state", MessageStateEnabled). Find() @@ -91,19 +91,19 @@ func (this *MessageDAO) FindEnabledMessage(id int64) (*Message, error) { } // 创建集群消息 -func (this *MessageDAO) CreateClusterMessage(clusterId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { - _, err := this.createMessage(clusterId, 0, messageType, level, body, paramsJSON) +func (this *MessageDAO) CreateClusterMessage(tx *dbs.Tx, clusterId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { + _, err := this.createMessage(tx, clusterId, 0, messageType, level, body, paramsJSON) return err } // 创建节点消息 -func (this *MessageDAO) CreateNodeMessage(clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { - _, err := this.createMessage(clusterId, nodeId, messageType, level, body, paramsJSON) +func (this *MessageDAO) CreateNodeMessage(tx *dbs.Tx, clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { + _, err := this.createMessage(tx, clusterId, nodeId, messageType, level, body, paramsJSON) return err } // 创建普通消息 -func (this *MessageDAO) CreateMessage(adminId int64, userId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { +func (this *MessageDAO) CreateMessage(tx *dbs.Tx, adminId int64, userId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { h := md5.New() h.Write([]byte(body)) h.Write(paramsJSON) @@ -122,14 +122,14 @@ func (this *MessageDAO) CreateMessage(adminId int64, userId int64, messageType M op.IsRead = false op.Day = timeutil.Format("Ymd") op.Hash = hash - err := this.Save(op) + err := this.Save(tx, op) return err } // 删除某天之前的消息 -func (this *MessageDAO) DeleteMessagesBeforeDay(dayTime time.Time) error { +func (this *MessageDAO) DeleteMessagesBeforeDay(tx *dbs.Tx, dayTime time.Time) error { day := timeutil.Format("Ymd", dayTime) - _, err := this.Query(). + _, err := this.Query(tx). Where("day<:day"). Param("day", day). Delete() @@ -137,8 +137,8 @@ func (this *MessageDAO) DeleteMessagesBeforeDay(dayTime time.Time) error { } // 计算未读消息数量 -func (this *MessageDAO) CountUnreadMessages(adminId int64, userId int64) (int64, error) { - query := this.Query(). +func (this *MessageDAO) CountUnreadMessages(tx *dbs.Tx, adminId int64, userId int64) (int64, error) { + query := this.Query(tx). Attr("isRead", false) if adminId > 0 { query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). @@ -150,8 +150,8 @@ func (this *MessageDAO) CountUnreadMessages(adminId int64, userId int64) (int64, } // 列出单页未读消息 -func (this *MessageDAO) ListUnreadMessages(adminId int64, userId int64, offset int64, size int64) (result []*Message, err error) { - query := this.Query(). +func (this *MessageDAO) ListUnreadMessages(tx *dbs.Tx, adminId int64, userId int64, offset int64, size int64) (result []*Message, err error) { + query := this.Query(tx). Attr("isRead", false) if adminId > 0 { query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). @@ -169,22 +169,22 @@ func (this *MessageDAO) ListUnreadMessages(adminId int64, userId int64, offset i } // 设置消息已读状态 -func (this *MessageDAO) UpdateMessageRead(messageId int64, b bool) error { +func (this *MessageDAO) UpdateMessageRead(tx *dbs.Tx, messageId int64, b bool) error { if messageId <= 0 { return errors.New("invalid messageId") } op := NewMessageOperator() op.Id = messageId op.IsRead = b - err := this.Save(op) + err := this.Save(tx, op) return err } // 设置一组消息为已读状态 -func (this *MessageDAO) UpdateMessagesRead(messageIds []int64, b bool) error { +func (this *MessageDAO) UpdateMessagesRead(tx *dbs.Tx, messageIds []int64, b bool) error { // 这里我们一个一个更改,因为In语句不容易Prepare,且效率不高 for _, messageId := range messageIds { - err := this.UpdateMessageRead(messageId, b) + err := this.UpdateMessageRead(tx, messageId, b) if err != nil { return err } @@ -193,8 +193,8 @@ func (this *MessageDAO) UpdateMessagesRead(messageIds []int64, b bool) error { } // 设置所有消息为已读 -func (this *MessageDAO) UpdateAllMessagesRead(adminId int64, userId int64) error { - query := this.Query(). +func (this *MessageDAO) UpdateAllMessagesRead(tx *dbs.Tx, adminId int64, userId int64) error { + query := this.Query(tx). Attr("isRead", false) if adminId > 0 { query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). @@ -209,11 +209,11 @@ func (this *MessageDAO) UpdateAllMessagesRead(adminId int64, userId int64) error } // 检查消息权限 -func (this *MessageDAO) CheckMessageUser(messageId int64, adminId int64, userId int64) (bool, error) { +func (this *MessageDAO) CheckMessageUser(tx *dbs.Tx, messageId int64, adminId int64, userId int64) (bool, error) { if messageId <= 0 || (adminId <= 0 && userId <= 0) { return false, nil } - query := this.Query(). + query := this.Query(tx). Pk(messageId) if adminId > 0 { query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). @@ -225,7 +225,7 @@ func (this *MessageDAO) CheckMessageUser(messageId int64, adminId int64, userId } // 创建消息 -func (this *MessageDAO) createMessage(clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) (int64, error) { +func (this *MessageDAO) createMessage(tx *dbs.Tx, clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) (int64, error) { h := md5.New() h.Write([]byte(body)) h.Write(paramsJSON) @@ -251,7 +251,7 @@ func (this *MessageDAO) createMessage(clusterId int64, nodeId int64, messageType op.Day = timeutil.Format("Ymd") op.Hash = hash - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } diff --git a/internal/db/models/node_cluster_dao.go b/internal/db/models/node_cluster_dao.go index 805253f0..dea77e45 100644 --- a/internal/db/models/node_cluster_dao.go +++ b/internal/db/models/node_cluster_dao.go @@ -43,8 +43,8 @@ func init() { } // 启用条目 -func (this *NodeClusterDAO) EnableNodeCluster(id int64) error { - _, err := this.Query(). +func (this *NodeClusterDAO) EnableNodeCluster(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodeClusterStateEnabled). Update() @@ -52,8 +52,8 @@ func (this *NodeClusterDAO) EnableNodeCluster(id int64) error { } // 禁用条目 -func (this *NodeClusterDAO) DisableNodeCluster(id int64) error { - _, err := this.Query(). +func (this *NodeClusterDAO) DisableNodeCluster(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodeClusterStateDisabled). Update() @@ -61,8 +61,8 @@ func (this *NodeClusterDAO) DisableNodeCluster(id int64) error { } // 查找集群 -func (this *NodeClusterDAO) FindEnabledNodeCluster(id int64) (*NodeCluster, error) { - result, err := this.Query(). +func (this *NodeClusterDAO) FindEnabledNodeCluster(tx *dbs.Tx, id int64) (*NodeCluster, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeClusterStateEnabled). Find() @@ -74,8 +74,8 @@ func (this *NodeClusterDAO) FindEnabledNodeCluster(id int64) (*NodeCluster, erro // 根据UniqueId获取ID // TODO 增加缓存 -func (this *NodeClusterDAO) FindEnabledClusterIdWithUniqueId(uniqueId string) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) FindEnabledClusterIdWithUniqueId(tx *dbs.Tx, uniqueId string) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Attr("uniqueId", uniqueId). ResultPk(). @@ -83,16 +83,16 @@ func (this *NodeClusterDAO) FindEnabledClusterIdWithUniqueId(uniqueId string) (i } // 根据主键查找名称 -func (this *NodeClusterDAO) FindNodeClusterName(id int64) (string, error) { - return this.Query(). +func (this *NodeClusterDAO) FindNodeClusterName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 查找所有可用的集群 -func (this *NodeClusterDAO) FindAllEnableClusters() (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnableClusters(tx *dbs.Tx) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Slice(&result). Desc("order"). @@ -102,14 +102,14 @@ func (this *NodeClusterDAO) FindAllEnableClusters() (result []*NodeCluster, err } // 创建集群 -func (this *NodeClusterDAO) CreateCluster(adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string, cachePolicyId int64, httpFirewallPolicyId int64) (clusterId int64, err error) { - uniqueId, err := this.genUniqueId() +func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string, cachePolicyId int64, httpFirewallPolicyId int64) (clusterId int64, err error) { + uniqueId, err := this.genUniqueId(tx) if err != nil { return 0, err } secret := rands.String(32) - err = SharedApiTokenDAO.CreateAPIToken(uniqueId, secret, NodeRoleCluster) + err = SharedApiTokenDAO.CreateAPIToken(tx, uniqueId, secret, NodeRoleCluster) if err != nil { return 0, err } @@ -144,7 +144,7 @@ func (this *NodeClusterDAO) CreateCluster(adminId int64, name string, grantId in op.UniqueId = uniqueId op.Secret = secret op.State = NodeClusterStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -153,7 +153,7 @@ func (this *NodeClusterDAO) CreateCluster(adminId int64, name string, grantId in } // 修改集群 -func (this *NodeClusterDAO) UpdateCluster(clusterId int64, name string, grantId int64, installDir string) error { +func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name string, grantId int64, installDir string) error { if clusterId <= 0 { return errors.New("invalid clusterId") } @@ -162,13 +162,13 @@ func (this *NodeClusterDAO) UpdateCluster(clusterId int64, name string, grantId op.Name = name op.GrantId = grantId op.InstallDir = installDir - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算所有集群数量 -func (this *NodeClusterDAO) CountAllEnabledClusters(keyword string) (int64, error) { - query := this.Query(). +func (this *NodeClusterDAO) CountAllEnabledClusters(tx *dbs.Tx, keyword string) (int64, error) { + query := this.Query(tx). State(NodeClusterStateEnabled) if len(keyword) > 0 { query.Where("(name LIKE :keyword OR dnsName like :keyword)"). @@ -178,8 +178,8 @@ func (this *NodeClusterDAO) CountAllEnabledClusters(keyword string) (int64, erro } // 列出单页集群 -func (this *NodeClusterDAO) ListEnabledClusters(keyword string, offset, size int64) (result []*NodeCluster, err error) { - query := this.Query(). +func (this *NodeClusterDAO) ListEnabledClusters(tx *dbs.Tx, keyword string, offset, size int64) (result []*NodeCluster, err error) { + query := this.Query(tx). State(NodeClusterStateEnabled) if len(keyword) > 0 { query.Where("(name LIKE :keyword OR dnsName like :keyword)"). @@ -195,8 +195,8 @@ func (this *NodeClusterDAO) ListEnabledClusters(keyword string, offset, size int } // 查找所有API节点地址 -func (this *NodeClusterDAO) FindAllAPINodeAddrsWithCluster(clusterId int64) (result []string, err error) { - one, err := this.Query(). +func (this *NodeClusterDAO) FindAllAPINodeAddrsWithCluster(tx *dbs.Tx, clusterId int64) (result []string, err error) { + one, err := this.Query(tx). Pk(clusterId). Result("useAllAPINodes", "apiNodes"). Find() @@ -208,7 +208,7 @@ func (this *NodeClusterDAO) FindAllAPINodeAddrsWithCluster(clusterId int64) (res } cluster := one.(*NodeCluster) if cluster.UseAllAPINodes == 1 { - apiNodes, err := SharedAPINodeDAO.FindAllEnabledAPINodes() + apiNodes, err := SharedAPINodeDAO.FindAllEnabledAPINodes(tx) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func (this *NodeClusterDAO) FindAllAPINodeAddrsWithCluster(clusterId int64) (res return nil, err } for _, apiNodeId := range apiNodeIds { - apiNode, err := SharedAPINodeDAO.FindEnabledAPINode(apiNodeId) + apiNode, err := SharedAPINodeDAO.FindEnabledAPINode(tx, apiNodeId) if err != nil { return nil, err } @@ -251,8 +251,8 @@ func (this *NodeClusterDAO) FindAllAPINodeAddrsWithCluster(clusterId int64) (res } // 查找健康检查设置 -func (this *NodeClusterDAO) FindClusterHealthCheckConfig(clusterId int64) (*serverconfigs.HealthCheckConfig, error) { - col, err := this.Query(). +func (this *NodeClusterDAO) FindClusterHealthCheckConfig(tx *dbs.Tx, clusterId int64) (*serverconfigs.HealthCheckConfig, error) { + col, err := this.Query(tx). Pk(clusterId). Result("healthCheck"). FindStringCol("") @@ -272,28 +272,28 @@ func (this *NodeClusterDAO) FindClusterHealthCheckConfig(clusterId int64) (*serv } // 修改健康检查设置 -func (this *NodeClusterDAO) UpdateClusterHealthCheck(clusterId int64, healthCheckJSON []byte) error { +func (this *NodeClusterDAO) UpdateClusterHealthCheck(tx *dbs.Tx, clusterId int64, healthCheckJSON []byte) error { if clusterId <= 0 { return errors.New("invalid clusterId '" + strconv.FormatInt(clusterId, 10) + "'") } op := NewNodeClusterOperator() op.Id = clusterId op.HealthCheck = healthCheckJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算使用某个认证的集群数量 -func (this *NodeClusterDAO) CountAllEnabledClustersWithGrantId(grantId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) CountAllEnabledClustersWithGrantId(tx *dbs.Tx, grantId int64) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Attr("grantId", grantId). Count() } // 获取使用某个认证的所有集群 -func (this *NodeClusterDAO) FindAllEnabledClustersWithGrantId(grantId int64) (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledClustersWithGrantId(tx *dbs.Tx, grantId int64) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Attr("grantId", grantId). Slice(&result). @@ -303,8 +303,8 @@ func (this *NodeClusterDAO) FindAllEnabledClustersWithGrantId(grantId int64) (re } // 计算使用某个DNS服务商的集群数量 -func (this *NodeClusterDAO) CountAllEnabledClustersWithDNSProviderId(dnsProviderId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) CountAllEnabledClustersWithDNSProviderId(tx *dbs.Tx, dnsProviderId int64) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Where("dnsDomainId IN (SELECT id FROM "+SharedDNSDomainDAO.Table+" WHERE state=1 AND providerId=:providerId)"). Param("providerId", dnsProviderId). @@ -312,8 +312,8 @@ func (this *NodeClusterDAO) CountAllEnabledClustersWithDNSProviderId(dnsProvider } // 获取所有使用某个DNS服务商的集群 -func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSProviderId(dnsProviderId int64) (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSProviderId(tx *dbs.Tx, dnsProviderId int64) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Where("dnsDomainId IN (SELECT id FROM "+SharedDNSDomainDAO.Table+" WHERE state=1 AND providerId=:providerId)"). Param("providerId", dnsProviderId). @@ -324,16 +324,16 @@ func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSProviderId(dnsProviderI } // 计算使用某个DNS域名的集群数量 -func (this *NodeClusterDAO) CountAllEnabledClustersWithDNSDomainId(dnsDomainId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) CountAllEnabledClustersWithDNSDomainId(tx *dbs.Tx, dnsDomainId int64) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Attr("dnsDomainId", dnsDomainId). Count() } // 查询使用某个DNS域名的集群ID列表 -func (this *NodeClusterDAO) FindAllEnabledClusterIdsWithDNSDomainId(dnsDomainId int64) ([]int64, error) { - ones, err := this.Query(). +func (this *NodeClusterDAO) FindAllEnabledClusterIdsWithDNSDomainId(tx *dbs.Tx, dnsDomainId int64) ([]int64, error) { + ones, err := this.Query(tx). State(NodeClusterStateEnabled). Attr("dnsDomainId", dnsDomainId). ResultPk(). @@ -349,8 +349,8 @@ func (this *NodeClusterDAO) FindAllEnabledClusterIdsWithDNSDomainId(dnsDomainId } // 查询使用某个DNS域名的所有集群域名 -func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSDomainId(dnsDomainId int64) (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSDomainId(tx *dbs.Tx, dnsDomainId int64) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Attr("dnsDomainId", dnsDomainId). Result("id", "name", "dnsName", "dnsDomainId"). @@ -360,8 +360,8 @@ func (this *NodeClusterDAO) FindAllEnabledClustersWithDNSDomainId(dnsDomainId in } // 查询已经设置了域名的集群 -func (this *NodeClusterDAO) FindAllEnabledClustersHaveDNSDomain() (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledClustersHaveDNSDomain(tx *dbs.Tx) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Gt("dnsDomainId", 0). Result("id", "name", "dnsName", "dnsDomainId"). @@ -371,16 +371,16 @@ func (this *NodeClusterDAO) FindAllEnabledClustersHaveDNSDomain() (result []*Nod } // 查找集群的认证ID -func (this *NodeClusterDAO) FindClusterGrantId(clusterId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) FindClusterGrantId(tx *dbs.Tx, clusterId int64) (int64, error) { + return this.Query(tx). Pk(clusterId). Result("grantId"). FindInt64Col(0) } // 查找DNS信息 -func (this *NodeClusterDAO) FindClusterDNSInfo(clusterId int64) (*NodeCluster, error) { - one, err := this.Query(). +func (this *NodeClusterDAO) FindClusterDNSInfo(tx *dbs.Tx, clusterId int64) (*NodeCluster, error) { + one, err := this.Query(tx). Pk(clusterId). Result("id", "name", "dnsName", "dnsDomainId", "dns"). Find() @@ -394,8 +394,8 @@ func (this *NodeClusterDAO) FindClusterDNSInfo(clusterId int64) (*NodeCluster, e } // 检查某个子域名是否可用 -func (this *NodeClusterDAO) ExistClusterDNSName(dnsName string, excludeClusterId int64) (bool, error) { - return this.Query(). +func (this *NodeClusterDAO) ExistClusterDNSName(tx *dbs.Tx, dnsName string, excludeClusterId int64) (bool, error) { + return this.Query(tx). Attr("dnsName", dnsName). State(NodeClusterStateEnabled). Where("id!=:clusterId"). @@ -404,7 +404,7 @@ func (this *NodeClusterDAO) ExistClusterDNSName(dnsName string, excludeClusterId } // 修改集群DNS相关信息 -func (this *NodeClusterDAO) UpdateClusterDNS(clusterId int64, dnsName string, dnsDomainId int64, nodesAutoSync bool, serversAutoSync bool) error { +func (this *NodeClusterDAO) UpdateClusterDNS(tx *dbs.Tx, clusterId int64, dnsName string, dnsDomainId int64, nodesAutoSync bool, serversAutoSync bool) error { if clusterId <= 0 { return errors.New("invalid clusterId") } @@ -423,17 +423,17 @@ func (this *NodeClusterDAO) UpdateClusterDNS(clusterId int64, dnsName string, dn } op.Dns = dnsJSON - err = this.Save(op) + err = this.Save(tx, op) return err } // 检查集群的DNS问题 -func (this *NodeClusterDAO) CheckClusterDNS(cluster *NodeCluster) (issues []*pb.DNSIssue, err error) { +func (this *NodeClusterDAO) CheckClusterDNS(tx *dbs.Tx, cluster *NodeCluster) (issues []*pb.DNSIssue, err error) { clusterId := int64(cluster.Id) domainId := int64(cluster.DnsDomainId) // 检查域名 - domain, err := SharedDNSDomainDAO.FindEnabledDNSDomain(domainId) + domain, err := SharedDNSDomainDAO.FindEnabledDNSDomain(tx, domainId) if err != nil { return nil, err } @@ -465,7 +465,7 @@ func (this *NodeClusterDAO) CheckClusterDNS(cluster *NodeCluster) (issues []*pb. // TODO 检查域名是否已解析 // 检查节点 - nodes, err := SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(clusterId) + nodes, err := SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, clusterId) if err != nil { return nil, err } @@ -515,7 +515,7 @@ func (this *NodeClusterDAO) CheckClusterDNS(cluster *NodeCluster) (issues []*pb. } // 检查IP地址 - ipAddr, err := SharedNodeIPAddressDAO.FindFirstNodeIPAddress(nodeId) + ipAddr, err := SharedNodeIPAddressDAO.FindFirstNodeIPAddress(tx, nodeId) if err != nil { return nil, err } @@ -540,16 +540,16 @@ func (this *NodeClusterDAO) CheckClusterDNS(cluster *NodeCluster) (issues []*pb. } // 查找集群所属管理员 -func (this *NodeClusterDAO) FindClusterAdminId(clusterId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) FindClusterAdminId(tx *dbs.Tx, clusterId int64) (int64, error) { + return this.Query(tx). Pk(clusterId). Result("adminId"). FindInt64Col(0) } // 查找集群的TOA设置 -func (this *NodeClusterDAO) FindClusterTOAConfig(clusterId int64) (*nodeconfigs.TOAConfig, error) { - toa, err := this.Query(). +func (this *NodeClusterDAO) FindClusterTOAConfig(tx *dbs.Tx, clusterId int64) (*nodeconfigs.TOAConfig, error) { + toa, err := this.Query(tx). Pk(clusterId). Result("toa"). FindStringCol("") @@ -569,28 +569,28 @@ func (this *NodeClusterDAO) FindClusterTOAConfig(clusterId int64) (*nodeconfigs. } // 修改集群的TOA设置 -func (this *NodeClusterDAO) UpdateClusterTOA(clusterId int64, toaJSON []byte) error { +func (this *NodeClusterDAO) UpdateClusterTOA(tx *dbs.Tx, clusterId int64, toaJSON []byte) error { if clusterId <= 0 { return errors.New("invalid clusterId") } op := NewNodeClusterOperator() op.Id = clusterId op.Toa = toaJSON - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算使用某个缓存策略的集群数量 -func (this *NodeClusterDAO) CountAllEnabledNodeClustersWithHTTPCachePolicyId(httpCachePolicyId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) CountAllEnabledNodeClustersWithHTTPCachePolicyId(tx *dbs.Tx, httpCachePolicyId int64) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Attr("cachePolicyId", httpCachePolicyId). Count() } // 查找使用缓存策略的所有集群 -func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPCachePolicyId(httpCachePolicyId int64) (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPCachePolicyId(tx *dbs.Tx, httpCachePolicyId int64) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Attr("cachePolicyId", httpCachePolicyId). DescPk(). @@ -600,16 +600,16 @@ func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPCachePolicyId(http } // 计算使用某个WAF策略的集群数量 -func (this *NodeClusterDAO) CountAllEnabledNodeClustersWithHTTPFirewallPolicyId(httpFirewallPolicyId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) CountAllEnabledNodeClustersWithHTTPFirewallPolicyId(tx *dbs.Tx, httpFirewallPolicyId int64) (int64, error) { + return this.Query(tx). State(NodeClusterStateEnabled). Attr("httpFirewallPolicyId", httpFirewallPolicyId). Count() } // 查找使用WAF策略的所有集群 -func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(httpFirewallPolicyId int64) (result []*NodeCluster, err error) { - _, err = this.Query(). +func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(tx *dbs.Tx, httpFirewallPolicyId int64) (result []*NodeCluster, err error) { + _, err = this.Query(tx). State(NodeClusterStateEnabled). Attr("httpFirewallPolicyId", httpFirewallPolicyId). DescPk(). @@ -619,16 +619,16 @@ func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(h } // 获取集群的WAF策略ID -func (this *NodeClusterDAO) FindClusterHTTPFirewallPolicyId(clusterId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) FindClusterHTTPFirewallPolicyId(tx *dbs.Tx, clusterId int64) (int64, error) { + return this.Query(tx). Pk(clusterId). Result("httpFirewallPolicyId"). FindInt64Col(0) } // 设置集群的缓存策略 -func (this *NodeClusterDAO) UpdateNodeClusterHTTPCachePolicyId(clusterId int64, httpCachePolicyId int64) error { - _, err := this.Query(). +func (this *NodeClusterDAO) UpdateNodeClusterHTTPCachePolicyId(tx *dbs.Tx, clusterId int64, httpCachePolicyId int64) error { + _, err := this.Query(tx). Pk(clusterId). Set("cachePolicyId", httpCachePolicyId). Update() @@ -636,16 +636,16 @@ func (this *NodeClusterDAO) UpdateNodeClusterHTTPCachePolicyId(clusterId int64, } // 获取集群的缓存策略ID -func (this *NodeClusterDAO) FindClusterHTTPCachePolicyId(clusterId int64) (int64, error) { - return this.Query(). +func (this *NodeClusterDAO) FindClusterHTTPCachePolicyId(tx *dbs.Tx, clusterId int64) (int64, error) { + return this.Query(tx). Pk(clusterId). Result("cachePolicyId"). FindInt64Col(0) } // 设置集群的WAF策略 -func (this *NodeClusterDAO) UpdateNodeClusterHTTPFirewallPolicyId(clusterId int64, httpFirewallPolicyId int64) error { - _, err := this.Query(). +func (this *NodeClusterDAO) UpdateNodeClusterHTTPFirewallPolicyId(tx *dbs.Tx, clusterId int64, httpFirewallPolicyId int64) error { + _, err := this.Query(tx). Pk(clusterId). Set("httpFirewallPolicyId", httpFirewallPolicyId). Update() @@ -653,10 +653,10 @@ func (this *NodeClusterDAO) UpdateNodeClusterHTTPFirewallPolicyId(clusterId int6 } // 生成唯一ID -func (this *NodeClusterDAO) genUniqueId() (string, error) { +func (this *NodeClusterDAO) genUniqueId(tx *dbs.Tx) (string, error) { for { uniqueId := rands.HexString(32) - ok, err := this.Query(). + ok, err := this.Query(tx). Attr("uniqueId", uniqueId). Exist() if err != nil { diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 13eb815b..8d03854f 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -42,16 +42,16 @@ func init() { } // 启用条目 -func (this *NodeDAO) EnableNode(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeDAO) EnableNode(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeStateEnabled). Update() } // 禁用条目 -func (this *NodeDAO) DisableNode(id int64) (err error) { - _, err = this.Query(). +func (this *NodeDAO) DisableNode(tx *dbs.Tx, id int64) (err error) { + _, err = this.Query(tx). Pk(id). Set("state", NodeStateDisabled). Update() @@ -59,8 +59,8 @@ func (this *NodeDAO) DisableNode(id int64) (err error) { } // 查找启用中的条目 -func (this *NodeDAO) FindEnabledNode(id int64) (*Node, error) { - result, err := this.Query(). +func (this *NodeDAO) FindEnabledNode(tx *dbs.Tx, id int64) (*Node, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeStateEnabled). Find() @@ -71,8 +71,8 @@ func (this *NodeDAO) FindEnabledNode(id int64) (*Node, error) { } // 根据主键查找名称 -func (this *NodeDAO) FindNodeName(id uint32) (string, error) { - name, err := this.Query(). +func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id uint32) (string, error) { + name, err := this.Query(tx). Pk(id). Result("name"). FindCol("") @@ -80,8 +80,8 @@ func (this *NodeDAO) FindNodeName(id uint32) (string, error) { } // 创建节点 -func (this *NodeDAO) CreateNode(adminId int64, name string, clusterId int64, groupId int64, regionId int64) (nodeId int64, err error) { - uniqueId, err := this.genUniqueId() +func (this *NodeDAO) CreateNode(tx *dbs.Tx, adminId int64, name string, clusterId int64, groupId int64, regionId int64) (nodeId int64, err error) { + uniqueId, err := this.genUniqueId(tx) if err != nil { return 0, err } @@ -89,7 +89,7 @@ func (this *NodeDAO) CreateNode(adminId int64, name string, clusterId int64, gro secret := rands.String(32) // 保存API Token - err = SharedApiTokenDAO.CreateAPIToken(uniqueId, secret, NodeRoleNode) + err = SharedApiTokenDAO.CreateAPIToken(tx, uniqueId, secret, NodeRoleNode) if err != nil { return } @@ -104,7 +104,7 @@ func (this *NodeDAO) CreateNode(adminId int64, name string, clusterId int64, gro op.RegionId = regionId op.IsOn = 1 op.State = NodeStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -113,7 +113,7 @@ func (this *NodeDAO) CreateNode(adminId int64, name string, clusterId int64, gro } // 修改节点 -func (this *NodeDAO) UpdateNode(nodeId int64, name string, clusterId int64, groupId int64, regionId int64, maxCPU int32, isOn bool) error { +func (this *NodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, name string, clusterId int64, groupId int64, regionId int64, maxCPU int32, isOn bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -126,25 +126,25 @@ func (this *NodeDAO) UpdateNode(nodeId int64, name string, clusterId int64, grou op.LatestVersion = dbs.SQL("latestVersion+1") op.MaxCPU = maxCPU op.IsOn = isOn - err := this.Save(op) + err := this.Save(tx, op) return err } // 更新节点版本 -func (this *NodeDAO) UpdateNodeLatestVersion(nodeId int64) error { +func (this *NodeDAO) UpdateNodeLatestVersion(tx *dbs.Tx, nodeId int64) error { if nodeId <= 0 { return errors.New("invalid nodeId") } op := NewNodeOperator() op.Id = nodeId op.LatestVersion = dbs.SQL("latestVersion+1") - err := this.Save(op) + err := this.Save(tx, op) return err } // 批量更新节点版本 -func (this *NodeDAO) IncreaseAllNodesLatestVersionMatch(clusterId int64) error { - _, err := this.Query(). +func (this *NodeDAO) IncreaseAllNodesLatestVersionMatch(tx *dbs.Tx, clusterId int64) error { + _, err := this.Query(tx). Attr("clusterId", clusterId). Set("latestVersion", dbs.SQL("latestVersion+1")). Update() @@ -153,11 +153,11 @@ func (this *NodeDAO) IncreaseAllNodesLatestVersionMatch(clusterId int64) error { } // 同步集群中的节点版本 -func (this *NodeDAO) SyncNodeVersionsWithCluster(clusterId int64) error { +func (this *NodeDAO) SyncNodeVersionsWithCluster(tx *dbs.Tx, clusterId int64) error { if clusterId <= 0 { return errors.New("invalid cluster") } - _, err := this.Query(). + _, err := this.Query(tx). Attr("clusterId", clusterId). Set("version", dbs.SQL("latestVersion")). Update() @@ -165,8 +165,8 @@ func (this *NodeDAO) SyncNodeVersionsWithCluster(clusterId int64) error { } // 取得有变更的集群 -func (this *NodeDAO) FindChangedClusterIds() ([]int64, error) { - ones, _, err := this.Query(). +func (this *NodeDAO) FindChangedClusterIds(tx *dbs.Tx) ([]int64, error) { + ones, _, err := this.Query(tx). State(NodeStateEnabled). Gt("latestVersion", 0). Where("version!=latestVersion"). @@ -183,15 +183,15 @@ func (this *NodeDAO) FindChangedClusterIds() ([]int64, error) { } // 计算所有节点数量 -func (this *NodeDAO) CountAllEnabledNodes() (int64, error) { - return this.Query(). +func (this *NodeDAO) CountAllEnabledNodes(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Count() } // 列出单页节点 -func (this *NodeDAO) ListEnabledNodesMatch(offset int64, size int64, clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string, groupId int64, regionId int64) (result []*Node, err error) { - query := this.Query(). +func (this *NodeDAO) ListEnabledNodesMatch(tx *dbs.Tx, offset int64, size int64, clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string, groupId int64, regionId int64) (result []*Node, err error) { + query := this.Query(tx). State(NodeStateEnabled). Offset(offset). Limit(size). @@ -244,8 +244,8 @@ func (this *NodeDAO) ListEnabledNodesMatch(offset int64, size int64, clusterId i } // 根据节点ID和密钥查询节点 -func (this *NodeDAO) FindEnabledNodeWithUniqueIdAndSecret(uniqueId string, secret string) (*Node, error) { - one, err := this.Query(). +func (this *NodeDAO) FindEnabledNodeWithUniqueIdAndSecret(tx *dbs.Tx, uniqueId string, secret string) (*Node, error) { + one, err := this.Query(tx). Attr("uniqueId", uniqueId). Attr("secret", secret). State(NodeStateEnabled). @@ -259,8 +259,8 @@ func (this *NodeDAO) FindEnabledNodeWithUniqueIdAndSecret(uniqueId string, secre } // 根据节点ID获取节点 -func (this *NodeDAO) FindEnabledNodeWithUniqueId(uniqueId string) (*Node, error) { - one, err := this.Query(). +func (this *NodeDAO) FindEnabledNodeWithUniqueId(tx *dbs.Tx, uniqueId string) (*Node, error) { + one, err := this.Query(tx). Attr("uniqueId", uniqueId). State(NodeStateEnabled). Find() @@ -273,8 +273,8 @@ func (this *NodeDAO) FindEnabledNodeWithUniqueId(uniqueId string) (*Node, error) } // 获取节点集群ID -func (this *NodeDAO) FindNodeClusterId(nodeId int64) (int64, error) { - col, err := this.Query(). +func (this *NodeDAO) FindNodeClusterId(tx *dbs.Tx, nodeId int64) (int64, error) { + col, err := this.Query(tx). Pk(nodeId). Result("clusterId"). FindCol(0) @@ -282,8 +282,8 @@ func (this *NodeDAO) FindNodeClusterId(nodeId int64) (int64, error) { } // 匹配节点并返回节点ID -func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err error) { - query := this.Query() +func (this *NodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64) (result []int64, err error) { + query := this.Query(tx) query.State(NodeStateEnabled) if clusterId > 0 { query.Attr("clusterId", clusterId) @@ -300,8 +300,8 @@ func (this *NodeDAO) FindAllNodeIdsMatch(clusterId int64) (result []int64, err e } // 获取一个集群的所有节点 -func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllEnabledNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). DescPk(). @@ -311,8 +311,8 @@ func (this *NodeDAO) FindAllEnabledNodesWithClusterId(clusterId int64) (result [ } // 取得一个集群离线的节点 -func (this *NodeDAO) FindAllInactiveNodesWithClusterId(clusterId int64) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllInactiveNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Attr("isOn", true). // 只监控启用的节点 @@ -326,8 +326,8 @@ func (this *NodeDAO) FindAllInactiveNodesWithClusterId(clusterId int64) (result } // 计算节点数量 -func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string, groupId int64, regionId int64) (int64, error) { - query := this.Query() +func (this *NodeDAO) CountAllEnabledNodesMatch(tx *dbs.Tx, clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string, groupId int64, regionId int64) (int64, error) { + query := this.Query(tx) query.State(NodeStateEnabled) // 集群 @@ -375,8 +375,8 @@ func (this *NodeDAO) CountAllEnabledNodesMatch(clusterId int64, installState con } // 更改节点状态 -func (this *NodeDAO) UpdateNodeStatus(nodeId int64, statusJSON []byte) error { - _, err := this.Query(). +func (this *NodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, statusJSON []byte) error { + _, err := this.Query(tx). Pk(nodeId). Set("status", string(statusJSON)). Update() @@ -384,12 +384,12 @@ func (this *NodeDAO) UpdateNodeStatus(nodeId int64, statusJSON []byte) error { } // 更改节点在线状态 -func (this *NodeDAO) UpdateNodeIsActive(nodeId int64, isActive bool) error { +func (this *NodeDAO) UpdateNodeIsActive(tx *dbs.Tx, nodeId int64, isActive bool) error { b := "true" if !isActive { b = "false" } - _, err := this.Query(). + _, err := this.Query(tx). Pk(nodeId). Where("status IS NOT NULL"). Set("status", dbs.SQL("JSON_SET(status, '$.isActive', "+b+")")). @@ -398,8 +398,8 @@ func (this *NodeDAO) UpdateNodeIsActive(nodeId int64, isActive bool) error { } // 设置节点安装状态 -func (this *NodeDAO) UpdateNodeIsInstalled(nodeId int64, isInstalled bool) error { - _, err := this.Query(). +func (this *NodeDAO) UpdateNodeIsInstalled(tx *dbs.Tx, nodeId int64, isInstalled bool) error { + _, err := this.Query(tx). Pk(nodeId). Set("isInstalled", isInstalled). Set("installStatus", "null"). // 重置安装状态 @@ -408,8 +408,8 @@ func (this *NodeDAO) UpdateNodeIsInstalled(nodeId int64, isInstalled bool) error } // 查询节点的安装状态 -func (this *NodeDAO) FindNodeInstallStatus(nodeId int64) (*NodeInstallStatus, error) { - node, err := this.Query(). +func (this *NodeDAO) FindNodeInstallStatus(tx *dbs.Tx, nodeId int64) (*NodeInstallStatus, error) { + node, err := this.Query(tx). Pk(nodeId). Result("installStatus", "isInstalled"). Find() @@ -439,9 +439,9 @@ func (this *NodeDAO) FindNodeInstallStatus(nodeId int64) (*NodeInstallStatus, er } // 修改节点的安装状态 -func (this *NodeDAO) UpdateNodeInstallStatus(nodeId int64, status *NodeInstallStatus) error { +func (this *NodeDAO) UpdateNodeInstallStatus(tx *dbs.Tx, nodeId int64, status *NodeInstallStatus) error { if status == nil { - _, err := this.Query(). + _, err := this.Query(tx). Pk(nodeId). Set("installStatus", "null"). Update() @@ -452,7 +452,7 @@ func (this *NodeDAO) UpdateNodeInstallStatus(nodeId int64, status *NodeInstallSt if err != nil { return err } - _, err = this.Query(). + _, err = this.Query(tx). Pk(nodeId). Set("installStatus", string(data)). Update() @@ -461,8 +461,8 @@ func (this *NodeDAO) UpdateNodeInstallStatus(nodeId int64, status *NodeInstallSt // 组合配置 // TODO 提升运行速度 -func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, error) { - node, err := this.FindEnabledNode(nodeId) +func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.NodeConfig, error) { + node, err := this.FindEnabledNode(tx, nodeId) if err != nil { return nil, err } @@ -482,7 +482,7 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e } // 获取所有的服务 - servers, err := SharedServerDAO.FindAllEnabledServersWithNode(int64(node.Id)) + servers, err := SharedServerDAO.FindAllEnabledServersWithNode(tx, int64(node.Id)) if err != nil { return nil, err } @@ -502,7 +502,7 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e // 全局设置 // TODO 根据用户的不同读取不同的全局设置 - settingJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeServerGlobalConfig) + settingJSON, err := SharedSysSettingDAO.ReadSetting(tx, SettingCodeServerGlobalConfig) if err != nil { return nil, err } @@ -517,12 +517,12 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e // WAF clusterId := int64(node.ClusterId) - httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(clusterId) + httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) if err != nil { return nil, err } if httpFirewallPolicyId > 0 { - firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(httpFirewallPolicyId) + firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId) if err != nil { return nil, err } @@ -532,12 +532,12 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e } // 缓存策略 - httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(clusterId) + httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) if err != nil { return nil, err } if httpCachePolicyId > 0 { - cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(httpCachePolicyId) + cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId) if err != nil { return nil, err } @@ -547,7 +547,7 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e } // TOA - toaConfig, err := SharedNodeClusterDAO.FindClusterTOAConfig(clusterId) + toaConfig, err := SharedNodeClusterDAO.FindClusterTOAConfig(tx, clusterId) if err != nil { return nil, err } @@ -557,7 +557,7 @@ func (this *NodeDAO) ComposeNodeConfig(nodeId int64) (*nodeconfigs.NodeConfig, e } // 修改当前连接的API节点 -func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int64) error { +func (this *NodeDAO) UpdateNodeConnectedAPINodes(tx *dbs.Tx, nodeId int64, apiNodeIds []int64) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -574,14 +574,14 @@ func (this *NodeDAO) UpdateNodeConnectedAPINodes(nodeId int64, apiNodeIds []int6 } else { op.ConnectedAPINodes = "[]" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 根据UniqueId获取ID // TODO 增加缓存 -func (this *NodeDAO) FindEnabledNodeIdWithUniqueId(uniqueId string) (int64, error) { - return this.Query(). +func (this *NodeDAO) FindEnabledNodeIdWithUniqueId(tx *dbs.Tx, uniqueId string) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Attr("uniqueId", uniqueId). ResultPk(). @@ -589,8 +589,8 @@ func (this *NodeDAO) FindEnabledNodeIdWithUniqueId(uniqueId string) (int64, erro } // 计算使用某个认证的节点数量 -func (this *NodeDAO) CountAllEnabledNodesWithGrantId(grantId int64) (int64, error) { - return this.Query(). +func (this *NodeDAO) CountAllEnabledNodesWithGrantId(tx *dbs.Tx, grantId int64) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Where("id IN (SELECT nodeId FROM edgeNodeLogins WHERE type='ssh' AND JSON_CONTAINS(params, :grantParam))"). Param("grantParam", string(maps.Map{"grantId": grantId}.AsJSON())). @@ -599,8 +599,8 @@ func (this *NodeDAO) CountAllEnabledNodesWithGrantId(grantId int64) (int64, erro } // 查找使用某个认证的所有节点 -func (this *NodeDAO) FindAllEnabledNodesWithGrantId(grantId int64) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllEnabledNodesWithGrantId(tx *dbs.Tx, grantId int64) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Where("id IN (SELECT nodeId FROM edgeNodeLogins WHERE type='ssh' AND JSON_CONTAINS(params, :grantParam))"). Param("grantParam", string(maps.Map{"grantId": grantId}.AsJSON())). @@ -612,8 +612,8 @@ func (this *NodeDAO) FindAllEnabledNodesWithGrantId(grantId int64) (result []*No } // 查找所有未安装的节点 -func (this *NodeDAO) FindAllNotInstalledNodesWithClusterId(clusterId int64) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllNotInstalledNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Attr("isInstalled", false). @@ -624,8 +624,8 @@ func (this *NodeDAO) FindAllNotInstalledNodesWithClusterId(clusterId int64) (res } // 计算所有低于某个版本的节点数量 -func (this *NodeDAO) CountAllLowerVersionNodesWithClusterId(clusterId int64, os string, arch string, version string) (int64, error) { - return this.Query(). +func (this *NodeDAO) CountAllLowerVersionNodesWithClusterId(tx *dbs.Tx, clusterId int64, os string, arch string, version string) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Where("status IS NOT NULL"). @@ -639,8 +639,8 @@ func (this *NodeDAO) CountAllLowerVersionNodesWithClusterId(clusterId int64, os } // 查找所有低于某个版本的节点 -func (this *NodeDAO) FindAllLowerVersionNodesWithClusterId(clusterId int64, os string, arch string, version string) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllLowerVersionNodesWithClusterId(tx *dbs.Tx, clusterId int64, os string, arch string, version string) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Where("status IS NOT NULL"). @@ -657,24 +657,24 @@ func (this *NodeDAO) FindAllLowerVersionNodesWithClusterId(clusterId int64, os s } // 查找某个节点分组下的所有节点数量 -func (this *NodeDAO) CountAllEnabledNodesWithGroupId(groupId int64) (int64, error) { - return this.Query(). +func (this *NodeDAO) CountAllEnabledNodesWithGroupId(tx *dbs.Tx, groupId int64) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Attr("groupId", groupId). Count() } // 查找某个节点区域下的所有节点数量 -func (this *NodeDAO) CountAllEnabledNodesWithRegionId(regionId int64) (int64, error) { - return this.Query(). +func (this *NodeDAO) CountAllEnabledNodesWithRegionId(tx *dbs.Tx, regionId int64) (int64, error) { + return this.Query(tx). State(NodeStateEnabled). Attr("regionId", regionId). Count() } // 获取一个集群的节点DNS信息 -func (this *NodeDAO) FindAllEnabledNodesDNSWithClusterId(clusterId int64) (result []*Node, err error) { - _, err = this.Query(). +func (this *NodeDAO) FindAllEnabledNodesDNSWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Node, err error) { + _, err = this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Attr("isOn", true). @@ -687,8 +687,8 @@ func (this *NodeDAO) FindAllEnabledNodesDNSWithClusterId(clusterId int64) (resul } // 计算一个集群的节点DNS数量 -func (this *NodeDAO) CountAllEnabledNodesDNSWithClusterId(clusterId int64) (result int64, err error) { - return this.Query(). +func (this *NodeDAO) CountAllEnabledNodesDNSWithClusterId(tx *dbs.Tx, clusterId int64) (result int64, err error) { + return this.Query(tx). State(NodeStateEnabled). Attr("clusterId", clusterId). Attr("isOn", true). @@ -700,8 +700,8 @@ func (this *NodeDAO) CountAllEnabledNodesDNSWithClusterId(clusterId int64) (resu } // 获取单个节点的DNS信息 -func (this *NodeDAO) FindEnabledNodeDNS(nodeId int64) (*Node, error) { - one, err := this.Query(). +func (this *NodeDAO) FindEnabledNodeDNS(tx *dbs.Tx, nodeId int64) (*Node, error) { + one, err := this.Query(tx). State(NodeStateEnabled). Pk(nodeId). Result("id", "name", "dnsRoutes", "clusterId", "isOn"). @@ -713,7 +713,7 @@ func (this *NodeDAO) FindEnabledNodeDNS(nodeId int64) (*Node, error) { } // 修改节点的DNS信息 -func (this *NodeDAO) UpdateNodeDNS(nodeId int64, routes map[int64][]string) error { +func (this *NodeDAO) UpdateNodeDNS(tx *dbs.Tx, nodeId int64, routes map[int64][]string) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -727,16 +727,16 @@ func (this *NodeDAO) UpdateNodeDNS(nodeId int64, routes map[int64][]string) erro op := NewNodeOperator() op.Id = nodeId op.DnsRoutes = routesJSON - err = this.Save(op) + err = this.Save(tx, op) return err } // 计算节点上线|下线状态 -func (this *NodeDAO) UpdateNodeUp(nodeId int64, isUp bool, maxUp int, maxDown int) (changed bool, err error) { +func (this *NodeDAO) UpdateNodeUp(tx *dbs.Tx, nodeId int64, isUp bool, maxUp int, maxDown int) (changed bool, err error) { if nodeId <= 0 { return false, errors.New("invalid nodeId") } - one, err := this.Query(). + one, err := this.Query(tx). Pk(nodeId). Result("isUp", "countUp", "countDown"). Find() @@ -779,7 +779,7 @@ func (this *NodeDAO) UpdateNodeUp(nodeId int64, isUp bool, maxUp int, maxDown in op.CountUp = countUp op.CountDown = countDown - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return false, err } @@ -787,11 +787,11 @@ func (this *NodeDAO) UpdateNodeUp(nodeId int64, isUp bool, maxUp int, maxDown in } // 修改节点活跃状态 -func (this *NodeDAO) UpdateNodeActive(nodeId int64, isActive bool) error { +func (this *NodeDAO) UpdateNodeActive(tx *dbs.Tx, nodeId int64, isActive bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(nodeId). Set("isActive", isActive). Update() @@ -799,8 +799,8 @@ func (this *NodeDAO) UpdateNodeActive(nodeId int64, isActive bool) error { } // 检查节点活跃状态 -func (this *NodeDAO) FindNodeActive(nodeId int64) (bool, error) { - isActive, err := this.Query(). +func (this *NodeDAO) FindNodeActive(tx *dbs.Tx, nodeId int64) (bool, error) { + isActive, err := this.Query(tx). Pk(nodeId). Result("isActive"). FindIntCol(0) @@ -811,18 +811,18 @@ func (this *NodeDAO) FindNodeActive(nodeId int64) (bool, error) { } // 查找节点的版本号 -func (this *NodeDAO) FindNodeVersion(nodeId int64) (int64, error) { - return this.Query(). +func (this *NodeDAO) FindNodeVersion(tx *dbs.Tx, nodeId int64) (int64, error) { + return this.Query(tx). Pk(nodeId). Result("version"). FindInt64Col(0) } // 生成唯一ID -func (this *NodeDAO) genUniqueId() (string, error) { +func (this *NodeDAO) genUniqueId(tx *dbs.Tx) (string, error) { for { uniqueId := rands.HexString(32) - ok, err := this.Query(). + ok, err := this.Query(tx). Attr("uniqueId", uniqueId). Exist() if err != nil { diff --git a/internal/db/models/node_grant_dao.go b/internal/db/models/node_grant_dao.go index 3c32bf56..0dba7013 100644 --- a/internal/db/models/node_grant_dao.go +++ b/internal/db/models/node_grant_dao.go @@ -35,16 +35,16 @@ func init() { } // 启用条目 -func (this *NodeGrantDAO) EnableNodeGrant(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeGrantDAO) EnableNodeGrant(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeGrantStateEnabled). Update() } // 禁用条目 -func (this *NodeGrantDAO) DisableNodeGrant(id int64) (err error) { - _, err = this.Query(). +func (this *NodeGrantDAO) DisableNodeGrant(tx *dbs.Tx, id int64) (err error) { + _, err = this.Query(tx). Pk(id). Set("state", NodeGrantStateDisabled). Update() @@ -52,8 +52,8 @@ func (this *NodeGrantDAO) DisableNodeGrant(id int64) (err error) { } // 查找启用中的条目 -func (this *NodeGrantDAO) FindEnabledNodeGrant(id int64) (*NodeGrant, error) { - result, err := this.Query(). +func (this *NodeGrantDAO) FindEnabledNodeGrant(tx *dbs.Tx, id int64) (*NodeGrant, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeGrantStateEnabled). Find() @@ -64,8 +64,8 @@ func (this *NodeGrantDAO) FindEnabledNodeGrant(id int64) (*NodeGrant, error) { } // 根据主键查找名称 -func (this *NodeGrantDAO) FindNodeGrantName(id uint32) (string, error) { - name, err := this.Query(). +func (this *NodeGrantDAO) FindNodeGrantName(tx *dbs.Tx, id uint32) (string, error) { + name, err := this.Query(tx). Pk(id). Result("name"). FindCol("") @@ -73,7 +73,7 @@ func (this *NodeGrantDAO) FindNodeGrantName(id uint32) (string, error) { } // 创建认证信息 -func (this *NodeGrantDAO) CreateGrant(adminId int64, name string, method string, username string, password string, privateKey string, description string, nodeId int64) (grantId int64, err error) { +func (this *NodeGrantDAO) CreateGrant(tx *dbs.Tx, adminId int64, name string, method string, username string, password string, privateKey string, description string, nodeId int64) (grantId int64, err error) { op := NewNodeGrantOperator() op.AdminId = adminId op.Name = name @@ -90,12 +90,12 @@ func (this *NodeGrantDAO) CreateGrant(adminId int64, name string, method string, op.Description = description op.NodeId = nodeId op.State = NodeGrantStateEnabled - err = this.Save(op) + err = this.Save(tx, op) return types.Int64(op.Id), err } // 修改认证信息 -func (this *NodeGrantDAO) UpdateGrant(grantId int64, name string, method string, username string, password string, privateKey string, description string, nodeId int64) error { +func (this *NodeGrantDAO) UpdateGrant(tx *dbs.Tx, grantId int64, name string, method string, username string, password string, privateKey string, description string, nodeId int64) error { if grantId <= 0 { return errors.New("invalid grantId") } @@ -115,20 +115,20 @@ func (this *NodeGrantDAO) UpdateGrant(grantId int64, name string, method string, } op.Description = description op.NodeId = nodeId - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算所有认证信息数量 -func (this *NodeGrantDAO) CountAllEnabledGrants() (int64, error) { - return this.Query(). +func (this *NodeGrantDAO) CountAllEnabledGrants(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(NodeGrantStateEnabled). Count() } // 列出单页的认证信息 -func (this *NodeGrantDAO) ListEnabledGrants(offset int64, size int64) (result []*NodeGrant, err error) { - _, err = this.Query(). +func (this *NodeGrantDAO) ListEnabledGrants(tx *dbs.Tx, offset int64, size int64) (result []*NodeGrant, err error) { + _, err = this.Query(tx). State(NodeGrantStateEnabled). Offset(offset). Size(size). @@ -139,8 +139,8 @@ func (this *NodeGrantDAO) ListEnabledGrants(offset int64, size int64) (result [] } // 列出所有的认证信息 -func (this *NodeGrantDAO) FindAllEnabledGrants() (result []*NodeGrant, err error) { - _, err = this.Query(). +func (this *NodeGrantDAO) FindAllEnabledGrants(tx *dbs.Tx) (result []*NodeGrant, err error) { + _, err = this.Query(tx). State(NodeGrantStateEnabled). DescPk(). Slice(&result). diff --git a/internal/db/models/node_group_dao.go b/internal/db/models/node_group_dao.go index 36e91bda..68eb83bb 100644 --- a/internal/db/models/node_group_dao.go +++ b/internal/db/models/node_group_dao.go @@ -35,24 +35,24 @@ func init() { } // 启用条目 -func (this *NodeGroupDAO) EnableNodeGroup(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeGroupDAO) EnableNodeGroup(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeGroupStateEnabled). Update() } // 禁用条目 -func (this *NodeGroupDAO) DisableNodeGroup(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeGroupDAO) DisableNodeGroup(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeGroupStateDisabled). Update() } // 查找启用中的条目 -func (this *NodeGroupDAO) FindEnabledNodeGroup(id int64) (*NodeGroup, error) { - result, err := this.Query(). +func (this *NodeGroupDAO) FindEnabledNodeGroup(tx *dbs.Tx, id int64) (*NodeGroup, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeGroupStateEnabled). Find() @@ -63,8 +63,8 @@ func (this *NodeGroupDAO) FindEnabledNodeGroup(id int64) (*NodeGroup, error) { } // 根据主键查找名称 -func (this *NodeGroupDAO) FindNodeGroupName(id int64) (string, error) { - name, err := this.Query(). +func (this *NodeGroupDAO) FindNodeGroupName(tx *dbs.Tx, id int64) (string, error) { + name, err := this.Query(tx). Pk(id). Result("name"). FindCol("") @@ -72,12 +72,12 @@ func (this *NodeGroupDAO) FindNodeGroupName(id int64) (string, error) { } // 创建分组 -func (this *NodeGroupDAO) CreateNodeGroup(clusterId int64, name string) (int64, error) { +func (this *NodeGroupDAO) CreateNodeGroup(tx *dbs.Tx, clusterId int64, name string) (int64, error) { op := NewNodeGroupOperator() op.ClusterId = clusterId op.Name = name op.State = NodeGroupStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -85,20 +85,20 @@ func (this *NodeGroupDAO) CreateNodeGroup(clusterId int64, name string) (int64, } // 修改分组 -func (this *NodeGroupDAO) UpdateNodeGroup(groupId int64, name string) error { +func (this *NodeGroupDAO) UpdateNodeGroup(tx *dbs.Tx, groupId int64, name string) error { if groupId <= 0 { return errors.New("invalid groupId") } op := NewNodeGroupOperator() op.Id = groupId op.Name = name - err := this.Save(op) + err := this.Save(tx, op) return err } // 查询所有分组 -func (this *NodeGroupDAO) FindAllEnabledGroupsWithClusterId(clusterId int64) (result []*NodeGroup, err error) { - _, err = this.Query(). +func (this *NodeGroupDAO) FindAllEnabledGroupsWithClusterId(tx *dbs.Tx, clusterId int64) (result []*NodeGroup, err error) { + _, err = this.Query(tx). State(NodeGroupStateEnabled). Attr("clusterId", clusterId). Desc("order"). @@ -109,9 +109,9 @@ func (this *NodeGroupDAO) FindAllEnabledGroupsWithClusterId(clusterId int64) (re } // 保存排序 -func (this *NodeGroupDAO) UpdateGroupOrders(groupIds []int64) error { +func (this *NodeGroupDAO) UpdateGroupOrders(tx *dbs.Tx, groupIds []int64) error { for index, groupId := range groupIds { - _, err := this.Query(). + _, err := this.Query(tx). Pk(groupId). Set("order", len(groupIds)-index). Update() diff --git a/internal/db/models/node_ip_address_dao.go b/internal/db/models/node_ip_address_dao.go index 67a2686b..a976774d 100644 --- a/internal/db/models/node_ip_address_dao.go +++ b/internal/db/models/node_ip_address_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *NodeIPAddressDAO) EnableAddress(id int64) (err error) { - _, err = this.Query(). +func (this *NodeIPAddressDAO) EnableAddress(tx *dbs.Tx, id int64) (err error) { + _, err = this.Query(tx). Pk(id). Set("state", NodeIPAddressStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *NodeIPAddressDAO) EnableAddress(id int64) (err error) { } // 禁用IP地址 -func (this *NodeIPAddressDAO) DisableAddress(id int64) (err error) { - _, err = this.Query(). +func (this *NodeIPAddressDAO) DisableAddress(tx *dbs.Tx, id int64) (err error) { + _, err = this.Query(tx). Pk(id). Set("state", NodeIPAddressStateDisabled). Update() @@ -53,11 +53,11 @@ func (this *NodeIPAddressDAO) DisableAddress(id int64) (err error) { } // 禁用节点的所有的IP地址 -func (this *NodeIPAddressDAO) DisableAllAddressesWithNodeId(nodeId int64) error { +func (this *NodeIPAddressDAO) DisableAllAddressesWithNodeId(tx *dbs.Tx, nodeId int64) error { if nodeId <= 0 { return errors.New("invalid nodeId") } - _, err := this.Query(). + _, err := this.Query(tx). Attr("nodeId", nodeId). Set("state", NodeIPAddressStateDisabled). Update() @@ -65,8 +65,8 @@ func (this *NodeIPAddressDAO) DisableAllAddressesWithNodeId(nodeId int64) error } // 查找启用中的IP地址 -func (this *NodeIPAddressDAO) FindEnabledAddress(id int64) (*NodeIPAddress, error) { - result, err := this.Query(). +func (this *NodeIPAddressDAO) FindEnabledAddress(tx *dbs.Tx, id int64) (*NodeIPAddress, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeIPAddressStateEnabled). Find() @@ -77,22 +77,22 @@ func (this *NodeIPAddressDAO) FindEnabledAddress(id int64) (*NodeIPAddress, erro } // 根据主键查找名称 -func (this *NodeIPAddressDAO) FindAddressName(id int64) (string, error) { - return this.Query(). +func (this *NodeIPAddressDAO) FindAddressName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建IP地址 -func (this *NodeIPAddressDAO) CreateAddress(nodeId int64, name string, ip string, canAccess bool) (addressId int64, err error) { +func (this *NodeIPAddressDAO) CreateAddress(tx *dbs.Tx, nodeId int64, name string, ip string, canAccess bool) (addressId int64, err error) { op := NewNodeIPAddressOperator() op.NodeId = nodeId op.Name = name op.Ip = ip op.CanAccess = canAccess op.State = NodeIPAddressStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -100,7 +100,7 @@ func (this *NodeIPAddressDAO) CreateAddress(nodeId int64, name string, ip string } // 修改IP地址 -func (this *NodeIPAddressDAO) UpdateAddress(addressId int64, name string, ip string, canAccess bool) (err error) { +func (this *NodeIPAddressDAO) UpdateAddress(tx *dbs.Tx, addressId int64, name string, ip string, canAccess bool) (err error) { if addressId <= 0 { return errors.New("invalid addressId") } @@ -111,25 +111,25 @@ func (this *NodeIPAddressDAO) UpdateAddress(addressId int64, name string, ip str op.Ip = ip op.CanAccess = canAccess op.State = NodeIPAddressStateEnabled // 恢复状态 - err = this.Save(op) + err = this.Save(tx, op) return err } // 修改IP地址中的IP -func (this *NodeIPAddressDAO) UpdateAddressIP(addressId int64, ip string) error { +func (this *NodeIPAddressDAO) UpdateAddressIP(tx *dbs.Tx, addressId int64, ip string) error { if addressId <= 0 { return errors.New("invalid addressId") } op := NewNodeIPAddressOperator() op.Id = addressId op.Ip = ip - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改IP地址所属节点 -func (this *NodeIPAddressDAO) UpdateAddressNodeId(addressId int64, nodeId int64) error { - _, err := this.Query(). +func (this *NodeIPAddressDAO) UpdateAddressNodeId(tx *dbs.Tx, addressId int64, nodeId int64) error { + _, err := this.Query(tx). Pk(addressId). Set("nodeId", nodeId). Set("state", NodeIPAddressStateEnabled). // 恢复状态 @@ -138,8 +138,8 @@ func (this *NodeIPAddressDAO) UpdateAddressNodeId(addressId int64, nodeId int64) } // 查找节点的所有的IP地址 -func (this *NodeIPAddressDAO) FindAllEnabledAddressesWithNode(nodeId int64) (result []*NodeIPAddress, err error) { - _, err = this.Query(). +func (this *NodeIPAddressDAO) FindAllEnabledAddressesWithNode(tx *dbs.Tx, nodeId int64) (result []*NodeIPAddress, err error) { + _, err = this.Query(tx). Attr("nodeId", nodeId). State(NodeIPAddressStateEnabled). Desc("order"). @@ -150,8 +150,8 @@ func (this *NodeIPAddressDAO) FindAllEnabledAddressesWithNode(nodeId int64) (res } // 查找节点的第一个可访问的IP地址 -func (this *NodeIPAddressDAO) FindFirstNodeIPAddress(nodeId int64) (string, error) { - return this.Query(). +func (this *NodeIPAddressDAO) FindFirstNodeIPAddress(tx *dbs.Tx, nodeId int64) (string, error) { + return this.Query(tx). Attr("nodeId", nodeId). State(NodeIPAddressStateEnabled). Attr("canAccess", true). @@ -162,8 +162,8 @@ func (this *NodeIPAddressDAO) FindFirstNodeIPAddress(nodeId int64) (string, erro } // 查找节点的第一个可访问的IP地址ID -func (this *NodeIPAddressDAO) FindFirstNodeIPAddressId(nodeId int64) (int64, error) { - return this.Query(). +func (this *NodeIPAddressDAO) FindFirstNodeIPAddressId(tx *dbs.Tx, nodeId int64) (int64, error) { + return this.Query(tx). Attr("nodeId", nodeId). State(NodeIPAddressStateEnabled). Attr("canAccess", true). diff --git a/internal/db/models/node_log_dao.go b/internal/db/models/node_log_dao.go index 8b6f8529..25d731bc 100644 --- a/internal/db/models/node_log_dao.go +++ b/internal/db/models/node_log_dao.go @@ -34,7 +34,7 @@ func init() { } // 创建日志 -func (this *NodeLogDAO) CreateLog(nodeRole NodeRole, nodeId int64, level string, tag string, description string, createdAt int64) error { +func (this *NodeLogDAO) CreateLog(tx *dbs.Tx, nodeRole NodeRole, nodeId int64, level string, tag string, description string, createdAt int64) error { op := NewNodeLogOperator() op.Role = nodeRole op.NodeId = nodeId @@ -43,18 +43,18 @@ func (this *NodeLogDAO) CreateLog(nodeRole NodeRole, nodeId int64, level string, op.Description = description op.CreatedAt = createdAt op.Day = timeutil.FormatTime("Ymd", createdAt) - err := this.Save(op) + err := this.Save(tx, op) return err } // 清除超出一定日期的日志 -func (this *NodeLogDAO) DeleteExpiredLogs(days int) error { +func (this *NodeLogDAO) DeleteExpiredLogs(tx *dbs.Tx, days int) error { if days <= 0 { return errors.New("invalid days '" + strconv.Itoa(days) + "'") } date := time.Now().AddDate(0, 0, -days) expireDay := timeutil.Format("Ymd", date) - _, err := this.Query(). + _, err := this.Query(tx). Where("day<=:day"). Param("day", expireDay). Delete() @@ -62,16 +62,16 @@ func (this *NodeLogDAO) DeleteExpiredLogs(days int) error { } // 计算节点数量 -func (this *NodeLogDAO) CountNodeLogs(role string, nodeId int64) (int64, error) { - return this.Query(). +func (this *NodeLogDAO) CountNodeLogs(tx *dbs.Tx, role string, nodeId int64) (int64, error) { + return this.Query(tx). Attr("nodeId", nodeId). Attr("role", role). Count() } // 列出单页日志 -func (this *NodeLogDAO) ListNodeLogs(role string, nodeId int64, offset int64, size int64) (result []*NodeLog, err error) { - _, err = this.Query(). +func (this *NodeLogDAO) ListNodeLogs(tx *dbs.Tx, role string, nodeId int64, offset int64, size int64) (result []*NodeLog, err error) { + _, err = this.Query(tx). Attr("nodeId", nodeId). Attr("role", role). Offset(offset). diff --git a/internal/db/models/node_login_dao.go b/internal/db/models/node_login_dao.go index 661c3c87..a9dca32a 100644 --- a/internal/db/models/node_login_dao.go +++ b/internal/db/models/node_login_dao.go @@ -37,24 +37,24 @@ func init() { } // 启用条目 -func (this *NodeLoginDAO) EnableNodeLogin(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeLoginDAO) EnableNodeLogin(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeLoginStateEnabled). Update() } // 禁用条目 -func (this *NodeLoginDAO) DisableNodeLogin(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *NodeLoginDAO) DisableNodeLogin(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", NodeLoginStateDisabled). Update() } // 查找启用中的条目 -func (this *NodeLoginDAO) FindEnabledNodeLogin(id uint32) (*NodeLogin, error) { - result, err := this.Query(). +func (this *NodeLoginDAO) FindEnabledNodeLogin(tx *dbs.Tx, id uint32) (*NodeLogin, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeLoginStateEnabled). Find() @@ -65,8 +65,8 @@ func (this *NodeLoginDAO) FindEnabledNodeLogin(id uint32) (*NodeLogin, error) { } // 根据主键查找名称 -func (this *NodeLoginDAO) FindNodeLoginName(id uint32) (string, error) { - name, err := this.Query(). +func (this *NodeLoginDAO) FindNodeLoginName(tx *dbs.Tx, id uint32) (string, error) { + name, err := this.Query(tx). Pk(id). Result("name"). FindCol("") @@ -74,19 +74,19 @@ func (this *NodeLoginDAO) FindNodeLoginName(id uint32) (string, error) { } // 创建认证 -func (this *NodeLoginDAO) CreateNodeLogin(nodeId int64, name string, loginType string, paramsJSON []byte) (loginId int64, err error) { +func (this *NodeLoginDAO) CreateNodeLogin(tx *dbs.Tx, nodeId int64, name string, loginType string, paramsJSON []byte) (loginId int64, err error) { login := NewNodeLoginOperator() login.NodeId = nodeId login.Name = name login.Type = loginType login.Params = string(paramsJSON) login.State = NodeLoginStateEnabled - err = this.Save(login) + err = this.Save(tx, login) return types.Int64(login.Id), err } // 修改认证 -func (this *NodeLoginDAO) UpdateNodeLogin(loginId int64, name string, loginType string, paramsJSON []byte) error { +func (this *NodeLoginDAO) UpdateNodeLogin(tx *dbs.Tx, loginId int64, name string, loginType string, paramsJSON []byte) error { if loginId <= 0 { return errors.New("invalid loginId") } @@ -95,13 +95,13 @@ func (this *NodeLoginDAO) UpdateNodeLogin(loginId int64, name string, loginType login.Name = name login.Type = loginType login.Params = string(paramsJSON) - err := this.Save(login) + err := this.Save(tx, login) return err } // 查找认证 -func (this *NodeLoginDAO) FindEnabledNodeLoginWithNodeId(nodeId int64) (*NodeLogin, error) { - one, err := this.Query(). +func (this *NodeLoginDAO) FindEnabledNodeLoginWithNodeId(tx *dbs.Tx, nodeId int64) (*NodeLogin, error) { + one, err := this.Query(tx). Attr("nodeId", nodeId). State(NodeLoginStateEnabled). Find() @@ -115,8 +115,8 @@ func (this *NodeLoginDAO) FindEnabledNodeLoginWithNodeId(nodeId int64) (*NodeLog } // 禁用某个节点的认证 -func (this *NodeLoginDAO) DisableNodeLogins(nodeId int64) error { - _, err := this.Query(). +func (this *NodeLoginDAO) DisableNodeLogins(tx *dbs.Tx, nodeId int64) error { + _, err := this.Query(tx). Attr("nodeId", nodeId). Set("state", NodeLoginStateDisabled). Update() diff --git a/internal/db/models/node_price_item_dao.go b/internal/db/models/node_price_item_dao.go index 4acaf24c..ef611020 100644 --- a/internal/db/models/node_price_item_dao.go +++ b/internal/db/models/node_price_item_dao.go @@ -36,8 +36,8 @@ func init() { } // 启用条目 -func (this *NodePriceItemDAO) EnableNodePriceItem(id int64) error { - _, err := this.Query(). +func (this *NodePriceItemDAO) EnableNodePriceItem(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodePriceItemStateEnabled). Update() @@ -45,8 +45,8 @@ func (this *NodePriceItemDAO) EnableNodePriceItem(id int64) error { } // 禁用条目 -func (this *NodePriceItemDAO) DisableNodePriceItem(id int64) error { - _, err := this.Query(). +func (this *NodePriceItemDAO) DisableNodePriceItem(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodePriceItemStateDisabled). Update() @@ -54,8 +54,8 @@ func (this *NodePriceItemDAO) DisableNodePriceItem(id int64) error { } // 查找启用中的条目 -func (this *NodePriceItemDAO) FindEnabledNodePriceItem(id int64) (*NodePriceItem, error) { - result, err := this.Query(). +func (this *NodePriceItemDAO) FindEnabledNodePriceItem(tx *dbs.Tx, id int64) (*NodePriceItem, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodePriceItemStateEnabled). Find() @@ -66,15 +66,15 @@ func (this *NodePriceItemDAO) FindEnabledNodePriceItem(id int64) (*NodePriceItem } // 根据主键查找名称 -func (this *NodePriceItemDAO) FindNodePriceItemName(id int64) (string, error) { - return this.Query(). +func (this *NodePriceItemDAO) FindNodePriceItemName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建价格 -func (this *NodePriceItemDAO) CreateItem(name string, itemType string, bitsFrom, bitsTo int64) (int64, error) { +func (this *NodePriceItemDAO) CreateItem(tx *dbs.Tx, name string, itemType string, bitsFrom, bitsTo int64) (int64, error) { op := NewNodePriceItemOperator() op.Name = name op.Type = itemType @@ -82,11 +82,11 @@ func (this *NodePriceItemDAO) CreateItem(name string, itemType string, bitsFrom, op.BitsTo = bitsTo op.IsOn = true op.State = NodePriceItemStateEnabled - return this.SaveInt64(op) + return this.SaveInt64(tx, op) } // 修改价格 -func (this *NodePriceItemDAO) UpdateItem(itemId int64, name string, bitsFrom, bitsTo int64) error { +func (this *NodePriceItemDAO) UpdateItem(tx *dbs.Tx, itemId int64, name string, bitsFrom, bitsTo int64) error { if itemId <= 0 { return errors.New("invalid itemId") } @@ -95,12 +95,12 @@ func (this *NodePriceItemDAO) UpdateItem(itemId int64, name string, bitsFrom, bi op.Name = name op.BitsFrom = bitsFrom op.BitsTo = bitsTo - return this.Save(op) + return this.Save(tx, op) } // 列出某个区域的所有价格 -func (this *NodePriceItemDAO) FindAllEnabledRegionPrices(priceType string) (result []*NodePriceItem, err error) { - _, err = this.Query(). +func (this *NodePriceItemDAO) FindAllEnabledRegionPrices(tx *dbs.Tx, priceType string) (result []*NodePriceItem, err error) { + _, err = this.Query(tx). Attr("type", priceType). State(NodePriceItemStateEnabled). Asc("bitsFrom"). @@ -110,8 +110,8 @@ func (this *NodePriceItemDAO) FindAllEnabledRegionPrices(priceType string) (resu } // 列出某个区域的所有启用的价格 -func (this *NodePriceItemDAO) FindAllEnabledAndOnRegionPrices(priceType string) (result []*NodePriceItem, err error) { - _, err = this.Query(). +func (this *NodePriceItemDAO) FindAllEnabledAndOnRegionPrices(tx *dbs.Tx, priceType string) (result []*NodePriceItem, err error) { + _, err = this.Query(tx). Attr("type", priceType). State(NodePriceItemStateEnabled). Attr("isOn", true). diff --git a/internal/db/models/node_region_dao.go b/internal/db/models/node_region_dao.go index 831f27fc..d87e0acf 100644 --- a/internal/db/models/node_region_dao.go +++ b/internal/db/models/node_region_dao.go @@ -36,8 +36,8 @@ func init() { } // 启用条目 -func (this *NodeRegionDAO) EnableNodeRegion(id int64) error { - _, err := this.Query(). +func (this *NodeRegionDAO) EnableNodeRegion(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodeRegionStateEnabled). Update() @@ -45,8 +45,8 @@ func (this *NodeRegionDAO) EnableNodeRegion(id int64) error { } // 禁用条目 -func (this *NodeRegionDAO) DisableNodeRegion(id int64) error { - _, err := this.Query(). +func (this *NodeRegionDAO) DisableNodeRegion(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", NodeRegionStateDisabled). Update() @@ -54,8 +54,8 @@ func (this *NodeRegionDAO) DisableNodeRegion(id int64) error { } // 查找启用中的条目 -func (this *NodeRegionDAO) FindEnabledNodeRegion(id int64) (*NodeRegion, error) { - result, err := this.Query(). +func (this *NodeRegionDAO) FindEnabledNodeRegion(tx *dbs.Tx, id int64) (*NodeRegion, error) { + result, err := this.Query(tx). Pk(id). Attr("state", NodeRegionStateEnabled). Find() @@ -66,26 +66,26 @@ func (this *NodeRegionDAO) FindEnabledNodeRegion(id int64) (*NodeRegion, error) } // 根据主键查找名称 -func (this *NodeRegionDAO) FindNodeRegionName(id int64) (string, error) { - return this.Query(). +func (this *NodeRegionDAO) FindNodeRegionName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建区域 -func (this *NodeRegionDAO) CreateRegion(adminId int64, name string, description string) (int64, error) { +func (this *NodeRegionDAO) CreateRegion(tx *dbs.Tx, adminId int64, name string, description string) (int64, error) { op := NewNodeRegionOperator() op.AdminId = adminId op.Name = name op.Description = description op.State = NodeRegionStateEnabled op.IsOn = true - return this.SaveInt64(op) + return this.SaveInt64(tx, op) } // 修改区域 -func (this *NodeRegionDAO) UpdateRegion(regionId int64, name string, description string, isOn bool) error { +func (this *NodeRegionDAO) UpdateRegion(tx *dbs.Tx, regionId int64, name string, description string, isOn bool) error { if regionId <= 0 { return errors.New("invalid regionId") } @@ -94,12 +94,12 @@ func (this *NodeRegionDAO) UpdateRegion(regionId int64, name string, description op.Name = name op.Description = description op.IsOn = isOn - return this.Save(op) + return this.Save(tx, op) } // 列出所有区域 -func (this *NodeRegionDAO) FindAllEnabledRegions() (result []*NodeRegion, err error) { - _, err = this.Query(). +func (this *NodeRegionDAO) FindAllEnabledRegions(tx *dbs.Tx) (result []*NodeRegion, err error) { + _, err = this.Query(tx). State(NodeRegionStateEnabled). Desc("order"). AscPk(). @@ -109,8 +109,8 @@ func (this *NodeRegionDAO) FindAllEnabledRegions() (result []*NodeRegion, err er } // 列出所有价格 -func (this *NodeRegionDAO) FindAllEnabledRegionPrices() (result []*NodeRegion, err error) { - _, err = this.Query(). +func (this *NodeRegionDAO) FindAllEnabledRegionPrices(tx *dbs.Tx) (result []*NodeRegion, err error) { + _, err = this.Query(tx). State(NodeRegionStateEnabled). Desc("order"). AscPk(). @@ -121,8 +121,8 @@ func (this *NodeRegionDAO) FindAllEnabledRegionPrices() (result []*NodeRegion, e } // 列出所有启用的区域 -func (this *NodeRegionDAO) FindAllEnabledAndOnRegions() (result []*NodeRegion, err error) { - _, err = this.Query(). +func (this *NodeRegionDAO) FindAllEnabledAndOnRegions(tx *dbs.Tx) (result []*NodeRegion, err error) { + _, err = this.Query(tx). State(NodeRegionStateEnabled). Attr("isOn", true). Desc("order"). @@ -133,10 +133,10 @@ func (this *NodeRegionDAO) FindAllEnabledAndOnRegions() (result []*NodeRegion, e } // 排序 -func (this *NodeRegionDAO) UpdateRegionOrders(regionIds []int64) error { +func (this *NodeRegionDAO) UpdateRegionOrders(tx *dbs.Tx, regionIds []int64) error { order := len(regionIds) for _, regionId := range regionIds { - _, err := this.Query(). + _, err := this.Query(tx). Pk(regionId). Set("order", order). Update() @@ -149,8 +149,8 @@ func (this *NodeRegionDAO) UpdateRegionOrders(regionIds []int64) error { } // 修改价格项价格 -func (this *NodeRegionDAO) UpdateRegionItemPrice(regionId int64, itemId int64, price float32) error { - one, err := this.Query(). +func (this *NodeRegionDAO) UpdateRegionItemPrice(tx *dbs.Tx, regionId int64, itemId int64, price float32) error { + one, err := this.Query(tx). Pk(regionId). Result("prices"). Find() @@ -173,7 +173,7 @@ func (this *NodeRegionDAO) UpdateRegionItemPrice(regionId int64, itemId int64, p if err != nil { return err } - _, err = this.Query(). + _, err = this.Query(tx). Pk(regionId). Set("prices", pricesJSON). Update() diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index 639c0373..e0ef62b3 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -42,19 +42,19 @@ func init() { func (this *OriginDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *OriginDAO) EnableOrigin(id int64) error { - _, err := this.Query(). +func (this *OriginDAO) EnableOrigin(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", OriginStateEnabled). Update() @@ -62,8 +62,8 @@ func (this *OriginDAO) EnableOrigin(id int64) error { } // 禁用条目 -func (this *OriginDAO) DisableOrigin(id int64) error { - _, err := this.Query(). +func (this *OriginDAO) DisableOrigin(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", OriginStateDisabled). Update() @@ -71,8 +71,8 @@ func (this *OriginDAO) DisableOrigin(id int64) error { } // 查找启用中的条目 -func (this *OriginDAO) FindEnabledOrigin(id int64) (*Origin, error) { - result, err := this.Query(). +func (this *OriginDAO) FindEnabledOrigin(tx *dbs.Tx, id int64) (*Origin, error) { + result, err := this.Query(tx). Pk(id). Attr("state", OriginStateEnabled). Find() @@ -83,15 +83,15 @@ func (this *OriginDAO) FindEnabledOrigin(id int64) (*Origin, error) { } // 根据主键查找名称 -func (this *OriginDAO) FindOriginName(id int64) (string, error) { - return this.Query(). +func (this *OriginDAO) FindOriginName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建源站 -func (this *OriginDAO) CreateOrigin(adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) { +func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) { op := NewOriginOperator() op.AdminId = adminId op.UserId = userId @@ -104,7 +104,7 @@ func (this *OriginDAO) CreateOrigin(adminId int64, userId int64, name string, ad } op.Weight = weight op.State = OriginStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return } @@ -112,7 +112,7 @@ func (this *OriginDAO) CreateOrigin(adminId int64, userId int64, name string, ad } // 修改源站 -func (this *OriginDAO) UpdateOrigin(originId int64, name string, addrJSON string, description string, weight int32, isOn bool) error { +func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool) error { if originId <= 0 { return errors.New("invalid originId") } @@ -127,13 +127,13 @@ func (this *OriginDAO) UpdateOrigin(originId int64, name string, addrJSON string op.Weight = weight op.IsOn = isOn op.Version = dbs.SQL("version+1") - err := this.Save(op) + err := this.Save(tx, op) return err } // 将源站信息转换为配置 -func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.OriginConfig, error) { - origin, err := this.FindEnabledOrigin(originId) +func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64) (*serverconfigs.OriginConfig, error) { + origin, err := this.FindEnabledOrigin(tx, originId) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.Origi config.RequestHeaderPolicyRef = ref if ref.HeaderPolicyId > 0 { - headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, ref.HeaderPolicyId) if err != nil { return nil, err } @@ -221,7 +221,7 @@ func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.Origi config.ResponseHeaderPolicyRef = ref if ref.HeaderPolicyId > 0 { - headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, ref.HeaderPolicyId) if err != nil { return nil, err } @@ -248,7 +248,7 @@ func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.Origi } config.CertRef = ref if ref.CertId > 0 { - certConfig, err := SharedSSLCertDAO.ComposeCertConfig(ref.CertId) + certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) if err != nil { return nil, err } diff --git a/internal/db/models/provider_dao.go b/internal/db/models/provider_dao.go index a3a14ccc..b850df6b 100644 --- a/internal/db/models/provider_dao.go +++ b/internal/db/models/provider_dao.go @@ -33,24 +33,24 @@ func init() { } // 启用条目 -func (this *ProviderDAO) EnableProvider(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *ProviderDAO) EnableProvider(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", ProviderStateEnabled). Update() } // 禁用条目 -func (this *ProviderDAO) DisableProvider(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *ProviderDAO) DisableProvider(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", ProviderStateDisabled). Update() } // 查找启用中的条目 -func (this *ProviderDAO) FindEnabledProvider(id int64) (*Provider, error) { - result, err := this.Query(). +func (this *ProviderDAO) FindEnabledProvider(tx *dbs.Tx, id int64) (*Provider, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ProviderStateEnabled). Find() @@ -61,8 +61,8 @@ func (this *ProviderDAO) FindEnabledProvider(id int64) (*Provider, error) { } // 查找供应商名称 -func (this *ProviderDAO) FindProviderName(providerId int64) (string, error) { - return this.Query(). +func (this *ProviderDAO) FindProviderName(tx *dbs.Tx, providerId int64) (string, error) { + return this.Query(tx). Pk(providerId). Result("name"). FindStringCol("") diff --git a/internal/db/models/region_city_dao.go b/internal/db/models/region_city_dao.go index b4fa5408..4d941175 100644 --- a/internal/db/models/region_city_dao.go +++ b/internal/db/models/region_city_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *RegionCityDAO) EnableRegionCity(id uint32) error { - _, err := this.Query(). +func (this *RegionCityDAO) EnableRegionCity(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionCityStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *RegionCityDAO) EnableRegionCity(id uint32) error { } // 禁用条目 -func (this *RegionCityDAO) DisableRegionCity(id uint32) error { - _, err := this.Query(). +func (this *RegionCityDAO) DisableRegionCity(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionCityStateDisabled). Update() @@ -53,8 +53,8 @@ func (this *RegionCityDAO) DisableRegionCity(id uint32) error { } // 查找启用中的条目 -func (this *RegionCityDAO) FindEnabledRegionCity(id uint32) (*RegionCity, error) { - result, err := this.Query(). +func (this *RegionCityDAO) FindEnabledRegionCity(tx *dbs.Tx, id uint32) (*RegionCity, error) { + result, err := this.Query(tx). Pk(id). Attr("state", RegionCityStateEnabled). Find() @@ -65,23 +65,23 @@ func (this *RegionCityDAO) FindEnabledRegionCity(id uint32) (*RegionCity, error) } // 根据主键查找名称 -func (this *RegionCityDAO) FindRegionCityName(id uint32) (string, error) { - return this.Query(). +func (this *RegionCityDAO) FindRegionCityName(tx *dbs.Tx, id uint32) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 根据数据ID查找城市 -func (this *RegionCityDAO) FindCityWithDataId(dataId string) (int64, error) { - return this.Query(). +func (this *RegionCityDAO) FindCityWithDataId(tx *dbs.Tx, dataId string) (int64, error) { + return this.Query(tx). Attr("dataId", dataId). ResultPk(). FindInt64Col(0) } // 创建城市 -func (this *RegionCityDAO) CreateCity(provinceId int64, name string, dataId string) (int64, error) { +func (this *RegionCityDAO) CreateCity(tx *dbs.Tx, provinceId int64, name string, dataId string) (int64, error) { op := NewRegionCityOperator() op.ProvinceId = provinceId op.Name = name @@ -94,7 +94,7 @@ func (this *RegionCityDAO) CreateCity(provinceId int64, name string, dataId stri return 0, err } op.Codes = codesJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } diff --git a/internal/db/models/region_country_dao.go b/internal/db/models/region_country_dao.go index df30eb14..22ca9a5f 100644 --- a/internal/db/models/region_country_dao.go +++ b/internal/db/models/region_country_dao.go @@ -37,8 +37,8 @@ func init() { } // 启用条目 -func (this *RegionCountryDAO) EnableRegionCountry(id uint32) error { - _, err := this.Query(). +func (this *RegionCountryDAO) EnableRegionCountry(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionCountryStateEnabled). Update() @@ -46,8 +46,8 @@ func (this *RegionCountryDAO) EnableRegionCountry(id uint32) error { } // 禁用条目 -func (this *RegionCountryDAO) DisableRegionCountry(id int64) error { - _, err := this.Query(). +func (this *RegionCountryDAO) DisableRegionCountry(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionCountryStateDisabled). Update() @@ -55,8 +55,8 @@ func (this *RegionCountryDAO) DisableRegionCountry(id int64) error { } // 查找启用中的条目 -func (this *RegionCountryDAO) FindEnabledRegionCountry(id int64) (*RegionCountry, error) { - result, err := this.Query(). +func (this *RegionCountryDAO) FindEnabledRegionCountry(tx *dbs.Tx, id int64) (*RegionCountry, error) { + result, err := this.Query(tx). Pk(id). Attr("state", RegionCountryStateEnabled). Find() @@ -67,16 +67,16 @@ func (this *RegionCountryDAO) FindEnabledRegionCountry(id int64) (*RegionCountry } // 根据主键查找名称 -func (this *RegionCountryDAO) FindRegionCountryName(id int64) (string, error) { - return this.Query(). +func (this *RegionCountryDAO) FindRegionCountryName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 根据数据ID查找国家 -func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, error) { - return this.Query(). +func (this *RegionCountryDAO) FindCountryIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) { + return this.Query(tx). Attr("dataId", dataId). ResultPk(). FindInt64Col(0) @@ -84,8 +84,8 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err // 根据国家名查找国家ID // TODO 加入缓存 -func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) { - return this.Query(). +func (this *RegionCountryDAO) FindCountryIdWithCountryName(tx *dbs.Tx, countryName string) (int64, error) { + return this.Query(tx). Where("JSON_CONTAINS(codes, :countryName)"). Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号 ResultPk(). @@ -93,7 +93,7 @@ func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) ( } // 根据数据ID创建国家 -func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) { +func (this *RegionCountryDAO) CreateCountry(tx *dbs.Tx, name string, dataId string) (int64, error) { op := NewRegionCountryOperator() op.Name = name @@ -114,7 +114,7 @@ func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, op.DataId = dataId op.State = RegionCountryStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -122,8 +122,8 @@ func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, } // 查找所有可用的国家 -func (this *RegionCountryDAO) FindAllEnabledCountriesOrderByPinyin() (result []*RegionCountry, err error) { - _, err = this.Query(). +func (this *RegionCountryDAO) FindAllEnabledCountriesOrderByPinyin(tx *dbs.Tx) (result []*RegionCountry, err error) { + _, err = this.Query(tx). State(RegionCountryStateEnabled). Slice(&result). Asc("pinyin"). diff --git a/internal/db/models/region_provider_dao.go b/internal/db/models/region_provider_dao.go index a7831c68..4291c574 100644 --- a/internal/db/models/region_provider_dao.go +++ b/internal/db/models/region_provider_dao.go @@ -33,8 +33,8 @@ func init() { } // 启用条目 -func (this *RegionProviderDAO) EnableRegionProvider(id uint32) error { - _, err := this.Query(). +func (this *RegionProviderDAO) EnableRegionProvider(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionProviderStateEnabled). Update() @@ -42,8 +42,8 @@ func (this *RegionProviderDAO) EnableRegionProvider(id uint32) error { } // 禁用条目 -func (this *RegionProviderDAO) DisableRegionProvider(id uint32) error { - _, err := this.Query(). +func (this *RegionProviderDAO) DisableRegionProvider(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionProviderStateDisabled). Update() @@ -51,8 +51,8 @@ func (this *RegionProviderDAO) DisableRegionProvider(id uint32) error { } // 查找启用中的条目 -func (this *RegionProviderDAO) FindEnabledRegionProvider(id uint32) (*RegionProvider, error) { - result, err := this.Query(). +func (this *RegionProviderDAO) FindEnabledRegionProvider(tx *dbs.Tx, id uint32) (*RegionProvider, error) { + result, err := this.Query(tx). Pk(id). Attr("state", RegionProviderStateEnabled). Find() @@ -63,8 +63,8 @@ func (this *RegionProviderDAO) FindEnabledRegionProvider(id uint32) (*RegionProv } // 根据主键查找名称 -func (this *RegionProviderDAO) FindRegionProviderName(id uint32) (string, error) { - return this.Query(). +func (this *RegionProviderDAO) FindRegionProviderName(tx *dbs.Tx, id uint32) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") diff --git a/internal/db/models/region_province_dao.go b/internal/db/models/region_province_dao.go index 69957558..1bf20af6 100644 --- a/internal/db/models/region_province_dao.go +++ b/internal/db/models/region_province_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *RegionProvinceDAO) EnableRegionProvince(id int64) error { - _, err := this.Query(). +func (this *RegionProvinceDAO) EnableRegionProvince(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionProvinceStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *RegionProvinceDAO) EnableRegionProvince(id int64) error { } // 禁用条目 -func (this *RegionProvinceDAO) DisableRegionProvince(id int64) error { - _, err := this.Query(). +func (this *RegionProvinceDAO) DisableRegionProvince(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", RegionProvinceStateDisabled). Update() @@ -53,8 +53,8 @@ func (this *RegionProvinceDAO) DisableRegionProvince(id int64) error { } // 查找启用中的条目 -func (this *RegionProvinceDAO) FindEnabledRegionProvince(id int64) (*RegionProvince, error) { - result, err := this.Query(). +func (this *RegionProvinceDAO) FindEnabledRegionProvince(tx *dbs.Tx, id int64) (*RegionProvince, error) { + result, err := this.Query(tx). Pk(id). Attr("state", RegionProvinceStateEnabled). Find() @@ -65,16 +65,16 @@ func (this *RegionProvinceDAO) FindEnabledRegionProvince(id int64) (*RegionProvi } // 根据主键查找名称 -func (this *RegionProvinceDAO) FindRegionProvinceName(id int64) (string, error) { - return this.Query(). +func (this *RegionProvinceDAO) FindRegionProvinceName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 根据数据ID查找省份 -func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, error) { - return this.Query(). +func (this *RegionProvinceDAO) FindProvinceIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) { + return this.Query(tx). Attr("dataId", dataId). ResultPk(). FindInt64Col(0) @@ -82,8 +82,8 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e // 根据省份名查找省份ID // TODO 加入缓存 -func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) { - return this.Query(). +func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(tx *dbs.Tx, provinceName string) (int64, error) { + return this.Query(tx). Where("JSON_CONTAINS(codes, :provinceName)"). Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号 ResultPk(). @@ -91,7 +91,7 @@ func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName strin } // 创建省份 -func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) { +func (this *RegionProvinceDAO) CreateProvince(tx *dbs.Tx, countryId int64, name string, dataId string) (int64, error) { op := NewRegionProvinceOperator() op.CountryId = countryId op.Name = name @@ -104,7 +104,7 @@ func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, data return 0, err } op.Codes = codesJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -112,8 +112,8 @@ func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, data } // 查找所有省份 -func (this *RegionProvinceDAO) FindAllEnabledProvincesWithCountryId(countryId int64) (result []*RegionProvince, err error) { - _, err = this.Query(). +func (this *RegionProvinceDAO) FindAllEnabledProvincesWithCountryId(tx *dbs.Tx, countryId int64) (result []*RegionProvince, err error) { + _, err = this.Query(tx). State(RegionProvinceStateEnabled). Attr("countryId", countryId). Asc(). diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index 6f9832a8..e61b3b90 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -40,19 +40,19 @@ func init() { func (this *ReverseProxyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *ReverseProxyDAO) EnableReverseProxy(id int64) error { - _, err := this.Query(). +func (this *ReverseProxyDAO) EnableReverseProxy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ReverseProxyStateEnabled). Update() @@ -63,8 +63,8 @@ func (this *ReverseProxyDAO) EnableReverseProxy(id int64) error { } // 禁用条目 -func (this *ReverseProxyDAO) DisableReverseProxy(id int64) error { - _, err := this.Query(). +func (this *ReverseProxyDAO) DisableReverseProxy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ReverseProxyStateDisabled). Update() @@ -75,8 +75,8 @@ func (this *ReverseProxyDAO) DisableReverseProxy(id int64) error { } // 查找启用中的条目 -func (this *ReverseProxyDAO) FindEnabledReverseProxy(id int64) (*ReverseProxy, error) { - result, err := this.Query(). +func (this *ReverseProxyDAO) FindEnabledReverseProxy(tx *dbs.Tx, id int64) (*ReverseProxy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ReverseProxyStateEnabled). Find() @@ -87,8 +87,8 @@ func (this *ReverseProxyDAO) FindEnabledReverseProxy(id int64) (*ReverseProxy, e } // 根据iD组合配置 -func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*serverconfigs.ReverseProxyConfig, error) { - reverseProxy, err := this.FindEnabledReverseProxy(reverseProxyId) +func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyId int64) (*serverconfigs.ReverseProxyConfig, error) { + reverseProxy, err := this.FindEnabledReverseProxy(tx, reverseProxyId) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s return nil, err } for _, ref := range originRefs { - originConfig, err := SharedOriginDAO.ComposeOriginConfig(ref.OriginId) + originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, ref.OriginId) if err != nil { return nil, err } @@ -137,7 +137,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s return nil, err } for _, originConfig := range originRefs { - originConfig, err := SharedOriginDAO.ComposeOriginConfig(originConfig.OriginId) + originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, originConfig.OriginId) if err != nil { return nil, err } @@ -151,7 +151,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s } // 创建反向代理 -func (this *ReverseProxyDAO) CreateReverseProxy(adminId int64, userId int64, schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) { +func (this *ReverseProxyDAO) CreateReverseProxy(tx *dbs.Tx, adminId int64, userId int64, schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) { op := NewReverseProxyOperator() op.IsOn = true op.State = ReverseProxyStateEnabled @@ -167,7 +167,7 @@ func (this *ReverseProxyDAO) CreateReverseProxy(adminId int64, userId int64, sch if len(backupOriginsJSON) > 0 { op.BackupOrigins = string(backupOriginsJSON) } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -176,7 +176,7 @@ func (this *ReverseProxyDAO) CreateReverseProxy(adminId int64, userId int64, sch } // 修改反向代理调度算法 -func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(reverseProxyId int64, schedulingJSON []byte) error { +func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reverseProxyId int64, schedulingJSON []byte) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -187,13 +187,13 @@ func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(reverseProxyId int64, } else { op.Scheduling = "null" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改主要源站 -func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(reverseProxyId int64, origins []byte) error { +func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, origins []byte) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -204,13 +204,13 @@ func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(reverseProxyId int } else { op.PrimaryOrigins = "[]" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改备用源站 -func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(reverseProxyId int64, origins []byte) error { +func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(tx *dbs.Tx, reverseProxyId int64, origins []byte) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -221,13 +221,13 @@ func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(reverseProxyId int6 } else { op.BackupOrigins = "[]" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改是否启用 -func (this *ReverseProxyDAO) UpdateReverseProxy(reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool) error { +func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -244,11 +244,11 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(reverseProxyId int64, requestHos op.RequestURI = requestURI op.StripPrefix = stripPrefix op.AutoFlush = autoFlush - err := this.Save(op) + err := this.Save(tx, op) return err } // 通知更新 func (this *ReverseProxyDAO) CreateEvent() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) } diff --git a/internal/db/models/server_daily_stat_dao.go b/internal/db/models/server_daily_stat_dao.go index 7b0779fa..5f9812a9 100644 --- a/internal/db/models/server_daily_stat_dao.go +++ b/internal/db/models/server_daily_stat_dao.go @@ -31,13 +31,13 @@ func init() { } // 提交数据 -func (this *ServerDailyStatDAO) SaveStats(stats []*pb.ServerDailyStat) error { +func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailyStat) error { for _, stat := range stats { day := timeutil.FormatTime("Ymd", stat.CreatedAt) timeFrom := timeutil.FormatTime("His", stat.CreatedAt) timeTo := timeutil.FormatTime("His", stat.CreatedAt+5*60) // 5分钟 - _, _, err := this.Query(). + _, _, err := this.Query(tx). Param("bytes", stat.Bytes). InsertOrUpdate(maps.Map{ "serverId": stat.ServerId, @@ -58,8 +58,8 @@ func (this *ServerDailyStatDAO) SaveStats(stats []*pb.ServerDailyStat) error { // 根据用户计算某月合计 // month 格式为YYYYMM -func (this *ServerDailyStatDAO) SumUserMonthly(userId int64, regionId int64, month string) (int64, error) { - query := this.Query() +func (this *ServerDailyStatDAO) SumUserMonthly(tx *dbs.Tx, userId int64, regionId int64, month string) (int64, error) { + query := this.Query(tx) if regionId > 0 { query.Attr("regionId", regionId) } @@ -71,8 +71,8 @@ func (this *ServerDailyStatDAO) SumUserMonthly(userId int64, regionId int64, mon // 获取某月带宽峰值 // month 格式为YYYYMM -func (this *ServerDailyStatDAO) SumUserMonthlyPeek(userId int64, regionId int64, month string) (int64, error) { - query := this.Query() +func (this *ServerDailyStatDAO) SumUserMonthlyPeek(tx *dbs.Tx, userId int64, regionId int64, month string) (int64, error) { + query := this.Query(tx) if regionId > 0 { query.Attr("regionId", regionId) } @@ -88,8 +88,8 @@ func (this *ServerDailyStatDAO) SumUserMonthlyPeek(userId int64, regionId int64, // 获取某天流量总和 // day 格式为YYYYMMDD -func (this *ServerDailyStatDAO) SumUserDaily(userId int64, regionId int64, day string) (int64, error) { - query := this.Query() +func (this *ServerDailyStatDAO) SumUserDaily(tx *dbs.Tx, userId int64, regionId int64, day string) (int64, error) { + query := this.Query(tx) if regionId > 0 { query.Attr("regionId", regionId) } @@ -102,8 +102,8 @@ func (this *ServerDailyStatDAO) SumUserDaily(userId int64, regionId int64, day s // 获取某天带宽峰值 // day 格式为YYYYMMDD -func (this *ServerDailyStatDAO) SumUserDailyPeek(userId int64, regionId int64, day string) (int64, error) { - query := this.Query() +func (this *ServerDailyStatDAO) SumUserDailyPeek(tx *dbs.Tx, userId int64, regionId int64, day string) (int64, error) { + query := this.Query(tx) if regionId > 0 { query.Attr("regionId", regionId) } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 2287a131..079f8d89 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -54,16 +54,16 @@ func (this *ServerDAO) Init() { } // 启用条目 -func (this *ServerDAO) EnableServer(id uint32) (rowsAffected int64, err error) { - return this.Query(). +func (this *ServerDAO) EnableServer(tx *dbs.Tx, id uint32) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", ServerStateEnabled). Update() } // 禁用条目 -func (this *ServerDAO) DisableServer(id int64) (err error) { - _, err = this.Query(). +func (this *ServerDAO) DisableServer(tx *dbs.Tx, id int64) (err error) { + _, err = this.Query(tx). Pk(id). Set("state", ServerStateDisabled). Update() @@ -71,8 +71,8 @@ func (this *ServerDAO) DisableServer(id int64) (err error) { } // 查找启用中的条目 -func (this *ServerDAO) FindEnabledServer(id int64) (*Server, error) { - result, err := this.Query(). +func (this *ServerDAO) FindEnabledServer(tx *dbs.Tx, id int64) (*Server, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ServerStateEnabled). Find() @@ -83,15 +83,16 @@ func (this *ServerDAO) FindEnabledServer(id int64) (*Server, error) { } // 查找服务类型 -func (this *ServerDAO) FindEnabledServerType(serverId int64) (string, error) { - return this.Query(). +func (this *ServerDAO) FindEnabledServerType(tx *dbs.Tx, serverId int64) (string, error) { + return this.Query(tx). Pk(serverId). Result("type"). FindStringCol("") } // 创建服务 -func (this *ServerDAO) CreateServer(adminId int64, +func (this *ServerDAO) CreateServer(tx *dbs.Tx, + adminId int64, userId int64, serverType serverconfigs.ServerType, name string, @@ -166,7 +167,7 @@ func (this *ServerDAO) CreateServer(adminId int64, op.GroupIds = groupIdsJSON } - dnsName, err := this.genDNSName() + dnsName, err := this.genDNSName(tx) if err != nil { return 0, err } @@ -175,7 +176,7 @@ func (this *ServerDAO) CreateServer(adminId int64, op.Version = 1 op.IsOn = 1 op.State = ServerStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err @@ -183,7 +184,7 @@ func (this *ServerDAO) CreateServer(adminId int64, serverId = types.Int64(op.Id) - _, err = this.RenewServerConfig(serverId, false) + _, err = this.RenewServerConfig(tx, serverId, false) if err != nil { return serverId, err } @@ -197,7 +198,7 @@ func (this *ServerDAO) CreateServer(adminId int64, } // 修改服务基本信息 -func (this *ServerDAO) UpdateServerBasic(serverId int64, name string, description string, clusterId int64, isOn bool, groupIds []int64) error { +func (this *ServerDAO) UpdateServerBasic(tx *dbs.Tx, serverId int64, name string, description string, clusterId int64, isOn bool, groupIds []int64) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } @@ -218,12 +219,12 @@ func (this *ServerDAO) UpdateServerBasic(serverId int64, name string, descriptio op.GroupIds = groupIdsJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } - _, err = this.RenewServerConfig(serverId, false) + _, err = this.RenewServerConfig(tx, serverId, false) if err != nil { return err } @@ -232,8 +233,8 @@ func (this *ServerDAO) UpdateServerBasic(serverId int64, name string, descriptio } // 修复服务是否启用 -func (this *ServerDAO) UpdateServerIsOn(serverId int64, isOn bool) error { - _, err := this.Query(). +func (this *ServerDAO) UpdateServerIsOn(tx *dbs.Tx, serverId int64, isOn bool) error { + _, err := this.Query(tx). Pk(serverId). Set("isOn", isOn). Update() @@ -241,13 +242,13 @@ func (this *ServerDAO) UpdateServerIsOn(serverId int64, isOn bool) error { } // 修改服务配置 -func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte, updateMd5 bool) (isChanged bool, err error) { +func (this *ServerDAO) UpdateServerConfig(tx *dbs.Tx, serverId int64, configJSON []byte, updateMd5 bool) (isChanged bool, err error) { if serverId <= 0 { return false, errors.New("serverId should not be smaller than 0") } // 查询以前的md5 - oldConfigMd5, err := this.Query(). + oldConfigMd5, err := this.Query(tx). Pk(serverId). Result("configMd5"). FindStringCol("") @@ -255,7 +256,7 @@ func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte, upd return false, err } - globalConfig, err := SharedSysSettingDAO.ReadSetting(SettingCodeServerGlobalConfig) + globalConfig, err := SharedSysSettingDAO.ReadSetting(tx, SettingCodeServerGlobalConfig) if err != nil { return false, err } @@ -279,19 +280,19 @@ func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte, upd if updateMd5 { op.ConfigMd5 = newConfigMd5 } - err = this.Save(op) + err = this.Save(tx, op) return true, err } // 修改HTTP配置 -func (this *ServerDAO) UpdateServerHTTP(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(config) == 0 { config = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("http", string(config)). Update() @@ -299,7 +300,7 @@ func (this *ServerDAO) UpdateServerHTTP(serverId int64, config []byte) error { return err } - _, err = this.RenewServerConfig(serverId, false) + _, err = this.RenewServerConfig(tx, serverId, false) if err != nil { return err } @@ -308,14 +309,14 @@ func (this *ServerDAO) UpdateServerHTTP(serverId int64, config []byte) error { } // 修改HTTPS配置 -func (this *ServerDAO) UpdateServerHTTPS(serverId int64, httpsJSON []byte) error { +func (this *ServerDAO) UpdateServerHTTPS(tx *dbs.Tx, serverId int64, httpsJSON []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(httpsJSON) == 0 { httpsJSON = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("https", string(httpsJSON)). Update() @@ -323,7 +324,7 @@ func (this *ServerDAO) UpdateServerHTTPS(serverId int64, httpsJSON []byte) error return err } - _, err = this.RenewServerConfig(serverId, false) + _, err = this.RenewServerConfig(tx, serverId, false) if err != nil { return err } @@ -332,14 +333,14 @@ func (this *ServerDAO) UpdateServerHTTPS(serverId int64, httpsJSON []byte) error } // 修改TCP配置 -func (this *ServerDAO) UpdateServerTCP(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerTCP(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(config) == 0 { config = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("tcp", string(config)). Update() @@ -351,14 +352,14 @@ func (this *ServerDAO) UpdateServerTCP(serverId int64, config []byte) error { } // 修改TLS配置 -func (this *ServerDAO) UpdateServerTLS(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerTLS(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(config) == 0 { config = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("tls", string(config)). Update() @@ -370,14 +371,14 @@ func (this *ServerDAO) UpdateServerTLS(serverId int64, config []byte) error { } // 修改Unix配置 -func (this *ServerDAO) UpdateServerUnix(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerUnix(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(config) == 0 { config = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("unix", string(config)). Update() @@ -389,14 +390,14 @@ func (this *ServerDAO) UpdateServerUnix(serverId int64, config []byte) error { } // 修改UDP配置 -func (this *ServerDAO) UpdateServerUDP(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerUDP(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } if len(config) == 0 { config = []byte("null") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("udp", string(config)). Update() @@ -408,11 +409,11 @@ func (this *ServerDAO) UpdateServerUDP(serverId int64, config []byte) error { } // 修改Web配置 -func (this *ServerDAO) UpdateServerWeb(serverId int64, webId int64) error { +func (this *ServerDAO) UpdateServerWeb(tx *dbs.Tx, serverId int64, webId int64) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(serverId). Set("webId", webId). Update() @@ -423,22 +424,22 @@ func (this *ServerDAO) UpdateServerWeb(serverId int64, webId int64) error { } // 初始化Web配置 -func (this *ServerDAO) InitServerWeb(serverId int64) (int64, error) { +func (this *ServerDAO) InitServerWeb(tx *dbs.Tx, serverId int64) (int64, error) { if serverId <= 0 { return 0, errors.New("serverId should not be smaller than 0") } - adminId, userId, err := this.FindServerAdminIdAndUserId(serverId) + adminId, userId, err := this.FindServerAdminIdAndUserId(tx, serverId) if err != nil { return 0, err } - webId, err := SharedHTTPWebDAO.CreateWeb(adminId, userId, nil) + webId, err := SharedHTTPWebDAO.CreateWeb(tx, adminId, userId, nil) if err != nil { return 0, err } - _, err = this.Query(). + _, err = this.Query(tx). Pk(serverId). Set("webId", webId). Update() @@ -455,11 +456,11 @@ func (this *ServerDAO) InitServerWeb(serverId int64) (int64, error) { } // 查找ServerNames配置 -func (this *ServerDAO) FindServerNames(serverId int64) (serverNamesJSON []byte, isAuditing bool, auditingServerNamesJSON []byte, auditingResultJSON []byte, err error) { +func (this *ServerDAO) FindServerNames(tx *dbs.Tx, serverId int64) (serverNamesJSON []byte, isAuditing bool, auditingServerNamesJSON []byte, auditingResultJSON []byte, err error) { if serverId <= 0 { return } - one, err := this.Query(). + one, err := this.Query(tx). Pk(serverId). Result("serverNames", "isAuditing", "auditingServerNames", "auditingResult"). Find() @@ -474,7 +475,7 @@ func (this *ServerDAO) FindServerNames(serverId int64) (serverNamesJSON []byte, } // 修改ServerNames配置 -func (this *ServerDAO) UpdateServerNames(serverId int64, serverNames []byte) error { +func (this *ServerDAO) UpdateServerNames(tx *dbs.Tx, serverId int64, serverNames []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } @@ -486,7 +487,7 @@ func (this *ServerDAO) UpdateServerNames(serverId int64, serverNames []byte) err serverNames = []byte("[]") } op.ServerNames = serverNames - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -494,7 +495,7 @@ func (this *ServerDAO) UpdateServerNames(serverId int64, serverNames []byte) err } // 修改域名审核 -func (this *ServerDAO) UpdateAuditingServerNames(serverId int64, isAuditing bool, auditingServerNamesJSON []byte) error { +func (this *ServerDAO) UpdateAuditingServerNames(tx *dbs.Tx, serverId int64, isAuditing bool, auditingServerNamesJSON []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } @@ -508,7 +509,7 @@ func (this *ServerDAO) UpdateAuditingServerNames(serverId int64, isAuditing bool op.AuditingServerNames = auditingServerNamesJSON } op.AuditingResult = `{"isOk":true}` - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -516,7 +517,7 @@ func (this *ServerDAO) UpdateAuditingServerNames(serverId int64, isAuditing bool } // 修改域名审核结果 -func (this *ServerDAO) UpdateServerAuditing(serverId int64, result *pb.ServerNameAuditingResult) error { +func (this *ServerDAO) UpdateServerAuditing(tx *dbs.Tx, serverId int64, result *pb.ServerNameAuditingResult) error { if serverId <= 0 { return errors.New("invalid serverId") } @@ -537,18 +538,18 @@ func (this *ServerDAO) UpdateServerAuditing(serverId int64, result *pb.ServerNam if result.IsOk { op.ServerNames = dbs.SQL("auditingServerNames") } - return this.Save(op) + return this.Save(tx, op) } // 修改反向代理配置 -func (this *ServerDAO) UpdateServerReverseProxy(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, config []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } op := NewServerOperator() op.Id = serverId op.ReverseProxy = JSONBytes(config) - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return err } @@ -557,8 +558,8 @@ func (this *ServerDAO) UpdateServerReverseProxy(serverId int64, config []byte) e } // 计算所有可用服务数量 -func (this *ServerDAO) CountAllEnabledServersMatch(groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState) (int64, error) { - query := this.Query(). +func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState) (int64, error) { + query := this.Query(tx). State(ServerStateEnabled) if groupId > 0 { query.Where("JSON_CONTAINS(groupIds, :groupId)"). @@ -581,8 +582,8 @@ func (this *ServerDAO) CountAllEnabledServersMatch(groupId int64, keyword string } // 列出单页的服务 -func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32) (result []*Server, err error) { - query := this.Query(). +func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32) (result []*Server, err error) { + query := this.Query(tx). State(ServerStateEnabled). Offset(offset). Limit(size). @@ -612,9 +613,9 @@ func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId } // 获取节点中的所有服务 -func (this *ServerDAO) FindAllEnabledServersWithNode(nodeId int64) (result []*Server, err error) { +func (this *ServerDAO) FindAllEnabledServersWithNode(tx *dbs.Tx, nodeId int64) (result []*Server, err error) { // 节点所在集群 - clusterId, err := SharedNodeDAO.FindNodeClusterId(nodeId) + clusterId, err := SharedNodeDAO.FindNodeClusterId(tx, nodeId) if err != nil { return nil, err } @@ -622,7 +623,7 @@ func (this *ServerDAO) FindAllEnabledServersWithNode(nodeId int64) (result []*Se return nil, nil } - _, err = this.Query(). + _, err = this.Query(tx). Attr("clusterId", clusterId). State(ServerStateEnabled). AscPk(). @@ -632,8 +633,8 @@ func (this *ServerDAO) FindAllEnabledServersWithNode(nodeId int64) (result []*Se } // 获取所有的服务ID -func (this *ServerDAO) FindAllEnabledServerIds() (serverIds []int64, err error) { - ones, err := this.Query(). +func (this *ServerDAO) FindAllEnabledServerIds(tx *dbs.Tx) (serverIds []int64, err error) { + ones, err := this.Query(tx). State(ServerStateEnabled). AscPk(). ResultPk(). @@ -645,8 +646,8 @@ func (this *ServerDAO) FindAllEnabledServerIds() (serverIds []int64, err error) } // 查找服务的搜索条件 -func (this *ServerDAO) FindServerNodeFilters(serverId int64) (isOk bool, clusterId int64, err error) { - one, err := this.Query(). +func (this *ServerDAO) FindServerNodeFilters(tx *dbs.Tx, serverId int64) (isOk bool, clusterId int64, err error) { + one, err := this.Query(tx). Pk(serverId). Result("clusterId"). Find() @@ -662,8 +663,8 @@ func (this *ServerDAO) FindServerNodeFilters(serverId int64) (isOk bool, cluster } // 构造服务的Config -func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.ServerConfig, error) { - server, err := this.FindEnabledServer(serverId) +func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverconfigs.ServerConfig, error) { + server, err := this.FindEnabledServer(tx, serverId) if err != nil { return nil, err } @@ -690,12 +691,12 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve // CNAME if server.ClusterId > 0 && len(server.DnsName) > 0 { - clusterDNS, err := SharedNodeClusterDAO.FindClusterDNSInfo(int64(server.ClusterId)) + clusterDNS, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, int64(server.ClusterId)) if err != nil { return nil, err } if clusterDNS != nil && clusterDNS.DnsDomainId > 0 { - domain, err := SharedDNSDomainDAO.FindEnabledDNSDomain(int64(clusterDNS.DnsDomainId)) + domain, err := SharedDNSDomainDAO.FindEnabledDNSDomain(tx, int64(clusterDNS.DnsDomainId)) if err != nil { return nil, err } @@ -726,7 +727,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve // SSL if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { - sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(httpsConfig.SSLPolicyRef.SSLPolicyId) + sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, httpsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } @@ -758,7 +759,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve // SSL if tlsConfig.SSLPolicyRef != nil { - sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tlsConfig.SSLPolicyRef.SSLPolicyId) + sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, tlsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } @@ -792,7 +793,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve // Web if server.WebId > 0 { - webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(int64(server.WebId)) + webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(server.WebId)) if err != nil { return nil, err } @@ -810,7 +811,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve } config.ReverseProxyRef = reverseProxyRef - reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(reverseProxyRef.ReverseProxyId) + reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId) if err != nil { return nil, err } @@ -823,8 +824,8 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve } // 更新服务的Config配置 -func (this *ServerDAO) RenewServerConfig(serverId int64, updateMd5 bool) (isChanged bool, err error) { - serverConfig, err := this.ComposeServerConfig(serverId) +func (this *ServerDAO) RenewServerConfig(tx *dbs.Tx, serverId int64, updateMd5 bool) (isChanged bool, err error) { + serverConfig, err := this.ComposeServerConfig(tx, serverId) if err != nil { return false, err } @@ -832,12 +833,12 @@ func (this *ServerDAO) RenewServerConfig(serverId int64, updateMd5 bool) (isChan if err != nil { return false, err } - return this.UpdateServerConfig(serverId, data, updateMd5) + return this.UpdateServerConfig(tx, serverId, data, updateMd5) } // 根据条件获取反向代理配置 -func (this *ServerDAO) FindReverseProxyRef(serverId int64) (*serverconfigs.ReverseProxyRef, error) { - reverseProxy, err := this.Query(). +func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) { + reverseProxy, err := this.Query(tx). Pk(serverId). Result("reverseProxy"). FindStringCol("") @@ -853,8 +854,8 @@ func (this *ServerDAO) FindReverseProxyRef(serverId int64) (*serverconfigs.Rever } // 查找Server对应的WebId -func (this *ServerDAO) FindServerWebId(serverId int64) (int64, error) { - webId, err := this.Query(). +func (this *ServerDAO) FindServerWebId(tx *dbs.Tx, serverId int64) (int64, error) { + webId, err := this.Query(tx). Pk(serverId). Result("webId"). FindIntCol(0) @@ -865,7 +866,7 @@ func (this *ServerDAO) FindServerWebId(serverId int64) (int64, error) { } // 计算使用SSL策略的所有服务数量 -func (this *ServerDAO) CountAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int64) (count int64, err error) { +func (this *ServerDAO) CountAllEnabledServersWithSSLPolicyIds(tx *dbs.Tx, sslPolicyIds []int64) (count int64, err error) { if len(sslPolicyIds) == 0 { return } @@ -873,7 +874,7 @@ func (this *ServerDAO) CountAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int for _, policyId := range sslPolicyIds { policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) } - return this.Query(). + return this.Query(tx). State(ServerStateEnabled). Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(tls, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). Param("policyIds", strings.Join(policyStringIds, ",")). @@ -881,7 +882,7 @@ func (this *ServerDAO) CountAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int } // 查找使用某个SSL策略的所有服务 -func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int64) (result []*Server, err error) { +func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(tx *dbs.Tx, sslPolicyIds []int64) (result []*Server, err error) { if len(sslPolicyIds) == 0 { return } @@ -889,7 +890,7 @@ func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int6 for _, policyId := range sslPolicyIds { policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) } - _, err = this.Query(). + _, err = this.Query(tx). State(ServerStateEnabled). Result("id", "name", "https", "tls", "isOn", "type"). Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(tls, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). @@ -901,22 +902,22 @@ func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int6 } // 计算使用某个缓存策略的所有服务数量 -func (this *ServerDAO) CountEnabledServersWithWebIds(webIds []int64) (count int64, err error) { +func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) { if len(webIds) == 0 { return } - return this.Query(). + return this.Query(tx). State(ServerStateEnabled). Attr("webId", webIds). Count() } // 查找使用某个缓存策略的所有服务 -func (this *ServerDAO) FindAllEnabledServersWithWebIds(webIds []int64) (result []*Server, err error) { +func (this *ServerDAO) FindAllEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (result []*Server, err error) { if len(webIds) == 0 { return } - _, err = this.Query(). + _, err = this.Query(tx). State(ServerStateEnabled). Attr("webId", webIds). AscPk(). @@ -926,16 +927,16 @@ func (this *ServerDAO) FindAllEnabledServersWithWebIds(webIds []int64) (result [ } // 计算使用某个集群的所有服务数量 -func (this *ServerDAO) CountAllEnabledServersWithNodeClusterId(clusterId int64) (int64, error) { - return this.Query(). +func (this *ServerDAO) CountAllEnabledServersWithNodeClusterId(tx *dbs.Tx, clusterId int64) (int64, error) { + return this.Query(tx). State(ServerStateEnabled). Attr("clusterId", clusterId). Count() } // 计算使用某个分组的服务数量 -func (this *ServerDAO) CountAllEnabledServersWithGroupId(groupId int64) (int64, error) { - return this.Query(). +func (this *ServerDAO) CountAllEnabledServersWithGroupId(tx *dbs.Tx, groupId int64) (int64, error) { + return this.Query(tx). State(ServerStateEnabled). Where("JSON_CONTAINS(groupIds, :groupId)"). Param("groupId", numberutils.FormatInt64(groupId)). @@ -943,15 +944,15 @@ func (this *ServerDAO) CountAllEnabledServersWithGroupId(groupId int64) (int64, } // 查询使用某个DNS域名的所有服务域名 -func (this *ServerDAO) FindAllServerDNSNamesWithDNSDomainId(dnsDomainId int64) ([]string, error) { - clusterIds, err := SharedNodeClusterDAO.FindAllEnabledClusterIdsWithDNSDomainId(dnsDomainId) +func (this *ServerDAO) FindAllServerDNSNamesWithDNSDomainId(tx *dbs.Tx, dnsDomainId int64) ([]string, error) { + clusterIds, err := SharedNodeClusterDAO.FindAllEnabledClusterIdsWithDNSDomainId(tx, dnsDomainId) if err != nil { return nil, err } if len(clusterIds) == 0 { return nil, nil } - ones, err := this.Query(). + ones, err := this.Query(tx). State(ServerStateEnabled). Attr("isOn", true). Attr("clusterId", clusterIds). @@ -973,8 +974,8 @@ func (this *ServerDAO) FindAllServerDNSNamesWithDNSDomainId(dnsDomainId int64) ( } // 获取某个集群下的服务DNS信息 -func (this *ServerDAO) FindAllServersDNSWithClusterId(clusterId int64) (result []*Server, err error) { - _, err = this.Query(). +func (this *ServerDAO) FindAllServersDNSWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Server, err error) { + _, err = this.Query(tx). State(ServerStateEnabled). Attr("isOn", true). Attr("isAuditing", false). // 不在审核中 @@ -987,45 +988,45 @@ func (this *ServerDAO) FindAllServersDNSWithClusterId(clusterId int64) (result [ } // 重新生成子域名 -func (this *ServerDAO) GenerateServerDNSName(serverId int64) (string, error) { +func (this *ServerDAO) GenerateServerDNSName(tx *dbs.Tx, serverId int64) (string, error) { if serverId <= 0 { return "", errors.New("invalid serverId") } - dnsName, err := this.genDNSName() + dnsName, err := this.genDNSName(tx) if err != nil { return "", err } op := NewServerOperator() op.Id = serverId op.DnsName = dnsName - err = this.Save(op) + err = this.Save(tx, op) return dnsName, err } // 创建事件 func (this *ServerDAO) createEvent() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) } // 查询当前服务的集群ID -func (this *ServerDAO) FindServerClusterId(serverId int64) (int64, error) { - return this.Query(). +func (this *ServerDAO) FindServerClusterId(tx *dbs.Tx, serverId int64) (int64, error) { + return this.Query(tx). Pk(serverId). Result("clusterId"). FindInt64Col(0) } // 查询服务的DNS名称 -func (this *ServerDAO) FindServerDNSName(serverId int64) (string, error) { - return this.Query(). +func (this *ServerDAO) FindServerDNSName(tx *dbs.Tx, serverId int64) (string, error) { + return this.Query(tx). Pk(serverId). Result("dnsName"). FindStringCol("") } // 获取当前服务的管理员ID和用户ID -func (this *ServerDAO) FindServerAdminIdAndUserId(serverId int64) (adminId int64, userId int64, err error) { - one, err := this.Query(). +func (this *ServerDAO) FindServerAdminIdAndUserId(tx *dbs.Tx, serverId int64) (adminId int64, userId int64, err error) { + one, err := this.Query(tx). Pk(serverId). Result("adminId", "userId"). Find() @@ -1039,11 +1040,11 @@ func (this *ServerDAO) FindServerAdminIdAndUserId(serverId int64) (adminId int64 } // 检查用户服务 -func (this *ServerDAO) CheckUserServer(serverId int64, userId int64) error { +func (this *ServerDAO) CheckUserServer(tx *dbs.Tx, serverId int64, userId int64) error { if serverId <= 0 || userId <= 0 { return ErrNotFound } - ok, err := this.Query(). + ok, err := this.Query(tx). Pk(serverId). Attr("userId", userId). Exist() @@ -1057,8 +1058,8 @@ func (this *ServerDAO) CheckUserServer(serverId int64, userId int64) error { } // 设置一个用户下的所有服务的所属集群 -func (this *ServerDAO) UpdateUserServersClusterId(userId int64, clusterId int64) error { - _, err := this.Query(). +func (this *ServerDAO) UpdateUserServersClusterId(tx *dbs.Tx, userId int64, clusterId int64) error { + _, err := this.Query(tx). Attr("userId", userId). Set("clusterId", clusterId). Update() @@ -1066,8 +1067,8 @@ func (this *ServerDAO) UpdateUserServersClusterId(userId int64, clusterId int64) } // 查找用户的所有的服务 -func (this *ServerDAO) FindAllEnabledServersWithUserId(userId int64) (result []*Server, err error) { - _, err = this.Query(). +func (this *ServerDAO) FindAllEnabledServersWithUserId(tx *dbs.Tx, userId int64) (result []*Server, err error) { + _, err = this.Query(tx). State(ServerStateEnabled). Attr("userId", userId). DescPk(). @@ -1077,10 +1078,10 @@ func (this *ServerDAO) FindAllEnabledServersWithUserId(userId int64) (result []* } // 生成DNS Name -func (this *ServerDAO) genDNSName() (string, error) { +func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) { for { dnsName := rands.HexString(8) - exist, err := this.Query(). + exist, err := this.Query(tx). Attr("dnsName", dnsName). Exist() if err != nil { diff --git a/internal/db/models/server_group_dao.go b/internal/db/models/server_group_dao.go index b7ed812c..77e6e5bf 100644 --- a/internal/db/models/server_group_dao.go +++ b/internal/db/models/server_group_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *ServerGroupDAO) EnableServerGroup(id int64) error { - _, err := this.Query(). +func (this *ServerGroupDAO) EnableServerGroup(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ServerGroupStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *ServerGroupDAO) EnableServerGroup(id int64) error { } // 禁用条目 -func (this *ServerGroupDAO) DisableServerGroup(id int64) error { - _, err := this.Query(). +func (this *ServerGroupDAO) DisableServerGroup(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", ServerGroupStateDisabled). Update() @@ -53,8 +53,8 @@ func (this *ServerGroupDAO) DisableServerGroup(id int64) error { } // 查找启用中的条目 -func (this *ServerGroupDAO) FindEnabledServerGroup(id int64) (*ServerGroup, error) { - result, err := this.Query(). +func (this *ServerGroupDAO) FindEnabledServerGroup(tx *dbs.Tx, id int64) (*ServerGroup, error) { + result, err := this.Query(tx). Pk(id). Attr("state", ServerGroupStateEnabled). Find() @@ -65,19 +65,19 @@ func (this *ServerGroupDAO) FindEnabledServerGroup(id int64) (*ServerGroup, erro } // 根据主键查找名称 -func (this *ServerGroupDAO) FindServerGroupName(id int64) (string, error) { - return this.Query(). +func (this *ServerGroupDAO) FindServerGroupName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建分组 -func (this *ServerGroupDAO) CreateGroup(name string) (groupId int64, err error) { +func (this *ServerGroupDAO) CreateGroup(tx *dbs.Tx, name string) (groupId int64, err error) { op := NewServerGroupOperator() op.State = ServerGroupStateEnabled op.Name = name - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -85,20 +85,20 @@ func (this *ServerGroupDAO) CreateGroup(name string) (groupId int64, err error) } // 修改分组 -func (this *ServerGroupDAO) UpdateGroup(groupId int64, name string) error { +func (this *ServerGroupDAO) UpdateGroup(tx *dbs.Tx, groupId int64, name string) error { if groupId <= 0 { return errors.New("invalid groupId") } op := NewServerGroupOperator() op.Id = groupId op.Name = name - err := this.Save(op) + err := this.Save(tx, op) return err } // 查找所有分组 -func (this *ServerGroupDAO) FindAllEnabledGroups() (result []*ServerGroup, err error) { - _, err = this.Query(). +func (this *ServerGroupDAO) FindAllEnabledGroups(tx *dbs.Tx) (result []*ServerGroup, err error) { + _, err = this.Query(tx). State(ServerGroupStateEnabled). Desc("order"). AscPk(). @@ -108,9 +108,9 @@ func (this *ServerGroupDAO) FindAllEnabledGroups() (result []*ServerGroup, err e } // 修改分组排序 -func (this *ServerGroupDAO) UpdateGroupOrders(groupIds []int64) error { +func (this *ServerGroupDAO) UpdateGroupOrders(tx *dbs.Tx, groupIds []int64) error { for index, groupId := range groupIds { - _, err := this.Query(). + _, err := this.Query(tx). Pk(groupId). Set("order", len(groupIds)-index). Update() diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index ed50a173..7018d1da 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -42,19 +42,19 @@ func init() { func (this *SSLCertDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *SSLCertDAO) EnableSSLCert(id int64) error { - _, err := this.Query(). +func (this *SSLCertDAO) EnableSSLCert(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLCertStateEnabled). Update() @@ -62,8 +62,8 @@ func (this *SSLCertDAO) EnableSSLCert(id int64) error { } // 禁用条目 -func (this *SSLCertDAO) DisableSSLCert(id int64) error { - _, err := this.Query(). +func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLCertStateDisabled). Update() @@ -71,8 +71,8 @@ func (this *SSLCertDAO) DisableSSLCert(id int64) error { } // 查找启用中的条目 -func (this *SSLCertDAO) FindEnabledSSLCert(id int64) (*SSLCert, error) { - result, err := this.Query(). +func (this *SSLCertDAO) FindEnabledSSLCert(tx *dbs.Tx, id int64) (*SSLCert, error) { + result, err := this.Query(tx). Pk(id). Attr("state", SSLCertStateEnabled). Find() @@ -83,15 +83,15 @@ func (this *SSLCertDAO) FindEnabledSSLCert(id int64) (*SSLCert, error) { } // 根据主键查找名称 -func (this *SSLCertDAO) FindSSLCertName(id int64) (string, error) { - return this.Query(). +func (this *SSLCertDAO) FindSSLCertName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 创建证书 -func (this *SSLCertDAO) CreateCert(adminId int64, userId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) (int64, error) { +func (this *SSLCertDAO) CreateCert(tx *dbs.Tx, adminId int64, userId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) (int64, error) { op := NewSSLCertOperator() op.AdminId = adminId op.UserId = userId @@ -118,7 +118,7 @@ func (this *SSLCertDAO) CreateCert(adminId int64, userId int64, isOn bool, name } op.CommonNames = commonNamesJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return 0, err } @@ -126,7 +126,7 @@ func (this *SSLCertDAO) CreateCert(adminId int64, userId int64, isOn bool, name } // 修改证书 -func (this *SSLCertDAO) UpdateCert(certId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) error { +func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, certId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) error { if certId <= 0 { return errors.New("invalid certId") } @@ -161,13 +161,13 @@ func (this *SSLCertDAO) UpdateCert(certId int64, isOn bool, name string, descrip } op.CommonNames = commonNamesJSON - err = this.Save(op) + err = this.Save(tx, op) return err } // 组合配置 -func (this *SSLCertDAO) ComposeCertConfig(certId int64) (*sslconfigs.SSLCertConfig, error) { - cert, err := this.FindEnabledSSLCert(certId) +func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64) (*sslconfigs.SSLCertConfig, error) { + cert, err := this.FindEnabledSSLCert(tx, certId) if err != nil { return nil, err } @@ -210,8 +210,8 @@ func (this *SSLCertDAO) ComposeCertConfig(certId int64) (*sslconfigs.SSLCertConf } // 计算符合条件的证书数量 -func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) { - query := this.Query(). +func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) { + query := this.Query(tx). State(SSLCertStateEnabled) if isCA { query.Attr("isCA", true) @@ -240,8 +240,8 @@ func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, } // 列出符合条件的证书 -func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) { - query := this.Query(). +func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) { + query := this.Query(tx). State(SSLCertStateEnabled) if isCA { query.Attr("isCA", true) @@ -285,7 +285,7 @@ func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, } // 设置证书的ACME信息 -func (this *SSLCertDAO) UpdateCertACME(certId int64, acmeTaskId int64) error { +func (this *SSLCertDAO) UpdateCertACME(tx *dbs.Tx, certId int64, acmeTaskId int64) error { if certId <= 0 { return errors.New("invalid certId") } @@ -293,19 +293,19 @@ func (this *SSLCertDAO) UpdateCertACME(certId int64, acmeTaskId int64) error { op.Id = certId op.AcmeTaskId = acmeTaskId op.IsACME = true - err := this.Save(op) + err := this.Save(tx, op) return err } // 查找需要自动更新的任务 // 这里我们只返回有限的字段以节省内存 -func (this *SSLCertDAO) FindAllExpiringCerts(days int) (result []*SSLCert, err error) { +func (this *SSLCertDAO) FindAllExpiringCerts(tx *dbs.Tx, days int) (result []*SSLCert, err error) { if days < 0 { days = 0 } deltaSeconds := int64(days * 86400) - _, err = this.Query(). + _, err = this.Query(tx). State(SSLCertStateEnabled). Where("FROM_UNIXTIME(timeEndAt, '%Y-%m-%d')=:day AND FROM_UNIXTIME(notifiedAt, '%Y-%m-%d')!=:today"). Param("day", timeutil.FormatTime("Y-m-d", time.Now().Unix()+deltaSeconds)). @@ -318,8 +318,8 @@ func (this *SSLCertDAO) FindAllExpiringCerts(days int) (result []*SSLCert, err e } // 设置当前证书事件通知时间 -func (this *SSLCertDAO) UpdateCertNotifiedAt(certId int64) error { - _, err := this.Query(). +func (this *SSLCertDAO) UpdateCertNotifiedAt(tx *dbs.Tx, certId int64) error { + _, err := this.Query(tx). Pk(certId). Set("notifiedAt", time.Now().Unix()). Update() @@ -327,11 +327,11 @@ func (this *SSLCertDAO) UpdateCertNotifiedAt(certId int64) error { } // 检查用户权限 -func (this *SSLCertDAO) CheckUserCert(certId int64, userId int64) error { +func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) error { if certId <= 0 || userId <= 0 { return errors.New("not found") } - ok, err := this.Query(). + ok, err := this.Query(tx). Pk(certId). Attr("userId", userId). State(SSLCertStateEnabled). diff --git a/internal/db/models/ssl_cert_group_dao.go b/internal/db/models/ssl_cert_group_dao.go index e965215b..de37f4b5 100644 --- a/internal/db/models/ssl_cert_group_dao.go +++ b/internal/db/models/ssl_cert_group_dao.go @@ -33,8 +33,8 @@ func init() { } // 启用条目 -func (this *SSLCertGroupDAO) EnableSSLCertGroup(id uint32) error { - _, err := this.Query(). +func (this *SSLCertGroupDAO) EnableSSLCertGroup(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLCertGroupStateEnabled). Update() @@ -42,8 +42,8 @@ func (this *SSLCertGroupDAO) EnableSSLCertGroup(id uint32) error { } // 禁用条目 -func (this *SSLCertGroupDAO) DisableSSLCertGroup(id uint32) error { - _, err := this.Query(). +func (this *SSLCertGroupDAO) DisableSSLCertGroup(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLCertGroupStateDisabled). Update() @@ -51,8 +51,8 @@ func (this *SSLCertGroupDAO) DisableSSLCertGroup(id uint32) error { } // 查找启用中的条目 -func (this *SSLCertGroupDAO) FindEnabledSSLCertGroup(id uint32) (*SSLCertGroup, error) { - result, err := this.Query(). +func (this *SSLCertGroupDAO) FindEnabledSSLCertGroup(tx *dbs.Tx, id uint32) (*SSLCertGroup, error) { + result, err := this.Query(tx). Pk(id). Attr("state", SSLCertGroupStateEnabled). Find() @@ -63,8 +63,8 @@ func (this *SSLCertGroupDAO) FindEnabledSSLCertGroup(id uint32) (*SSLCertGroup, } // 根据主键查找名称 -func (this *SSLCertGroupDAO) FindSSLCertGroupName(id uint32) (string, error) { - return this.Query(). +func (this *SSLCertGroupDAO) FindSSLCertGroupName(tx *dbs.Tx, id uint32) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index e653d853..8ece89f9 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -41,19 +41,19 @@ func init() { func (this *SSLPolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } // 启用条目 -func (this *SSLPolicyDAO) EnableSSLPolicy(id int64) error { - _, err := this.Query(). +func (this *SSLPolicyDAO) EnableSSLPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLPolicyStateEnabled). Update() @@ -61,8 +61,8 @@ func (this *SSLPolicyDAO) EnableSSLPolicy(id int64) error { } // 禁用条目 -func (this *SSLPolicyDAO) DisableSSLPolicy(id int64) error { - _, err := this.Query(). +func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", SSLPolicyStateDisabled). Update() @@ -70,8 +70,8 @@ func (this *SSLPolicyDAO) DisableSSLPolicy(id int64) error { } // 查找启用中的条目 -func (this *SSLPolicyDAO) FindEnabledSSLPolicy(id int64) (*SSLPolicy, error) { - result, err := this.Query(). +func (this *SSLPolicyDAO) FindEnabledSSLPolicy(tx *dbs.Tx, id int64) (*SSLPolicy, error) { + result, err := this.Query(tx). Pk(id). Attr("state", SSLPolicyStateEnabled). Find() @@ -82,8 +82,8 @@ func (this *SSLPolicyDAO) FindEnabledSSLPolicy(id int64) (*SSLPolicy, error) { } // 组合配置 -func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPolicy, error) { - policy, err := this.FindEnabledSSLPolicy(policyId) +func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*sslconfigs.SSLPolicy, error) { + policy, err := this.FindEnabledSSLPolicy(tx, policyId) if err != nil { return nil, err } @@ -106,7 +106,7 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPo } if len(refs) > 0 { for _, ref := range refs { - certConfig, err := SharedSSLCertDAO.ComposeCertConfig(ref.CertId) + certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPo } if len(refs) > 0 { for _, ref := range refs { - certConfig, err := SharedSSLCertDAO.ComposeCertConfig(ref.CertId) + certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) if err != nil { return nil, err } @@ -166,12 +166,12 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPo } // 查询使用单个证书的所有策略ID -func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (policyIds []int64, err error) { +func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId int64) (policyIds []int64, err error) { if certId <= 0 { return } - ones, err := this.Query(). + ones, err := this.Query(tx). State(SSLPolicyStateEnabled). ResultPk(). Where(`JSON_CONTAINS(certs, '{"certId": ` + strconv.FormatInt(certId, 10) + ` }')`). @@ -187,7 +187,7 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (polic } // 创建Policy -func (this *SSLPolicyDAO) CreatePolicy(adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { +func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { op := NewSSLPolicyOperator() op.State = SSLPolicyStateEnabled op.IsOn = true @@ -217,7 +217,7 @@ func (this *SSLPolicyDAO) CreatePolicy(adminId int64, userId int64, http2Enabled } op.CipherSuites = cipherSuitesJSON } - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -226,7 +226,7 @@ func (this *SSLPolicyDAO) CreatePolicy(adminId int64, userId int64, http2Enabled // 修改Policy // 创建Policy -func (this *SSLPolicyDAO) UpdatePolicy(policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error { +func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error { if policyId <= 0 { return errors.New("invalid policyId") } @@ -258,16 +258,16 @@ func (this *SSLPolicyDAO) UpdatePolicy(policyId int64, http2Enabled bool, minVer } else { op.CipherSuites = "[]" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 检查是否为用户所属策略 -func (this *SSLPolicyDAO) CheckUserPolicy(policyId int64, userId int64) error { +func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error { if policyId <= 0 || userId <= 0 { return errors.New("not found") } - ok, err := this.Query(). + ok, err := this.Query(tx). State(SSLPolicyStateEnabled). Pk(policyId). Attr("userId", userId). diff --git a/internal/db/models/sub_user_dao.go b/internal/db/models/sub_user_dao.go index ba822c45..74254f0c 100644 --- a/internal/db/models/sub_user_dao.go +++ b/internal/db/models/sub_user_dao.go @@ -33,8 +33,8 @@ func init() { } // 启用条目 -func (this *SubUserDAO) EnableSubUser(id uint32) error { - _, err := this.Query(). +func (this *SubUserDAO) EnableSubUser(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", SubUserStateEnabled). Update() @@ -42,8 +42,8 @@ func (this *SubUserDAO) EnableSubUser(id uint32) error { } // 禁用条目 -func (this *SubUserDAO) DisableSubUser(id uint32) error { - _, err := this.Query(). +func (this *SubUserDAO) DisableSubUser(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", SubUserStateDisabled). Update() @@ -51,8 +51,8 @@ func (this *SubUserDAO) DisableSubUser(id uint32) error { } // 查找启用中的条目 -func (this *SubUserDAO) FindEnabledSubUser(id uint32) (*SubUser, error) { - result, err := this.Query(). +func (this *SubUserDAO) FindEnabledSubUser(tx *dbs.Tx, id uint32) (*SubUser, error) { + result, err := this.Query(tx). Pk(id). Attr("state", SubUserStateEnabled). Find() @@ -63,8 +63,8 @@ func (this *SubUserDAO) FindEnabledSubUser(id uint32) (*SubUser, error) { } // 根据主键查找名称 -func (this *SubUserDAO) FindSubUserName(id uint32) (string, error) { - return this.Query(). +func (this *SubUserDAO) FindSubUserName(tx *dbs.Tx, id uint32) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") diff --git a/internal/db/models/sys_event_dao.go b/internal/db/models/sys_event_dao.go index c9fcc413..045770ab 100644 --- a/internal/db/models/sys_event_dao.go +++ b/internal/db/models/sys_event_dao.go @@ -30,7 +30,7 @@ func init() { } // 创建事件 -func (this *SysEventDAO) CreateEvent(event EventInterface) error { +func (this *SysEventDAO) CreateEvent(tx *dbs.Tx, event EventInterface) error { if event == nil { return errors.New("event should not be nil") } @@ -44,13 +44,13 @@ func (this *SysEventDAO) CreateEvent(event EventInterface) error { } op.Params = eventJSON - err = this.Save(op) + err = this.Save(tx, op) return err } // 查找事件 -func (this *SysEventDAO) FindEvents(size int64) (result []*SysEvent, err error) { - _, err = this.Query(). +func (this *SysEventDAO) FindEvents(tx *dbs.Tx, size int64) (result []*SysEvent, err error) { + _, err = this.Query(tx). Asc(). Limit(size). Slice(&result). @@ -59,8 +59,8 @@ func (this *SysEventDAO) FindEvents(size int64) (result []*SysEvent, err error) } // 删除事件 -func (this *SysEventDAO) DeleteEvent(eventId int64) error { - _, err := this.Query(). +func (this *SysEventDAO) DeleteEvent(tx *dbs.Tx, eventId int64) error { + _, err := this.Query(tx). Pk(eventId). Delete() return err diff --git a/internal/db/models/sys_event_types.go b/internal/db/models/sys_event_types.go index 83635f16..10a21321 100644 --- a/internal/db/models/sys_event_types.go +++ b/internal/db/models/sys_event_types.go @@ -1,6 +1,7 @@ package models import ( + "github.com/iwind/TeaGo/dbs" "reflect" ) @@ -33,12 +34,14 @@ func (this *ServerChangeEvent) Type() string { } func (this *ServerChangeEvent) Run() error { - serverIds, err := SharedServerDAO.FindAllEnabledServerIds() + var tx *dbs.Tx + + serverIds, err := SharedServerDAO.FindAllEnabledServerIds(tx) if err != nil { return err } for _, serverId := range serverIds { - isChanged, err := SharedServerDAO.RenewServerConfig(serverId, true) + isChanged, err := SharedServerDAO.RenewServerConfig(tx, serverId, true) if err != nil { return err } @@ -47,14 +50,14 @@ func (this *ServerChangeEvent) Run() error { } // 检查节点是否需要更新 - isOk, clusterId, err := SharedServerDAO.FindServerNodeFilters(serverId) + isOk, clusterId, err := SharedServerDAO.FindServerNodeFilters(tx, serverId) if err != nil { return err } if !isOk { continue } - err = SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(clusterId) + err = SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, clusterId) if err != nil { return err } diff --git a/internal/db/models/sys_locker_dao.go b/internal/db/models/sys_locker_dao.go index 27c918ab..caaed9f9 100644 --- a/internal/db/models/sys_locker_dao.go +++ b/internal/db/models/sys_locker_dao.go @@ -30,10 +30,10 @@ func init() { } // 开锁 -func (this *SysLockerDAO) Lock(key string, timeout int64) (bool, error) { +func (this *SysLockerDAO) Lock(tx *dbs.Tx, key string, timeout int64) (bool, error) { maxErrors := 5 for { - one, err := this.Query(). + one, err := this.Query(tx). Attr("key", key). Find() if err != nil { @@ -50,7 +50,7 @@ func (this *SysLockerDAO) Lock(key string, timeout int64) (bool, error) { op.Key = key op.TimeoutAt = time.Now().Unix() + timeout op.Version = 1 - err := this.Save(op) + err := this.Save(tx, op) if err != nil { maxErrors-- if maxErrors < 0 { @@ -73,7 +73,7 @@ func (this *SysLockerDAO) Lock(key string, timeout int64) (bool, error) { op.Id = locker.Id op.Version = locker.Version + 1 op.TimeoutAt = time.Now().Unix() + timeout - err = this.Save(op) + err = this.Save(tx, op) if err != nil { maxErrors-- if maxErrors < 0 { @@ -83,7 +83,7 @@ func (this *SysLockerDAO) Lock(key string, timeout int64) (bool, error) { } // 再次查询版本 - version, err := this.Query(). + version, err := this.Query(tx). Attr("key", key). Result("version"). FindCol("0") @@ -103,8 +103,8 @@ func (this *SysLockerDAO) Lock(key string, timeout int64) (bool, error) { } // 解锁 -func (this *SysLockerDAO) Unlock(key string) error { - _, err := this.Query(). +func (this *SysLockerDAO) Unlock(tx *dbs.Tx, key string) error { + _, err := this.Query(tx). Attr("key", key). Set("timeoutAt", time.Now().Unix()-86400*365). Update() diff --git a/internal/db/models/sys_setting_dao.go b/internal/db/models/sys_setting_dao.go index 4af21728..72180e94 100644 --- a/internal/db/models/sys_setting_dao.go +++ b/internal/db/models/sys_setting_dao.go @@ -42,7 +42,7 @@ func init() { } // 设置配置 -func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, codeFormatArgs ...interface{}) error { +func (this *SysSettingDAO) UpdateSetting(tx *dbs.Tx, codeFormat string, valueJSON []byte, codeFormatArgs ...interface{}) error { if len(codeFormatArgs) > 0 { codeFormat = fmt.Sprintf(codeFormat, codeFormatArgs...) } @@ -50,7 +50,7 @@ func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, co countRetries := 3 var lastErr error for i := 0; i < countRetries; i++ { - settingId, err := this.Query(). + settingId, err := this.Query(tx). Attr("code", codeFormat). ResultPk(). FindInt64Col(0) @@ -63,7 +63,7 @@ func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, co op := NewSysSettingOperator() op.Code = codeFormat op.Value = valueJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { lastErr = err @@ -77,7 +77,7 @@ func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, co op := NewSysSettingOperator() op.Id = settingId op.Value = valueJSON - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return err } @@ -87,11 +87,11 @@ func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, co } // 读取配置 -func (this *SysSettingDAO) ReadSetting(code string, codeFormatArgs ...interface{}) (valueJSON []byte, err error) { +func (this *SysSettingDAO) ReadSetting(tx *dbs.Tx, code string, codeFormatArgs ...interface{}) (valueJSON []byte, err error) { if len(codeFormatArgs) > 0 { code = fmt.Sprintf(code, codeFormatArgs...) } - col, err := this.Query(). + col, err := this.Query(tx). Attr("code", code). Result("value"). FindStringCol("") @@ -99,8 +99,8 @@ func (this *SysSettingDAO) ReadSetting(code string, codeFormatArgs ...interface{ } // 对比配置中的数字大小 -func (this *SysSettingDAO) CompareInt64Setting(code string, anotherValue int64) (int8, error) { - valueJSON, err := this.ReadSetting(code) +func (this *SysSettingDAO) CompareInt64Setting(tx *dbs.Tx, code string, anotherValue int64) (int8, error) { + valueJSON, err := this.ReadSetting(tx, code) if err != nil { return 0, err } @@ -115,8 +115,8 @@ func (this *SysSettingDAO) CompareInt64Setting(code string, anotherValue int64) } // 读取全局配置 -func (this *SysSettingDAO) ReadGlobalConfig() (*serverconfigs.GlobalConfig, error) { - globalConfigData, err := this.ReadSetting(SettingCodeServerGlobalConfig) +func (this *SysSettingDAO) ReadGlobalConfig(tx *dbs.Tx) (*serverconfigs.GlobalConfig, error) { + globalConfigData, err := this.ReadSetting(tx, SettingCodeServerGlobalConfig) if err != nil { return nil, err } diff --git a/internal/db/models/tcp_firewall_policy_dao.go b/internal/db/models/tcp_firewall_policy_dao.go index 8fb56fc0..0d6aa71f 100644 --- a/internal/db/models/tcp_firewall_policy_dao.go +++ b/internal/db/models/tcp_firewall_policy_dao.go @@ -31,12 +31,12 @@ func init() { func (this *TCPFirewallPolicyDAO) Init() { this.DAOObject.Init() this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) }) } diff --git a/internal/db/models/user_access_key_dao.go b/internal/db/models/user_access_key_dao.go index 7547c669..9ce811d4 100644 --- a/internal/db/models/user_access_key_dao.go +++ b/internal/db/models/user_access_key_dao.go @@ -35,8 +35,8 @@ func init() { } // 启用条目 -func (this *UserAccessKeyDAO) EnableUserAccessKey(id int64) error { - _, err := this.Query(). +func (this *UserAccessKeyDAO) EnableUserAccessKey(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", UserAccessKeyStateEnabled). Update() @@ -44,8 +44,8 @@ func (this *UserAccessKeyDAO) EnableUserAccessKey(id int64) error { } // 禁用条目 -func (this *UserAccessKeyDAO) DisableUserAccessKey(id int64) error { - _, err := this.Query(). +func (this *UserAccessKeyDAO) DisableUserAccessKey(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", UserAccessKeyStateDisabled). Update() @@ -53,8 +53,8 @@ func (this *UserAccessKeyDAO) DisableUserAccessKey(id int64) error { } // 查找启用中的条目 -func (this *UserAccessKeyDAO) FindEnabledUserAccessKey(id int64) (*UserAccessKey, error) { - result, err := this.Query(). +func (this *UserAccessKeyDAO) FindEnabledUserAccessKey(tx *dbs.Tx, id int64) (*UserAccessKey, error) { + result, err := this.Query(tx). Pk(id). Attr("state", UserAccessKeyStateEnabled). Find() @@ -65,7 +65,7 @@ func (this *UserAccessKeyDAO) FindEnabledUserAccessKey(id int64) (*UserAccessKey } // 创建Key -func (this *UserAccessKeyDAO) CreateAccessKey(userId int64, description string) (int64, error) { +func (this *UserAccessKeyDAO) CreateAccessKey(tx *dbs.Tx, userId int64, description string) (int64, error) { if userId <= 0 { return 0, errors.New("invalid userId") } @@ -76,12 +76,12 @@ func (this *UserAccessKeyDAO) CreateAccessKey(userId int64, description string) op.Secret = rands.String(32) op.IsOn = true op.State = UserAccessKeyStateEnabled - return this.SaveInt64(op) + return this.SaveInt64(tx, op) } // 查找用户所有的Key -func (this *UserAccessKeyDAO) FindAllEnabledAccessKeys(userId int64) (result []*UserAccessKey, err error) { - _, err = this.Query(). +func (this *UserAccessKeyDAO) FindAllEnabledAccessKeys(tx *dbs.Tx, userId int64) (result []*UserAccessKey, err error) { + _, err = this.Query(tx). State(UserAccessKeyStateEnabled). DescPk(). Slice(&result). @@ -90,8 +90,8 @@ func (this *UserAccessKeyDAO) FindAllEnabledAccessKeys(userId int64) (result []* } // 检查用户的AccessKey -func (this *UserAccessKeyDAO) CheckUserAccessKey(userId int64, accessKeyId int64) (bool, error) { - return this.Query(). +func (this *UserAccessKeyDAO) CheckUserAccessKey(tx *dbs.Tx, userId int64, accessKeyId int64) (bool, error) { + return this.Query(tx). Pk(accessKeyId). State(UserAccessKeyStateEnabled). Attr("userId", userId). @@ -99,11 +99,11 @@ func (this *UserAccessKeyDAO) CheckUserAccessKey(userId int64, accessKeyId int64 } // 设置是否启用 -func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(accessKeyId int64, isOn bool) error { +func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(tx *dbs.Tx, accessKeyId int64, isOn bool) error { if accessKeyId <= 0 { return errors.New("invalid accessKeyId") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(accessKeyId). Set("isOn", isOn). Update() @@ -111,8 +111,8 @@ func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(accessKeyId int64, isOn bool) } // 根据UniqueId查找AccessKey -func (this *UserAccessKeyDAO) FindAccessKeyWithUniqueId(uniqueId string) (*UserAccessKey, error) { - one, err := this.Query(). +func (this *UserAccessKeyDAO) FindAccessKeyWithUniqueId(tx *dbs.Tx, uniqueId string) (*UserAccessKey, error) { + one, err := this.Query(tx). Attr("uniqueId", uniqueId). Attr("isOn", true). State(UserAccessKeyStateEnabled). diff --git a/internal/db/models/user_bill_dao.go b/internal/db/models/user_bill_dao.go index f0a8b2b4..2d5a46f7 100644 --- a/internal/db/models/user_bill_dao.go +++ b/internal/db/models/user_bill_dao.go @@ -36,8 +36,8 @@ func init() { } // 计算账单数量 -func (this *UserBillDAO) CountAllUserBills(isPaid int32, userId int64, month string) (int64, error) { - query := this.Query() +func (this *UserBillDAO) CountAllUserBills(tx *dbs.Tx, isPaid int32, userId int64, month string) (int64, error) { + query := this.Query(tx) if isPaid == 0 { query.Attr("isPaid", 0) } else if isPaid > 0 { @@ -53,8 +53,8 @@ func (this *UserBillDAO) CountAllUserBills(isPaid int32, userId int64, month str } // 列出单页账单 -func (this *UserBillDAO) ListUserBills(isPaid int32, userId int64, month string, offset, size int64) (result []*UserBill, err error) { - query := this.Query() +func (this *UserBillDAO) ListUserBills(tx *dbs.Tx, isPaid int32, userId int64, month string, offset, size int64) (result []*UserBill, err error) { + query := this.Query(tx) if isPaid == 0 { query.Attr("isPaid", 0) } else if isPaid > 0 { @@ -76,7 +76,7 @@ func (this *UserBillDAO) ListUserBills(isPaid int32, userId int64, month string, } // 创建账单 -func (this *UserBillDAO) CreateBill(userId int64, billType BillType, description string, amount float32, month string) (int64, error) { +func (this *UserBillDAO) CreateBill(tx *dbs.Tx, userId int64, billType BillType, description string, amount float32, month string) (int64, error) { op := NewUserBillOperator() op.UserId = userId op.Type = billType @@ -84,12 +84,12 @@ func (this *UserBillDAO) CreateBill(userId int64, billType BillType, description op.Amount = amount op.Month = month op.IsPaid = false - return this.SaveInt64(op) + return this.SaveInt64(tx, op) } // 检查是否有当月账单 -func (this *UserBillDAO) ExistBill(userId int64, billType BillType, month string) (bool, error) { - return this.Query(). +func (this *UserBillDAO) ExistBill(tx *dbs.Tx, userId int64, billType BillType, month string) (bool, error) { + return this.Query(tx). Attr("userId", userId). Attr("month", month). Attr("type", billType). @@ -98,12 +98,12 @@ func (this *UserBillDAO) ExistBill(userId int64, billType BillType, month string // 生成账单 // month 格式YYYYMM -func (this *UserBillDAO) GenerateBills(month string) error { +func (this *UserBillDAO) GenerateBills(tx *dbs.Tx, month string) error { // 用户 offset := int64(0) size := int64(100) // 每次只查询N次,防止由于执行时间过长而锁表 for { - userIds, err := SharedUserDAO.ListEnabledUserIds(offset, size) + userIds, err := SharedUserDAO.ListEnabledUserIds(tx, offset, size) if err != nil { return err } @@ -114,7 +114,7 @@ func (this *UserBillDAO) GenerateBills(month string) error { for _, userId := range userIds { // CDN流量账单 - err := this.generateTrafficBill(userId, month) + err := this.generateTrafficBill(tx, userId, month) if err != nil { return err } @@ -126,9 +126,9 @@ func (this *UserBillDAO) GenerateBills(month string) error { // 生成CDN流量账单 // month 格式YYYYMM -func (this *UserBillDAO) generateTrafficBill(userId int64, month string) error { +func (this *UserBillDAO) generateTrafficBill(tx *dbs.Tx, userId int64, month string) error { // 检查是否已经有账单了 - b, err := this.ExistBill(userId, BillTypeTraffic, month) + b, err := this.ExistBill(tx, userId, BillTypeTraffic, month) if err != nil { return err } @@ -137,7 +137,7 @@ func (this *UserBillDAO) generateTrafficBill(userId int64, month string) error { } // TODO 优化使用缓存 - regions, err := SharedNodeRegionDAO.FindAllEnabledRegionPrices() + regions, err := SharedNodeRegionDAO.FindAllEnabledRegionPrices(tx) if err != nil { return err } @@ -145,7 +145,7 @@ func (this *UserBillDAO) generateTrafficBill(userId int64, month string) error { return nil } - priceItems, err := SharedNodePriceItemDAO.FindAllEnabledRegionPrices(NodePriceTypeTraffic) + priceItems, err := SharedNodePriceItemDAO.FindAllEnabledRegionPrices(tx, NodePriceTypeTraffic) if err != nil { return err } @@ -164,7 +164,7 @@ func (this *UserBillDAO) generateTrafficBill(userId int64, month string) error { return err } - trafficBytes, err := SharedServerDailyStatDAO.SumUserMonthly(userId, int64(region.Id), month) + trafficBytes, err := SharedServerDailyStatDAO.SumUserMonthly(tx, userId, int64(region.Id), month) if err != nil { return err } @@ -192,7 +192,7 @@ func (this *UserBillDAO) generateTrafficBill(userId int64, month string) error { } // 创建账单 - _, err = this.CreateBill(userId, BillTypeTraffic, "按流量计费", cost, month) + _, err = this.CreateBill(tx, userId, BillTypeTraffic, "按流量计费", cost, month) return err } diff --git a/internal/db/models/user_dao.go b/internal/db/models/user_dao.go index c7c180b5..02df7da4 100644 --- a/internal/db/models/user_dao.go +++ b/internal/db/models/user_dao.go @@ -37,24 +37,24 @@ func init() { } // 启用条目 -func (this *UserDAO) EnableUser(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *UserDAO) EnableUser(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", UserStateEnabled). Update() } // 禁用条目 -func (this *UserDAO) DisableUser(id int64) (rowsAffected int64, err error) { - return this.Query(). +func (this *UserDAO) DisableUser(tx *dbs.Tx, id int64) (rowsAffected int64, err error) { + return this.Query(tx). Pk(id). Set("state", UserStateDisabled). Update() } // 查找启用中的条目 -func (this *UserDAO) FindEnabledUser(id int64) (*User, error) { - result, err := this.Query(). +func (this *UserDAO) FindEnabledUser(tx *dbs.Tx, id int64) (*User, error) { + result, err := this.Query(tx). Pk(id). Attr("state", UserStateEnabled). Find() @@ -65,8 +65,8 @@ func (this *UserDAO) FindEnabledUser(id int64) (*User, error) { } // 查找用户基本信息 -func (this *UserDAO) FindEnabledBasicUser(id int64) (*User, error) { - result, err := this.Query(). +func (this *UserDAO) FindEnabledBasicUser(tx *dbs.Tx, id int64) (*User, error) { + result, err := this.Query(tx). Pk(id). Attr("state", UserStateEnabled). Result("id", "fullname", "username"). @@ -78,15 +78,15 @@ func (this *UserDAO) FindEnabledBasicUser(id int64) (*User, error) { } // 获取管理员名称 -func (this *UserDAO) FindUserFullname(userId int64) (string, error) { - return this.Query(). +func (this *UserDAO) FindUserFullname(tx *dbs.Tx, userId int64) (string, error) { + return this.Query(tx). Pk(userId). Result("fullname"). FindStringCol("") } // 创建用户 -func (this *UserDAO) CreateUser(username string, password string, fullname string, mobile string, tel string, email string, remark string, source string, clusterId int64) (int64, error) { +func (this *UserDAO) CreateUser(tx *dbs.Tx, username string, password string, fullname string, mobile string, tel string, email string, remark string, source string, clusterId int64) (int64, error) { op := NewUserOperator() op.Username = username op.Password = stringutil.Md5(password) @@ -100,7 +100,7 @@ func (this *UserDAO) CreateUser(username string, password string, fullname strin op.IsOn = true op.State = UserStateEnabled - err := this.Save(op) + err := this.Save(tx, op) if err != nil { return 0, err } @@ -108,7 +108,7 @@ func (this *UserDAO) CreateUser(username string, password string, fullname strin } // 修改用户 -func (this *UserDAO) UpdateUser(userId int64, username string, password string, fullname string, mobile string, tel string, email string, remark string, isOn bool, clusterId int64) error { +func (this *UserDAO) UpdateUser(tx *dbs.Tx, userId int64, username string, password string, fullname string, mobile string, tel string, email string, remark string, isOn bool, clusterId int64) error { if userId <= 0 { return errors.New("invalid userId") } @@ -125,23 +125,23 @@ func (this *UserDAO) UpdateUser(userId int64, username string, password string, op.Remark = remark op.IsOn = isOn op.ClusterId = clusterId - err := this.Save(op) + err := this.Save(tx, op) return err } // 修改用户基本信息 -func (this *UserDAO) UpdateUserInfo(userId int64, fullname string) error { +func (this *UserDAO) UpdateUserInfo(tx *dbs.Tx, userId int64, fullname string) error { if userId <= 0 { return errors.New("invalid userId") } op := NewUserOperator() op.Id = userId op.Fullname = fullname - return this.Save(op) + return this.Save(tx, op) } // 修改用户登录信息 -func (this *UserDAO) UpdateUserLogin(userId int64, username string, password string) error { +func (this *UserDAO) UpdateUserLogin(tx *dbs.Tx, userId int64, username string, password string) error { if userId <= 0 { return errors.New("invalid userId") } @@ -151,13 +151,13 @@ func (this *UserDAO) UpdateUserLogin(userId int64, username string, password str if len(password) > 0 { op.Password = stringutil.Md5(password) } - err := this.Save(op) + err := this.Save(tx, op) return err } // 计算用户数量 -func (this *UserDAO) CountAllEnabledUsers(keyword string) (int64, error) { - query := this.Query() +func (this *UserDAO) CountAllEnabledUsers(tx *dbs.Tx, keyword string) (int64, error) { + query := this.Query(tx) query.State(UserStateEnabled) if len(keyword) > 0 { query.Where("(username LIKE :keyword OR fullname LIKE :keyword OR mobile LIKE :keyword OR email LIKE :keyword OR tel LIKE :keyword OR remark LIKE :keyword)"). @@ -167,8 +167,8 @@ func (this *UserDAO) CountAllEnabledUsers(keyword string) (int64, error) { } // 列出单页用户 -func (this *UserDAO) ListEnabledUsers(keyword string) (result []*User, err error) { - query := this.Query() +func (this *UserDAO) ListEnabledUsers(tx *dbs.Tx, keyword string) (result []*User, err error) { + query := this.Query(tx) query.State(UserStateEnabled) if len(keyword) > 0 { query.Where("(username LIKE :keyword OR fullname LIKE :keyword OR mobile LIKE :keyword OR email LIKE :keyword OR tel LIKE :keyword OR remark LIKE :keyword)"). @@ -182,8 +182,8 @@ func (this *UserDAO) ListEnabledUsers(keyword string) (result []*User, err error } // 检查用户名是否存在 -func (this *UserDAO) ExistUser(userId int64, username string) (bool, error) { - return this.Query(). +func (this *UserDAO) ExistUser(tx *dbs.Tx, userId int64, username string) (bool, error) { + return this.Query(tx). State(UserStateEnabled). Attr("username", username). Neq("id", userId). @@ -191,8 +191,8 @@ func (this *UserDAO) ExistUser(userId int64, username string) (bool, error) { } // 列出单页的用户ID -func (this *UserDAO) ListEnabledUserIds(offset, size int64) ([]int64, error) { - ones, _, err := this.Query(). +func (this *UserDAO) ListEnabledUserIds(tx *dbs.Tx, offset, size int64) ([]int64, error) { + ones, _, err := this.Query(tx). ResultPk(). State(UserStateEnabled). Offset(offset). @@ -210,11 +210,11 @@ func (this *UserDAO) ListEnabledUserIds(offset, size int64) ([]int64, error) { } // 检查用户名、密码 -func (this *UserDAO) CheckUserPassword(username string, encryptedPassword string) (int64, error) { +func (this *UserDAO) CheckUserPassword(tx *dbs.Tx, username string, encryptedPassword string) (int64, error) { if len(username) == 0 || len(encryptedPassword) == 0 { return 0, nil } - return this.Query(). + return this.Query(tx). Attr("username", username). Attr("password", encryptedPassword). Attr("state", UserStateEnabled). @@ -224,22 +224,22 @@ func (this *UserDAO) CheckUserPassword(username string, encryptedPassword string } // 查找用户所在集群 -func (this *UserDAO) FindUserClusterId(userId int64) (int64, error) { - return this.Query(). +func (this *UserDAO) FindUserClusterId(tx *dbs.Tx, userId int64) (int64, error) { + return this.Query(tx). Pk(userId). Result("clusterId"). FindInt64Col(0) } // 更新用户Features -func (this *UserDAO) UpdateUserFeatures(userId int64, featuresJSON []byte) error { +func (this *UserDAO) UpdateUserFeatures(tx *dbs.Tx, userId int64, featuresJSON []byte) error { if userId <= 0 { return errors.New("invalid userId") } if len(featuresJSON) == 0 { featuresJSON = []byte("[]") } - _, err := this.Query(). + _, err := this.Query(tx). Pk(userId). Set("features", featuresJSON). Update() @@ -250,8 +250,8 @@ func (this *UserDAO) UpdateUserFeatures(userId int64, featuresJSON []byte) error } // 查找用户Features -func (this *UserDAO) FindUserFeatures(userId int64) ([]*UserFeature, error) { - featuresJSON, err := this.Query(). +func (this *UserDAO) FindUserFeatures(tx *dbs.Tx, userId int64) ([]*UserFeature, error) { + featuresJSON, err := this.Query(tx). Pk(userId). Result("features"). FindStringCol("") diff --git a/internal/db/models/user_node_dao.go b/internal/db/models/user_node_dao.go index 708eb90e..56dc1b39 100644 --- a/internal/db/models/user_node_dao.go +++ b/internal/db/models/user_node_dao.go @@ -39,8 +39,8 @@ func init() { } // 启用条目 -func (this *UserNodeDAO) EnableUserNode(id uint32) error { - _, err := this.Query(). +func (this *UserNodeDAO) EnableUserNode(tx *dbs.Tx, id uint32) error { + _, err := this.Query(tx). Pk(id). Set("state", UserNodeStateEnabled). Update() @@ -48,8 +48,8 @@ func (this *UserNodeDAO) EnableUserNode(id uint32) error { } // 禁用条目 -func (this *UserNodeDAO) DisableUserNode(id int64) error { - _, err := this.Query(). +func (this *UserNodeDAO) DisableUserNode(tx *dbs.Tx, id int64) error { + _, err := this.Query(tx). Pk(id). Set("state", UserNodeStateDisabled). Update() @@ -57,8 +57,8 @@ func (this *UserNodeDAO) DisableUserNode(id int64) error { } // 查找启用中的条目 -func (this *UserNodeDAO) FindEnabledUserNode(id int64) (*UserNode, error) { - result, err := this.Query(). +func (this *UserNodeDAO) FindEnabledUserNode(tx *dbs.Tx, id int64) (*UserNode, error) { + result, err := this.Query(tx). Pk(id). Attr("state", UserNodeStateEnabled). Find() @@ -69,16 +69,16 @@ func (this *UserNodeDAO) FindEnabledUserNode(id int64) (*UserNode, error) { } // 根据主键查找名称 -func (this *UserNodeDAO) FindUserNodeName(id int64) (string, error) { - return this.Query(). +func (this *UserNodeDAO) FindUserNodeName(tx *dbs.Tx, id int64) (string, error) { + return this.Query(tx). Pk(id). Result("name"). FindStringCol("") } // 列出所有可用用户节点 -func (this *UserNodeDAO) FindAllEnabledUserNodes() (result []*UserNode, err error) { - _, err = this.Query(). +func (this *UserNodeDAO) FindAllEnabledUserNodes(tx *dbs.Tx) (result []*UserNode, err error) { + _, err = this.Query(tx). State(UserNodeStateEnabled). Desc("order"). AscPk(). @@ -88,15 +88,15 @@ func (this *UserNodeDAO) FindAllEnabledUserNodes() (result []*UserNode, err erro } // 计算用户节点数量 -func (this *UserNodeDAO) CountAllEnabledUserNodes() (int64, error) { - return this.Query(). +func (this *UserNodeDAO) CountAllEnabledUserNodes(tx *dbs.Tx) (int64, error) { + return this.Query(tx). State(UserNodeStateEnabled). Count() } // 列出单页的用户节点 -func (this *UserNodeDAO) ListEnabledUserNodes(offset int64, size int64) (result []*UserNode, err error) { - _, err = this.Query(). +func (this *UserNodeDAO) ListEnabledUserNodes(tx *dbs.Tx, offset int64, size int64) (result []*UserNode, err error) { + _, err = this.Query(tx). State(UserNodeStateEnabled). Offset(offset). Limit(size). @@ -108,7 +108,7 @@ func (this *UserNodeDAO) ListEnabledUserNodes(offset int64, size int64) (result } // 根据主机名和端口获取ID -func (this *UserNodeDAO) FindEnabledUserNodeIdWithAddr(protocol string, host string, port int) (int64, error) { +func (this *UserNodeDAO) FindEnabledUserNodeIdWithAddr(tx *dbs.Tx, protocol string, host string, port int) (int64, error) { addr := maps.Map{ "protocol": protocol, "host": host, @@ -119,7 +119,7 @@ func (this *UserNodeDAO) FindEnabledUserNodeIdWithAddr(protocol string, host str return 0, err } - one, err := this.Query(). + one, err := this.Query(tx). State(UserNodeStateEnabled). Where("JSON_CONTAINS(accessAddrs, :addr)"). Param("addr", string(addrJSON)). @@ -135,13 +135,13 @@ func (this *UserNodeDAO) FindEnabledUserNodeIdWithAddr(protocol string, host str } // 创建用户节点 -func (this *UserNodeDAO) CreateUserNode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { - uniqueId, err := this.genUniqueId() +func (this *UserNodeDAO) CreateUserNode(tx *dbs.Tx, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { + uniqueId, err := this.genUniqueId(tx) if err != nil { return 0, err } secret := rands.String(32) - err = NewApiTokenDAO().CreateAPIToken(uniqueId, secret, NodeRoleUser) + err = NewApiTokenDAO().CreateAPIToken(tx, uniqueId, secret, NodeRoleUser) if err != nil { return } @@ -164,7 +164,7 @@ func (this *UserNodeDAO) CreateUserNode(name string, description string, httpJSO } op.State = NodeStateEnabled - err = this.Save(op) + err = this.Save(tx, op) if err != nil { return } @@ -173,7 +173,7 @@ func (this *UserNodeDAO) CreateUserNode(name string, description string, httpJSO } // 修改用户节点 -func (this *UserNodeDAO) UpdateUserNode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { +func (this *UserNodeDAO) UpdateUserNode(tx *dbs.Tx, nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -200,13 +200,13 @@ func (this *UserNodeDAO) UpdateUserNode(nodeId int64, name string, description s op.AccessAddrs = "null" } - err := this.Save(op) + err := this.Save(tx, op) return err } // 根据唯一ID获取节点信息 -func (this *UserNodeDAO) FindEnabledUserNodeWithUniqueId(uniqueId string) (*UserNode, error) { - result, err := this.Query(). +func (this *UserNodeDAO) FindEnabledUserNodeWithUniqueId(tx *dbs.Tx, uniqueId string) (*UserNode, error) { + result, err := this.Query(tx). Attr("uniqueId", uniqueId). Attr("state", UserNodeStateEnabled). Find() @@ -217,10 +217,10 @@ func (this *UserNodeDAO) FindEnabledUserNodeWithUniqueId(uniqueId string) (*User } // 生成唯一ID -func (this *UserNodeDAO) genUniqueId() (string, error) { +func (this *UserNodeDAO) genUniqueId(tx *dbs.Tx) (string, error) { for { uniqueId := rands.HexString(32) - ok, err := this.Query(). + ok, err := this.Query(tx). Attr("uniqueId", uniqueId). Exist() if err != nil { diff --git a/internal/db/models/user_node_model_ext.go b/internal/db/models/user_node_model_ext.go index 81ecb2aa..fec93bcc 100644 --- a/internal/db/models/user_node_model_ext.go +++ b/internal/db/models/user_node_model_ext.go @@ -43,7 +43,7 @@ func (this *UserNode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) if config.SSLPolicyRef != nil { policyId := config.SSLPolicyRef.SSLPolicyId if policyId > 0 { - sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(nil, policyId) if err != nil { return nil, err } diff --git a/internal/db/models/utils.go b/internal/db/models/utils.go index 6e00330b..bb4431f0 100644 --- a/internal/db/models/utils.go +++ b/internal/db/models/utils.go @@ -16,8 +16,8 @@ func IsNotNull(data string) bool { } // 构造Query -func NewQuery(dao dbs.DAOWrapper, adminId int64, userId int64) *dbs.Query { - query := dao.Object().Query() +func NewQuery(tx *dbs.Tx, dao dbs.DAOWrapper, adminId int64, userId int64) *dbs.Query { + query := dao.Object().Query(tx) if adminId > 0 { query.Attr("adminId", adminId) } diff --git a/internal/installers/queue.go b/internal/installers/queue.go index 17a7b322..353a3196 100644 --- a/internal/installers/queue.go +++ b/internal/installers/queue.go @@ -29,7 +29,7 @@ func (this *Queue) InstallNodeProcess(nodeId int64, isUpgrading bool) error { installStatus.IsRunning = true installStatus.UpdatedAt = time.Now().Unix() - err := models.SharedNodeDAO.UpdateNodeInstallStatus(nodeId, installStatus) + err := models.SharedNodeDAO.UpdateNodeInstallStatus(nil, nodeId, installStatus) if err != nil { return err } @@ -39,7 +39,7 @@ func (this *Queue) InstallNodeProcess(nodeId int64, isUpgrading bool) error { go func() { for ticker.Wait() { installStatus.UpdatedAt = time.Now().Unix() - err := models.SharedNodeDAO.UpdateNodeInstallStatus(nodeId, installStatus) + err := models.SharedNodeDAO.UpdateNodeInstallStatus(nil, nodeId, installStatus) if err != nil { logs.Println("[INSTALL]" + err.Error()) continue @@ -61,14 +61,14 @@ func (this *Queue) InstallNodeProcess(nodeId int64, isUpgrading bool) error { } else { installStatus.IsOk = true } - err = models.SharedNodeDAO.UpdateNodeInstallStatus(nodeId, installStatus) + err = models.SharedNodeDAO.UpdateNodeInstallStatus(nil, nodeId, installStatus) if err != nil { return err } // 修改为已安装 if installStatus.IsOk { - err = models.SharedNodeDAO.UpdateNodeIsInstalled(nodeId, true) + err = models.SharedNodeDAO.UpdateNodeIsInstalled(nil, nodeId, true) if err != nil { return err } @@ -79,7 +79,7 @@ func (this *Queue) InstallNodeProcess(nodeId int64, isUpgrading bool) error { // 安装边缘节点 func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallStatus, isUpgrading bool) error { - node, err := models.SharedNodeDAO.FindEnabledNode(nodeId) + node, err := models.SharedNodeDAO.FindEnabledNode(nil, nodeId) if err != nil { return err } @@ -88,7 +88,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt } // 登录信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nodeId) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nil, nodeId) if err != nil { return err } @@ -113,7 +113,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt if loginParams.GrantId == 0 { // 从集群中读取 - grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(int64(node.ClusterId)) + grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(nil, int64(node.ClusterId)) if err != nil { return err } @@ -123,7 +123,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt } loginParams.GrantId = grantId } - grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(loginParams.GrantId) + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(nil, loginParams.GrantId) if err != nil { return err } @@ -136,7 +136,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt installDir := node.InstallDir if len(installDir) == 0 { clusterId := node.ClusterId - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(int64(clusterId)) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(nil, int64(clusterId)) if err != nil { return err } @@ -151,7 +151,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt } // API终端 - apiNodes, err := models.SharedAPINodeDAO.FindAllEnabledAndOnAPINodes() + apiNodes, err := models.SharedAPINodeDAO.FindAllEnabledAndOnAPINodes(nil) if err != nil { return err } @@ -198,7 +198,7 @@ func (this *Queue) InstallNode(nodeId int64, installStatus *models.NodeInstallSt // 启动边缘节点 func (this *Queue) StartNode(nodeId int64) error { - node, err := models.SharedNodeDAO.FindEnabledNode(nodeId) + node, err := models.SharedNodeDAO.FindEnabledNode(nil, nodeId) if err != nil { return err } @@ -207,7 +207,7 @@ func (this *Queue) StartNode(nodeId int64) error { } // 登录信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nodeId) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nil, nodeId) if err != nil { return err } @@ -229,7 +229,7 @@ func (this *Queue) StartNode(nodeId int64) error { if loginParams.GrantId == 0 { // 从集群中读取 - grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(int64(node.ClusterId)) + grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(nil, int64(node.ClusterId)) if err != nil { return err } @@ -238,7 +238,7 @@ func (this *Queue) StartNode(nodeId int64) error { } loginParams.GrantId = grantId } - grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(loginParams.GrantId) + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(nil, loginParams.GrantId) if err != nil { return err } @@ -250,7 +250,7 @@ func (this *Queue) StartNode(nodeId int64) error { installDir := node.InstallDir if len(installDir) == 0 { clusterId := node.ClusterId - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(int64(clusterId)) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(nil, int64(clusterId)) if err != nil { return err } @@ -299,7 +299,7 @@ func (this *Queue) StartNode(nodeId int64) error { // 停止节点 func (this *Queue) StopNode(nodeId int64) error { - node, err := models.SharedNodeDAO.FindEnabledNode(nodeId) + node, err := models.SharedNodeDAO.FindEnabledNode(nil, nodeId) if err != nil { return err } @@ -308,7 +308,7 @@ func (this *Queue) StopNode(nodeId int64) error { } // 登录信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nodeId) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(nil, nodeId) if err != nil { return err } @@ -330,7 +330,7 @@ func (this *Queue) StopNode(nodeId int64) error { if loginParams.GrantId == 0 { // 从集群中读取 - grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(int64(node.ClusterId)) + grantId, err := models.SharedNodeClusterDAO.FindClusterGrantId(nil, int64(node.ClusterId)) if err != nil { return err } @@ -339,7 +339,7 @@ func (this *Queue) StopNode(nodeId int64) error { } loginParams.GrantId = grantId } - grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(loginParams.GrantId) + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(nil, loginParams.GrantId) if err != nil { return err } @@ -351,7 +351,7 @@ func (this *Queue) StopNode(nodeId int64) error { installDir := node.InstallDir if len(installDir) == 0 { clusterId := node.ClusterId - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(int64(clusterId)) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(nil, int64(clusterId)) if err != nil { return err } diff --git a/internal/iplibrary/manager.go b/internal/iplibrary/manager.go index 39a732da..5c0bf930 100644 --- a/internal/iplibrary/manager.go +++ b/internal/iplibrary/manager.go @@ -38,7 +38,7 @@ func NewManager() *Manager { func (this *Manager) Load() (LibraryInterface, error) { // 当前正在使用的IP库代号 - config, err := models.SharedSysSettingDAO.ReadGlobalConfig() + config, err := models.SharedSysSettingDAO.ReadGlobalConfig(nil) if err != nil { return nil, err } diff --git a/internal/iplibrary/updater.go b/internal/iplibrary/updater.go index 0b593564..88cf16a3 100644 --- a/internal/iplibrary/updater.go +++ b/internal/iplibrary/updater.go @@ -44,7 +44,7 @@ func (this *Updater) Start() { // 单次任务 func (this *Updater) loop() error { - config, err := models.SharedSysSettingDAO.ReadGlobalConfig() + config, err := models.SharedSysSettingDAO.ReadGlobalConfig(nil) if err != nil { return err } @@ -52,7 +52,7 @@ func (this *Updater) loop() error { if len(code) == 0 { code = serverconfigs.DefaultIPLibraryType } - lib, err := models.SharedIPLibraryDAO.FindLatestIPLibraryWithType(code) + lib, err := models.SharedIPLibraryDAO.FindLatestIPLibraryWithType(nil, code) if err != nil { return err } @@ -74,7 +74,7 @@ func (this *Updater) loop() error { } // 开始下载 - chunkIds, err := models.SharedFileChunkDAO.FindAllFileChunkIds(int64(lib.FileId)) + chunkIds, err := models.SharedFileChunkDAO.FindAllFileChunkIds(nil, int64(lib.FileId)) if err != nil { return err } @@ -93,7 +93,7 @@ func (this *Updater) loop() error { } }() for _, chunkId := range chunkIds { - chunk, err := models.SharedFileChunkDAO.FindFileChunk(chunkId) + chunk, err := models.SharedFileChunkDAO.FindFileChunk(nil, chunkId) if err != nil { return err } diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index 33ed3891..e47b00cc 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -55,7 +55,7 @@ func (this *APINode) Start() { sharedAPIConfig = config // 校验 - apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret) + apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(nil, config.NodeId, config.Secret) if err != nil { logs.Println("[API_NODE]start failed: read api node from database failed: " + err.Error()) return @@ -230,7 +230,7 @@ func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) { } // HTTPS - httpsConfig, err := apiNode.DecodeHTTPS() + httpsConfig, err := apiNode.DecodeHTTPS(nil) if err != nil { remotelogs.Error("API_NODE", "decode https config: "+err.Error()) return @@ -296,7 +296,7 @@ func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) { } // Rest HTTPS - restHTTPSConfig, err := apiNode.DecodeRestHTTPS() + restHTTPSConfig, err := apiNode.DecodeRestHTTPS(nil) if err != nil { remotelogs.Error("API_NODE", "decode REST https config: "+err.Error()) return diff --git a/internal/nodes/node_status_executor.go b/internal/nodes/node_status_executor.go index e7c2d875..32f79999 100644 --- a/internal/nodes/node_status_executor.go +++ b/internal/nodes/node_status_executor.go @@ -75,7 +75,7 @@ func (this *NodeStatusExecutor) update() { remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error()) return } - err = models.SharedAPINodeDAO.UpdateAPINodeStatus(sharedAPIConfig.NumberId(), jsonData) + err = models.SharedAPINodeDAO.UpdateAPINodeStatus(nil, sharedAPIConfig.NumberId(), jsonData) if err != nil { remotelogs.Error("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error()) return diff --git a/internal/nodes/rest_server.go b/internal/nodes/rest_server.go index 520972ed..760689d0 100644 --- a/internal/nodes/rest_server.go +++ b/internal/nodes/rest_server.go @@ -90,7 +90,7 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { return } - accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(token) + accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(nil, token) if err != nil { this.writeJSON(writer, maps.Map{ "code": 400, diff --git a/internal/remotelogs/utils.go b/internal/remotelogs/utils.go index 96216266..c7a3e96c 100644 --- a/internal/remotelogs/utils.go +++ b/internal/remotelogs/utils.go @@ -99,7 +99,7 @@ Loop: for { select { case log := <-logChan: - err := models.SharedNodeLogDAO.CreateLog(models.NodeRoleAPI, log.NodeId, log.Level, log.Tag, log.Description, log.CreatedAt) + err := models.SharedNodeLogDAO.CreateLog(nil, models.NodeRoleAPI, log.NodeId, log.Level, log.Tag, log.Description, log.CreatedAt) if err != nil { return err } diff --git a/internal/rpc/services/service_acme_authentication.go b/internal/rpc/services/service_acme_authentication.go index a3b1805a..3cfa9b7c 100644 --- a/internal/rpc/services/service_acme_authentication.go +++ b/internal/rpc/services/service_acme_authentication.go @@ -22,7 +22,9 @@ func (this *ACMEAuthenticationService) FindACMEAuthenticationKeyWithToken(ctx co return nil, errors.New("'token' should not be empty") } - auth, err := models.SharedACMEAuthenticationDAO.FindAuthWithToken(req.Token) + tx := this.NullTx() + + auth, err := models.SharedACMEAuthenticationDAO.FindAuthWithToken(tx, req.Token) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_acme_task.go b/internal/rpc/services/service_acme_task.go index 7bc48a9b..ae8f1d54 100644 --- a/internal/rpc/services/service_acme_task.go +++ b/internal/rpc/services/service_acme_task.go @@ -22,7 +22,9 @@ func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context. // TODO 校验权限 - count, err := models.SharedACMETaskDAO.CountACMETasksWithACMEUserId(req.AcmeUserId) + tx := this.NullTx() + + count, err := models.SharedACMETaskDAO.CountACMETasksWithACMEUserId(tx, req.AcmeUserId) if err != nil { return nil, err } @@ -38,7 +40,9 @@ func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context. // TODO 校验权限 - count, err := models.SharedACMETaskDAO.CountACMETasksWithDNSProviderId(req.DnsProviderId) + tx := this.NullTx() + + count, err := models.SharedACMETaskDAO.CountACMETasksWithDNSProviderId(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -52,7 +56,9 @@ func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req * return nil, err } - count, err := models.SharedACMETaskDAO.CountAllEnabledACMETasks(req.AdminId, req.UserId) + tx := this.NullTx() + + count, err := models.SharedACMETaskDAO.CountAllEnabledACMETasks(tx, req.AdminId, req.UserId) if err != nil { return nil, err } @@ -66,7 +72,9 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L return nil, err } - tasks, err := models.SharedACMETaskDAO.ListEnabledACMETasks(req.AdminId, req.UserId, req.Offset, req.Size) + tx := this.NullTx() + + tasks, err := models.SharedACMETaskDAO.ListEnabledACMETasks(tx, req.AdminId, req.UserId, req.Offset, req.Size) if err != nil { return nil, err } @@ -74,7 +82,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L result := []*pb.ACMETask{} for _, task := range tasks { // ACME用户 - acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(int64(task.AcmeUserId)) + acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(tx, int64(task.AcmeUserId)) if err != nil { return nil, err } @@ -91,7 +99,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L var pbProvider *pb.DNSProvider if task.AuthType == acme.AuthTypeDNS { // DNS - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(task.DnsProviderId)) + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(task.DnsProviderId)) if err != nil { return nil, err } @@ -109,7 +117,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L // 证书 var pbCert *pb.SSLCert = nil if task.CertId > 0 { - cert, err := models.SharedSSLCertDAO.FindEnabledSSLCert(int64(task.CertId)) + cert, err := models.SharedSSLCertDAO.FindEnabledSSLCert(tx, int64(task.CertId)) if err != nil { return nil, err } @@ -127,7 +135,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L // 最近一条日志 var pbTaskLog *pb.ACMETaskLog = nil - taskLog, err := models.SharedACMETaskLogDAO.FindLatestACMETasKLog(int64(task.Id)) + taskLog, err := models.SharedACMETaskLogDAO.FindLatestACMETasKLog(tx, int64(task.Id)) if err != nil { return nil, err } @@ -169,7 +177,8 @@ func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateA req.AuthType = acme.AuthTypeDNS } - taskId, err := models.SharedACMETaskDAO.CreateACMETask(adminId, userId, req.AuthType, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) + tx := this.NullTx() + taskId, err := models.SharedACMETaskDAO.CreateACMETask(tx, adminId, userId, req.AuthType, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) if err != nil { return nil, err } @@ -183,7 +192,9 @@ func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateA return nil, err } - canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + tx := this.NullTx() + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(tx, adminId, userId, req.AcmeTaskId) if err != nil { return nil, err } @@ -191,7 +202,7 @@ func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateA return nil, this.PermissionError() } - err = models.SharedACMETaskDAO.UpdateACMETask(req.AcmeTaskId, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) + err = models.SharedACMETaskDAO.UpdateACMETask(tx, req.AcmeTaskId, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) if err != nil { return nil, err } @@ -205,7 +216,9 @@ func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteA return nil, err } - canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + tx := this.NullTx() + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(tx, adminId, userId, req.AcmeTaskId) if err != nil { return nil, err } @@ -213,7 +226,7 @@ func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteA return nil, this.PermissionError() } - err = models.SharedACMETaskDAO.DisableACMETask(req.AcmeTaskId) + err = models.SharedACMETaskDAO.DisableACMETask(tx, req.AcmeTaskId) if err != nil { return nil, err } @@ -227,7 +240,9 @@ func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETas return nil, err } - canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + tx := this.NullTx() + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(tx, adminId, userId, req.AcmeTaskId) if err != nil { return nil, err } @@ -235,7 +250,7 @@ func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETas return nil, this.PermissionError() } - isOk, msg, certId := models.SharedACMETaskDAO.RunTask(req.AcmeTaskId) + isOk, msg, certId := models.SharedACMETaskDAO.RunTask(tx, req.AcmeTaskId) return &pb.RunACMETaskResponse{ IsOk: isOk, @@ -251,7 +266,9 @@ func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.Fi return nil, err } - canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + tx := this.NullTx() + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(tx, adminId, userId, req.AcmeTaskId) if err != nil { return nil, err } @@ -259,7 +276,7 @@ func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.Fi return nil, this.PermissionError() } - task, err := models.SharedACMETaskDAO.FindEnabledACMETask(req.AcmeTaskId) + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(tx, req.AcmeTaskId) if err != nil { return nil, err } @@ -270,7 +287,7 @@ func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.Fi // 用户 var pbACMEUser *pb.ACMEUser = nil if task.AcmeUserId > 0 { - acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(int64(task.AcmeUserId)) + acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(tx, int64(task.AcmeUserId)) if err != nil { return nil, err } @@ -286,7 +303,7 @@ func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.Fi // DNS var pbProvider *pb.DNSProvider - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(task.DnsProviderId)) + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(task.DnsProviderId)) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_acme_user.go b/internal/rpc/services/service_acme_user.go index b21cd66f..4804ed5b 100644 --- a/internal/rpc/services/service_acme_user.go +++ b/internal/rpc/services/service_acme_user.go @@ -19,7 +19,9 @@ func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateA return nil, err } - acmeUserId, err := models.SharedACMEUserDAO.CreateACMEUser(adminId, userId, req.Email, req.Description) + tx := this.NullTx() + + acmeUserId, err := models.SharedACMEUserDAO.CreateACMEUser(tx, adminId, userId, req.Email, req.Description) if err != nil { return nil, err } @@ -34,8 +36,10 @@ func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateA return nil, err } + tx := this.NullTx() + // 检查是否有权限 - b, err := models.SharedACMEUserDAO.CheckACMEUser(req.AcmeUserId, adminId, userId) + b, err := models.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId) if err != nil { return nil, err } @@ -43,7 +47,7 @@ func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateA return nil, this.PermissionError() } - err = models.SharedACMEUserDAO.UpdateACMEUser(req.AcmeUserId, req.Description) + err = models.SharedACMEUserDAO.UpdateACMEUser(tx, req.AcmeUserId, req.Description) if err != nil { return nil, err } @@ -58,8 +62,10 @@ func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteA return nil, err } + tx := this.NullTx() + // 检查是否有权限 - b, err := models.SharedACMEUserDAO.CheckACMEUser(req.AcmeUserId, adminId, userId) + b, err := models.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId) if err != nil { return nil, err } @@ -67,7 +73,7 @@ func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteA return nil, this.PermissionError() } - err = models.SharedACMEUserDAO.DisableACMEUser(req.AcmeUserId) + err = models.SharedACMEUserDAO.DisableACMEUser(tx, req.AcmeUserId) if err != nil { return nil, err } @@ -82,7 +88,9 @@ func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAc return nil, err } - count, err := models.SharedACMEUserDAO.CountACMEUsersWithAdminId(adminId, userId) + tx := this.NullTx() + + count, err := models.SharedACMEUserDAO.CountACMEUsersWithAdminId(tx, adminId, userId) if err != nil { return nil, err } @@ -97,7 +105,9 @@ func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACME return nil, err } - acmeUsers, err := models.SharedACMEUserDAO.ListACMEUsers(adminId, userId, req.Offset, req.Size) + tx := this.NullTx() + + acmeUsers, err := models.SharedACMEUserDAO.ListACMEUsers(tx, adminId, userId, req.Offset, req.Size) if err != nil { return nil, err } @@ -121,8 +131,10 @@ func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.Fi return nil, err } + tx := this.NullTx() + // 检查是否有权限 - b, err := models.SharedACMEUserDAO.CheckACMEUser(req.AcmeUserId, adminId, userId) + b, err := models.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId) if err != nil { return nil, err } @@ -130,7 +142,7 @@ func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.Fi return nil, this.PermissionError() } - acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(req.AcmeUserId) + acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(tx, req.AcmeUserId) if err != nil { return nil, err } @@ -153,7 +165,9 @@ func (this *ACMEUserService) FindAllACMEUsers(ctx context.Context, req *pb.FindA return nil, err } - acmeUsers, err := models.SharedACMEUserDAO.FindAllACMEUsers(adminId, userId) + tx := this.NullTx() + + acmeUsers, err := models.SharedACMEUserDAO.FindAllACMEUsers(tx, adminId, userId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_admin.go b/internal/rpc/services/service_admin.go index 85136cc8..d9f1a848 100644 --- a/internal/rpc/services/service_admin.go +++ b/internal/rpc/services/service_admin.go @@ -32,7 +32,9 @@ func (this *AdminService) LoginAdmin(ctx context.Context, req *pb.LoginAdminRequ }, nil } - adminId, err := models.SharedAdminDAO.CheckAdminPassword(req.Username, req.Password) + tx := this.NullTx() + + adminId, err := models.SharedAdminDAO.CheckAdminPassword(tx, req.Username, req.Password) if err != nil { utils.PrintError(err) return nil, err @@ -65,7 +67,9 @@ func (this *AdminService) CheckAdminExists(ctx context.Context, req *pb.CheckAdm }, nil } - ok, err := models.SharedAdminDAO.ExistEnabledAdmin(req.AdminId) + tx := this.NullTx() + + ok, err := models.SharedAdminDAO.ExistEnabledAdmin(tx, req.AdminId) if err != nil { return nil, err } @@ -83,7 +87,9 @@ func (this *AdminService) CheckAdminUsername(ctx context.Context, req *pb.CheckA return nil, err } - exists, err := models.SharedAdminDAO.CheckAdminUsername(req.AdminId, req.Username) + tx := this.NullTx() + + exists, err := models.SharedAdminDAO.CheckAdminUsername(tx, req.AdminId, req.Username) if err != nil { return nil, err } @@ -99,7 +105,9 @@ func (this *AdminService) FindAdminFullname(ctx context.Context, req *pb.FindAdm return nil, err } - fullname, err := models.SharedAdminDAO.FindAdminFullname(req.AdminId) + tx := this.NullTx() + + fullname, err := models.SharedAdminDAO.FindAdminFullname(tx, req.AdminId) if err != nil { utils.PrintError(err) return nil, err @@ -119,7 +127,9 @@ func (this *AdminService) FindEnabledAdmin(ctx context.Context, req *pb.FindEnab // TODO 检查权限 - admin, err := models.SharedAdminDAO.FindEnabledAdmin(req.AdminId) + tx := this.NullTx() + + admin, err := models.SharedAdminDAO.FindEnabledAdmin(tx, req.AdminId) if err != nil { return nil, err } @@ -146,7 +156,7 @@ func (this *AdminService) FindEnabledAdmin(ctx context.Context, req *pb.FindEnab // OTP认证 var pbOtpAuth *pb.Login = nil { - adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(int64(admin.Id), models.LoginTypeOTP) + adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(tx, int64(admin.Id), models.LoginTypeOTP) if err != nil { return nil, err } @@ -180,18 +190,20 @@ func (this *AdminService) CreateOrUpdateAdmin(ctx context.Context, req *pb.Creat return nil, err } - adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(req.Username) + tx := this.NullTx() + + adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(tx, req.Username) if err != nil { return nil, err } if adminId > 0 { - err = models.SharedAdminDAO.UpdateAdminPassword(adminId, req.Password) + err = models.SharedAdminDAO.UpdateAdminPassword(tx, adminId, req.Password) if err != nil { return nil, err } return &pb.CreateOrUpdateAdminResponse{AdminId: adminId}, nil } - adminId, err = models.SharedAdminDAO.CreateAdmin(req.Username, req.Password, "管理员", true, nil) + adminId, err = models.SharedAdminDAO.CreateAdmin(tx, req.Username, req.Password, "管理员", true, nil) if err != nil { return nil, err } @@ -206,7 +218,9 @@ func (this *AdminService) UpdateAdminInfo(ctx context.Context, req *pb.UpdateAdm return nil, err } - err = models.SharedAdminDAO.UpdateAdminInfo(req.AdminId, req.Fullname) + tx := this.NullTx() + + err = models.SharedAdminDAO.UpdateAdminInfo(tx, req.AdminId, req.Fullname) if err != nil { return nil, err } @@ -221,7 +235,9 @@ func (this *AdminService) UpdateAdminLogin(ctx context.Context, req *pb.UpdateAd return nil, err } - exists, err := models.SharedAdminDAO.CheckAdminUsername(req.AdminId, req.Username) + tx := this.NullTx() + + exists, err := models.SharedAdminDAO.CheckAdminUsername(tx, req.AdminId, req.Username) if err != nil { return nil, err } @@ -229,7 +245,7 @@ func (this *AdminService) UpdateAdminLogin(ctx context.Context, req *pb.UpdateAd return nil, errors.New("username already been token") } - err = models.SharedAdminDAO.UpdateAdminLogin(req.AdminId, req.Username, req.Password) + err = models.SharedAdminDAO.UpdateAdminLogin(tx, req.AdminId, req.Username, req.Password) if err != nil { return nil, err } @@ -245,7 +261,9 @@ func (this *AdminService) FindAllAdminModules(ctx context.Context, req *pb.FindA // TODO 检查权限 - admins, err := models.SharedAdminDAO.FindAllAdminModules() + tx := this.NullTx() + + admins, err := models.SharedAdminDAO.FindAllAdminModules(tx) if err != nil { return nil, err } @@ -289,7 +307,9 @@ func (this *AdminService) CreateAdmin(ctx context.Context, req *pb.CreateAdminRe // TODO 检查权限 - adminId, err := models.SharedAdminDAO.CreateAdmin(req.Username, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON) + tx := this.NullTx() + + adminId, err := models.SharedAdminDAO.CreateAdmin(tx, req.Username, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON) if err != nil { return nil, err } @@ -306,7 +326,9 @@ func (this *AdminService) UpdateAdmin(ctx context.Context, req *pb.UpdateAdminRe // TODO 检查权限 - err = models.SharedAdminDAO.UpdateAdmin(req.AdminId, req.Username, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON, req.IsOn) + tx := this.NullTx() + + err = models.SharedAdminDAO.UpdateAdmin(tx, req.AdminId, req.Username, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON, req.IsOn) if err != nil { return nil, err } @@ -323,7 +345,9 @@ func (this *AdminService) CountAllEnabledAdmins(ctx context.Context, req *pb.Cou // TODO 检查权限 - count, err := models.SharedAdminDAO.CountAllEnabledAdmins() + tx := this.NullTx() + + count, err := models.SharedAdminDAO.CountAllEnabledAdmins(tx) if err != nil { return nil, err } @@ -339,7 +363,9 @@ func (this *AdminService) ListEnabledAdmins(ctx context.Context, req *pb.ListEna // TODO 检查权限 - admins, err := models.SharedAdminDAO.ListEnabledAdmins(req.Offset, req.Size) + tx := this.NullTx() + + admins, err := models.SharedAdminDAO.ListEnabledAdmins(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -348,7 +374,7 @@ func (this *AdminService) ListEnabledAdmins(ctx context.Context, req *pb.ListEna for _, admin := range admins { var pbOtpAuth *pb.Login = nil { - adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(int64(admin.Id), models.LoginTypeOTP) + adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(tx, int64(admin.Id), models.LoginTypeOTP) if err != nil { return nil, err } @@ -385,9 +411,11 @@ func (this *AdminService) DeleteAdmin(ctx context.Context, req *pb.DeleteAdminRe // TODO 检查权限 + tx := this.NullTx() + // TODO 超级管理员用户是不能删除的,或者要至少留一个超级管理员用户 - _, err = models.SharedAdminDAO.DisableAdmin(req.AdminId) + _, err = models.SharedAdminDAO.DisableAdmin(tx, req.AdminId) if err != nil { return nil, err } @@ -406,7 +434,9 @@ func (this *AdminService) CheckAdminOTPWithUsername(ctx context.Context, req *pb return &pb.CheckAdminOTPWithUsernameResponse{RequireOTP: false}, nil } - adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(req.Username) + tx := this.NullTx() + + adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(tx, req.Username) if err != nil { return nil, err } @@ -414,7 +444,7 @@ func (this *AdminService) CheckAdminOTPWithUsername(ctx context.Context, req *pb return &pb.CheckAdminOTPWithUsernameResponse{RequireOTP: false}, nil } - otpIsOn, err := models.SharedLoginDAO.CheckLoginIsOn(adminId, "otp") + otpIsOn, err := models.SharedLoginDAO.CheckLoginIsOn(tx, adminId, "otp") if err != nil { return nil, err } diff --git a/internal/rpc/services/service_api_access_token.go b/internal/rpc/services/service_api_access_token.go index 8019b815..67eea975 100644 --- a/internal/rpc/services/service_api_access_token.go +++ b/internal/rpc/services/service_api_access_token.go @@ -9,12 +9,15 @@ import ( // AccessToken相关服务 type APIAccessTokenService struct { + BaseService } // 获取AccessToken func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *pb.GetAPIAccessTokenRequest) (*pb.GetAPIAccessTokenResponse, error) { if req.Type == "user" { // 用户 - accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(req.AccessKeyId) + tx := this.NullTx() + + accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(tx, req.AccessKeyId) if err != nil { return nil, err } @@ -26,7 +29,7 @@ func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *p } // 创建AccessToken - token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(int64(accessKey.UserId)) + token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(tx, int64(accessKey.UserId)) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_api_node.go b/internal/rpc/services/service_api_node.go index f21459ea..70410913 100644 --- a/internal/rpc/services/service_api_node.go +++ b/internal/rpc/services/service_api_node.go @@ -19,7 +19,9 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI return nil, err } - nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) + tx := this.NullTx() + + nodeId, err := models.SharedAPINodeDAO.CreateAPINode(tx, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -34,7 +36,9 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI return nil, err } - err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) + tx := this.NullTx() + + err = models.SharedAPINodeDAO.UpdateAPINode(tx, req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -49,7 +53,9 @@ func (this *APINodeService) DeleteAPINode(ctx context.Context, req *pb.DeleteAPI return nil, err } - err = models.SharedAPINodeDAO.DisableAPINode(req.NodeId) + tx := this.NullTx() + + err = models.SharedAPINodeDAO.DisableAPINode(tx, req.NodeId) if err != nil { return nil, err } @@ -64,7 +70,9 @@ func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb. return nil, err } - nodes, err := models.SharedAPINodeDAO.FindAllEnabledAPINodes() + tx := this.NullTx() + + nodes, err := models.SharedAPINodeDAO.FindAllEnabledAPINodes(tx) if err != nil { return nil, err } @@ -101,7 +109,9 @@ func (this *APINodeService) CountAllEnabledAPINodes(ctx context.Context, req *pb return nil, err } - count, err := models.SharedAPINodeDAO.CountAllEnabledAPINodes() + tx := this.NullTx() + + count, err := models.SharedAPINodeDAO.CountAllEnabledAPINodes(tx) if err != nil { return nil, err } @@ -116,7 +126,9 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis return nil, err } - nodes, err := models.SharedAPINodeDAO.ListEnabledAPINodes(req.Offset, req.Size) + tx := this.NullTx() + + nodes, err := models.SharedAPINodeDAO.ListEnabledAPINodes(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -157,7 +169,9 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find return nil, err } - node, err := models.SharedAPINodeDAO.FindEnabledAPINode(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedAPINodeDAO.FindEnabledAPINode(tx, req.NodeId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_base.go b/internal/rpc/services/service_base.go index 05443e86..e85bb450 100644 --- a/internal/rpc/services/service_base.go +++ b/internal/rpc/services/service_base.go @@ -5,6 +5,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/dbs" ) type BaseService struct { @@ -85,3 +86,8 @@ func (this *BaseService) SuccessCount(count int64) (*pb.RPCCountResponse, error) func (this *BaseService) PermissionError() error { return errors.New("Permission Denied") } + +// 空的数据库事务 +func (this *BaseService) NullTx() *dbs.Tx { + return nil +} diff --git a/internal/rpc/services/service_db_node.go b/internal/rpc/services/service_db_node.go index 87f1c3fd..644293d6 100644 --- a/internal/rpc/services/service_db_node.go +++ b/internal/rpc/services/service_db_node.go @@ -20,7 +20,10 @@ func (this *DBNodeService) CreateDBNode(ctx context.Context, req *pb.CreateDBNod if err != nil { return nil, err } - nodeId, err := models.SharedDBNodeDAO.CreateDBNode(req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset) + + tx := this.NullTx() + + nodeId, err := models.SharedDBNodeDAO.CreateDBNode(tx, req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset) if err != nil { return nil, err } @@ -34,7 +37,10 @@ func (this *DBNodeService) UpdateDBNode(ctx context.Context, req *pb.UpdateDBNod if err != nil { return nil, err } - err = models.SharedDBNodeDAO.UpdateNode(req.NodeId, req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset) + + tx := this.NullTx() + + err = models.SharedDBNodeDAO.UpdateNode(tx, req.NodeId, req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset) if err != nil { return nil, err } @@ -48,7 +54,10 @@ func (this *DBNodeService) DeleteDBNode(ctx context.Context, req *pb.DeleteDBNod if err != nil { return nil, err } - err = models.SharedDBNodeDAO.DisableDBNode(req.NodeId) + + tx := this.NullTx() + + err = models.SharedDBNodeDAO.DisableDBNode(tx, req.NodeId) if err != nil { return nil, err } @@ -62,7 +71,10 @@ func (this *DBNodeService) CountAllEnabledDBNodes(ctx context.Context, req *pb.C if err != nil { return nil, err } - count, err := models.SharedDBNodeDAO.CountAllEnabledNodes() + + tx := this.NullTx() + + count, err := models.SharedDBNodeDAO.CountAllEnabledNodes(tx) if err != nil { return nil, err } @@ -77,7 +89,9 @@ func (this *DBNodeService) ListEnabledDBNodes(ctx context.Context, req *pb.ListE return nil, err } - nodes, err := models.SharedDBNodeDAO.ListEnabledNodes(req.Offset, req.Size) + tx := this.NullTx() + + nodes, err := models.SharedDBNodeDAO.ListEnabledNodes(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -108,7 +122,9 @@ func (this *DBNodeService) FindEnabledDBNode(ctx context.Context, req *pb.FindEn return nil, err } - node, err := models.SharedDBNodeDAO.FindEnabledDBNode(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.NodeId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_dns.go b/internal/rpc/services/service_dns.go index ea05233e..76b654ed 100644 --- a/internal/rpc/services/service_dns.go +++ b/internal/rpc/services/service_dns.go @@ -9,6 +9,7 @@ import ( // DNS相关服务 type DNSService struct { + BaseService } // 查找问题 @@ -21,12 +22,14 @@ func (this *DNSService) FindAllDNSIssues(ctx context.Context, req *pb.FindAllDNS result := []*pb.DNSIssue{} - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersHaveDNSDomain() + tx := this.NullTx() + + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersHaveDNSDomain(tx) if err != nil { return nil, err } for _, cluster := range clusters { - issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(cluster) + issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(tx, cluster) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index 0ebaf63e..5bdc929c 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -26,8 +26,10 @@ func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.Creat return nil, err } + tx := this.NullTx() + // 查询Provider - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(req.DnsProviderId) + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -39,7 +41,7 @@ func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.Creat return nil, err } - domainId, err := models.SharedDNSDomainDAO.CreateDomain(adminId, userId, req.DnsProviderId, req.Name) + domainId, err := models.SharedDNSDomainDAO.CreateDomain(tx, adminId, userId, req.DnsProviderId, req.Name) if err != nil { return nil, err } @@ -65,7 +67,7 @@ func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.Creat if err != nil { return } - err = models.SharedDNSDomainDAO.UpdateDomainRoutes(domainId, routesJSON) + err = models.SharedDNSDomainDAO.UpdateDomainRoutes(tx, domainId, routesJSON) if err != nil { return } @@ -78,7 +80,7 @@ func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.Creat if err != nil { return } - err = models.SharedDNSDomainDAO.UpdateDomainRecords(domainId, recordsJSON) + err = models.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON) if err != nil { return } @@ -95,7 +97,9 @@ func (this *DNSDomainService) UpdateDNSDomain(ctx context.Context, req *pb.Updat return nil, err } - err = models.SharedDNSDomainDAO.UpdateDomain(req.DnsDomainId, req.Name, req.IsOn) + tx := this.NullTx() + + err = models.SharedDNSDomainDAO.UpdateDomain(tx, req.DnsDomainId, req.Name, req.IsOn) if err != nil { return nil, err } @@ -110,7 +114,9 @@ func (this *DNSDomainService) DeleteDNSDomain(ctx context.Context, req *pb.Delet return nil, err } - err = models.SharedDNSDomainDAO.DisableDNSDomain(req.DnsDomainId) + tx := this.NullTx() + + err = models.SharedDNSDomainDAO.DisableDNSDomain(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -125,7 +131,9 @@ func (this *DNSDomainService) FindEnabledDNSDomain(ctx context.Context, req *pb. return nil, err } - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(req.DnsDomainId) + tx := this.NullTx() + + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -145,7 +153,9 @@ func (this *DNSDomainService) FindEnabledBasicDNSDomain(ctx context.Context, req return nil, err } - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(req.DnsDomainId) + tx := this.NullTx() + + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -169,7 +179,9 @@ func (this *DNSDomainService) CountAllEnabledDNSDomainsWithDNSProviderId(ctx con return nil, err } - count, err := models.SharedDNSDomainDAO.CountAllEnabledDomainsWithProviderId(req.DnsProviderId) + tx := this.NullTx() + + count, err := models.SharedDNSDomainDAO.CountAllEnabledDomainsWithProviderId(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -184,7 +196,9 @@ func (this *DNSDomainService) FindAllEnabledDNSDomainsWithDNSProviderId(ctx cont return nil, err } - domains, err := models.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(req.DnsProviderId) + tx := this.NullTx() + + domains, err := models.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -209,7 +223,9 @@ func (this *DNSDomainService) FindAllEnabledBasicDNSDomainsWithDNSProviderId(ctx return nil, err } - domains, err := models.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(req.DnsProviderId) + tx := this.NullTx() + + domains, err := models.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -244,7 +260,9 @@ func (this *DNSDomainService) FindAllDNSDomainRoutes(ctx context.Context, req *p return nil, err } - routes, err := models.SharedDNSDomainDAO.FindDomainRoutes(req.DnsDomainId) + tx := this.NullTx() + + routes, err := models.SharedDNSDomainDAO.FindDomainRoutes(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -268,7 +286,9 @@ func (this *DNSDomainService) ExistAvailableDomains(ctx context.Context, req *pb return nil, err } - exist, err := models.SharedDNSDomainDAO.ExistAvailableDomains() + tx := this.NullTx() + + exist, err := models.SharedDNSDomainDAO.ExistAvailableDomains(tx) if err != nil { return nil, err } @@ -295,8 +315,10 @@ func (this *DNSDomainService) convertDomainToPB(domain *models.DNSDomain) (*pb.D countServerRecords := 0 serversChanged := false + tx := this.NullTx() + // 检查是否所有的集群都已经被解析 - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(domainId) + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, domainId) if err != nil { return nil, err } @@ -367,8 +389,10 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, clusterDnsName := cluster.DnsName clusterDomain := clusterDnsName + "." + domainName + tx := this.NullTx() + // 节点域名 - nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(clusterId) + nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, clusterId) if err != nil { return nil, nil, nil, 0, 0, false, false, err } @@ -385,7 +409,7 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, // 新增的节点域名 nodeKeys := []string{} for _, node := range nodes { - ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(int64(node.Id)) + ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(tx, int64(node.Id)) if err != nil { return nil, nil, nil, 0, 0, false, false, err } @@ -434,7 +458,7 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, } // 服务域名 - servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(clusterId) + servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(tx, clusterId) if err != nil { return nil, nil, nil, 0, 0, false, false, err } @@ -490,11 +514,13 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, // 执行同步 func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (*pb.SyncDNSDomainDataResponse, error) { + tx := this.NullTx() + // 查询集群信息 var err error clusters := []*models.NodeCluster{} if req.NodeClusterId > 0 { - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(req.NodeClusterId) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -514,14 +540,14 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( } clusters = append(clusters, cluster) } else { - clusters, err = models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(req.DnsDomainId) + clusters, err = models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, req.DnsDomainId) if err != nil { return nil, err } } // 域名信息 - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(req.DnsDomainId) + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -532,7 +558,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( domainName := domain.Name // 服务商信息 - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(domain.ProviderId)) + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(domain.ProviderId)) if err != nil { return nil, err } @@ -566,14 +592,14 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( if err != nil { return nil, err } - err = models.SharedDNSDomainDAO.UpdateDomainRoutes(domainId, routesJSON) + err = models.SharedDNSDomainDAO.UpdateDomainRoutes(tx, domainId, routesJSON) if err != nil { return nil, err } // 检查集群设置 for _, cluster := range clusters { - issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(cluster) + issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(tx, cluster) if err != nil { return nil, err } @@ -591,7 +617,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( if err != nil { return nil, err } - err = models.SharedDNSDomainDAO.UpdateDomainRecords(domainId, recordsJSON) + err = models.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON) if err != nil { return nil, err } @@ -639,7 +665,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) ( if err != nil { return nil, err } - err = models.SharedDNSDomainDAO.UpdateDomainRecords(domainId, recordsJSON) + err = models.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON) if err != nil { return nil, err } @@ -657,7 +683,9 @@ func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb. return nil, err } - isOk, err := models.SharedDNSDomainDAO.ExistDomainRecord(req.DnsDomainId, req.Name, req.Type, req.Route, req.Value) + tx := this.NullTx() + + isOk, err := models.SharedDNSDomainDAO.ExistDomainRecord(tx, req.DnsDomainId, req.Name, req.Type, req.Route, req.Value) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_dns_provider.go b/internal/rpc/services/service_dns_provider.go index 20af9ef6..b3bc9023 100644 --- a/internal/rpc/services/service_dns_provider.go +++ b/internal/rpc/services/service_dns_provider.go @@ -21,7 +21,9 @@ func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.C return nil, err } - providerId, err := models.SharedDNSProviderDAO.CreateDNSProvider(adminId, userId, req.Type, req.Name, req.ApiParamsJSON) + tx := this.NullTx() + + providerId, err := models.SharedDNSProviderDAO.CreateDNSProvider(tx, adminId, userId, req.Type, req.Name, req.ApiParamsJSON) if err != nil { return nil, err } @@ -39,7 +41,9 @@ func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.U // TODO 校验权限 - err = models.SharedDNSProviderDAO.UpdateDNSProvider(req.DnsProviderId, req.Name, req.ApiParamsJSON) + tx := this.NullTx() + + err = models.SharedDNSProviderDAO.UpdateDNSProvider(tx, req.DnsProviderId, req.Name, req.ApiParamsJSON) if err != nil { return nil, err } @@ -54,7 +58,9 @@ func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, return nil, err } - count, err := models.SharedDNSProviderDAO.CountAllEnabledDNSProviders(req.AdminId, req.UserId) + tx := this.NullTx() + + count, err := models.SharedDNSProviderDAO.CountAllEnabledDNSProviders(tx, req.AdminId, req.UserId) if err != nil { return nil, err } @@ -71,7 +77,9 @@ func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req // TODO 校验权限 - providers, err := models.SharedDNSProviderDAO.ListEnabledDNSProviders(req.AdminId, req.UserId, req.Offset, req.Size) + tx := this.NullTx() + + providers, err := models.SharedDNSProviderDAO.ListEnabledDNSProviders(tx, req.AdminId, req.UserId, req.Offset, req.Size) if err != nil { return nil, err } @@ -99,7 +107,9 @@ func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, // TODO 校验权限 - providers, err := models.SharedDNSProviderDAO.FindAllEnabledDNSProviders(req.AdminId, req.UserId) + tx := this.NullTx() + + providers, err := models.SharedDNSProviderDAO.FindAllEnabledDNSProviders(tx, req.AdminId, req.UserId) if err != nil { return nil, err } @@ -127,7 +137,9 @@ func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.D // TODO 校验权限 - err = models.SharedDNSProviderDAO.DisableDNSProvider(req.DnsProviderId) + tx := this.NullTx() + + err = models.SharedDNSProviderDAO.DisableDNSProvider(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -142,7 +154,9 @@ func (this *DNSProviderService) FindEnabledDNSProvider(ctx context.Context, req return nil, err } - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(req.DnsProviderId) + tx := this.NullTx() + + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -186,7 +200,9 @@ func (this *DNSProviderService) FindAllEnabledDNSProvidersWithType(ctx context.C return nil, err } - providers, err := models.SharedDNSProviderDAO.FindAllEnabledDNSProvidersWithType(req.ProviderTypeCode) + tx := this.NullTx() + + providers, err := models.SharedDNSProviderDAO.FindAllEnabledDNSProvidersWithType(tx, req.ProviderTypeCode) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_file.go b/internal/rpc/services/service_file.go index 834b8cbd..75848b6e 100644 --- a/internal/rpc/services/service_file.go +++ b/internal/rpc/services/service_file.go @@ -20,7 +20,9 @@ func (this *FileService) CreateFile(ctx context.Context, req *pb.CreateFileReque return nil, err } - fileId, err := models.SharedFileDAO.CreateFile("ipLibrary", "", req.Filename, req.Size) + tx := this.NullTx() + + fileId, err := models.SharedFileDAO.CreateFile(tx, "ipLibrary", "", req.Filename, req.Size) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *FileService) UpdateFileFinished(ctx context.Context, req *pb.UpdateF return nil, err } - err = models.SharedFileDAO.UpdateFileIsFinished(req.FileId) + tx := this.NullTx() + + err = models.SharedFileDAO.UpdateFileIsFinished(tx, req.FileId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_file_chunk.go b/internal/rpc/services/service_file_chunk.go index ff399e83..3c665d84 100644 --- a/internal/rpc/services/service_file_chunk.go +++ b/internal/rpc/services/service_file_chunk.go @@ -9,6 +9,7 @@ import ( // 文件片段相关服务 type FileChunkService struct { + BaseService } // 创建文件片段 @@ -19,7 +20,9 @@ func (this *FileChunkService) CreateFileChunk(ctx context.Context, req *pb.Creat return nil, err } - chunkId, err := models.SharedFileChunkDAO.CreateFileChunk(req.FileId, req.Data) + tx := this.NullTx() + + chunkId, err := models.SharedFileChunkDAO.CreateFileChunk(tx, req.FileId, req.Data) if err != nil { return nil, err } @@ -34,7 +37,9 @@ func (this *FileChunkService) FindAllFileChunkIds(ctx context.Context, req *pb.F return nil, err } - chunkIds, err := models.SharedFileChunkDAO.FindAllFileChunkIds(req.FileId) + tx := this.NullTx() + + chunkIds, err := models.SharedFileChunkDAO.FindAllFileChunkIds(tx, req.FileId) if err != nil { return nil, err } @@ -49,7 +54,9 @@ func (this *FileChunkService) DownloadFileChunk(ctx context.Context, req *pb.Dow return nil, err } - chunk, err := models.SharedFileChunkDAO.FindFileChunk(req.FileChunkId) + tx := this.NullTx() + + chunk, err := models.SharedFileChunkDAO.FindFileChunk(tx, req.FileChunkId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_access_log.go b/internal/rpc/services/service_http_access_log.go index f466dc05..48cb2b14 100644 --- a/internal/rpc/services/service_http_access_log.go +++ b/internal/rpc/services/service_http_access_log.go @@ -25,7 +25,9 @@ func (this *HTTPAccessLogService) CreateHTTPAccessLogs(ctx context.Context, req return &pb.CreateHTTPAccessLogsResponse{}, nil } - err = models.SharedHTTPAccessLogDAO.CreateHTTPAccessLogs(req.AccessLogs) + tx := this.NullTx() + + err = models.SharedHTTPAccessLogDAO.CreateHTTPAccessLogs(tx, req.AccessLogs) if err != nil { return nil, err } @@ -41,19 +43,21 @@ func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *p return nil, err } + tx := this.NullTx() + // 检查服务ID if userId > 0 { if req.ServerId <= 0 { return nil, errors.New("invalid serverId") } - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId) + accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(tx, req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId) if err != nil { return nil, err } @@ -82,7 +86,9 @@ func (this *HTTPAccessLogService) FindHTTPAccessLog(ctx context.Context, req *pb return nil, err } - accessLog, err := models.SharedHTTPAccessLogDAO.FindAccessLogWithRequestId(req.RequestId) + tx := this.NullTx() + + accessLog, err := models.SharedHTTPAccessLogDAO.FindAccessLogWithRequestId(tx, req.RequestId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_access_log_policy.go b/internal/rpc/services/service_http_access_log_policy.go index 80b2a617..651d949a 100644 --- a/internal/rpc/services/service_http_access_log_policy.go +++ b/internal/rpc/services/service_http_access_log_policy.go @@ -8,6 +8,7 @@ import ( ) type HTTPAccessLogPolicyService struct { + BaseService } // 获取所有可用策略 @@ -18,7 +19,9 @@ func (this *HTTPAccessLogPolicyService) FindAllEnabledHTTPAccessLogPolicies(ctx return nil, err } - policies, err := models.SharedHTTPAccessLogPolicyDAO.FindAllEnabledAccessLogPolicies() + tx := this.NullTx() + + policies, err := models.SharedHTTPAccessLogPolicyDAO.FindAllEnabledAccessLogPolicies(tx) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_cache_policy.go b/internal/rpc/services/service_http_cache_policy.go index 8c402e6e..1e6ceefb 100644 --- a/internal/rpc/services/service_http_cache_policy.go +++ b/internal/rpc/services/service_http_cache_policy.go @@ -20,7 +20,9 @@ func (this *HTTPCachePolicyService) FindAllEnabledHTTPCachePolicies(ctx context. return nil, err } - policies, err := models.SharedHTTPCachePolicyDAO.FindAllEnabledCachePolicies() + tx := this.NullTx() + + policies, err := models.SharedHTTPCachePolicyDAO.FindAllEnabledCachePolicies(tx) if err != nil { return nil, err } @@ -43,7 +45,9 @@ func (this *HTTPCachePolicyService) CreateHTTPCachePolicy(ctx context.Context, r return nil, err } - policyId, err := models.SharedHTTPCachePolicyDAO.CreateCachePolicy(req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) + tx := this.NullTx() + + policyId, err := models.SharedHTTPCachePolicyDAO.CreateCachePolicy(tx, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) if err != nil { return nil, err } @@ -58,7 +62,9 @@ func (this *HTTPCachePolicyService) UpdateHTTPCachePolicy(ctx context.Context, r return nil, err } - err = models.SharedHTTPCachePolicyDAO.UpdateCachePolicy(req.HttpCachePolicyId, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) + tx := this.NullTx() + + err = models.SharedHTTPCachePolicyDAO.UpdateCachePolicy(tx, req.HttpCachePolicyId, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) if err != nil { return nil, err } @@ -74,7 +80,9 @@ func (this *HTTPCachePolicyService) DeleteHTTPCachePolicy(ctx context.Context, r return nil, err } - err = models.SharedHTTPCachePolicyDAO.DisableHTTPCachePolicy(req.HttpCachePolicyId) + tx := this.NullTx() + + err = models.SharedHTTPCachePolicyDAO.DisableHTTPCachePolicy(tx, req.HttpCachePolicyId) if err != nil { return nil, err } @@ -90,7 +98,9 @@ func (this *HTTPCachePolicyService) CountAllEnabledHTTPCachePolicies(ctx context return nil, err } - count, err := models.SharedHTTPCachePolicyDAO.CountAllEnabledHTTPCachePolicies() + tx := this.NullTx() + + count, err := models.SharedHTTPCachePolicyDAO.CountAllEnabledHTTPCachePolicies(tx) if err != nil { return nil, err } @@ -105,7 +115,9 @@ func (this *HTTPCachePolicyService) ListEnabledHTTPCachePolicies(ctx context.Con return nil, err } - cachePolicies, err := models.SharedHTTPCachePolicyDAO.ListEnabledHTTPCachePolicies(req.Offset, req.Size) + tx := this.NullTx() + + cachePolicies, err := models.SharedHTTPCachePolicyDAO.ListEnabledHTTPCachePolicies(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -124,7 +136,9 @@ func (this *HTTPCachePolicyService) FindEnabledHTTPCachePolicyConfig(ctx context return nil, err } - cachePolicy, err := models.SharedHTTPCachePolicyDAO.ComposeCachePolicy(req.HttpCachePolicyId) + tx := this.NullTx() + + cachePolicy, err := models.SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, req.HttpCachePolicyId) if err != nil { return nil, err } @@ -139,7 +153,9 @@ func (this *HTTPCachePolicyService) FindEnabledHTTPCachePolicy(ctx context.Conte return nil, err } - policy, err := models.SharedHTTPCachePolicyDAO.FindEnabledHTTPCachePolicy(req.HttpCachePolicyId) + tx := this.NullTx() + + policy, err := models.SharedHTTPCachePolicyDAO.FindEnabledHTTPCachePolicy(tx, req.HttpCachePolicyId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 7c603206..73050c40 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -24,7 +24,9 @@ func (this *HTTPFirewallPolicyService) FindAllEnabledHTTPFirewallPolicies(ctx co return nil, err } - policies, err := models.SharedHTTPFirewallPolicyDAO.FindAllEnabledFirewallPolicies() + tx := this.NullTx() + + policies, err := models.SharedHTTPFirewallPolicyDAO.FindAllEnabledFirewallPolicies(tx) if err != nil { return nil, err } @@ -52,7 +54,9 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont return nil, err } - policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(req.IsOn, req.Name, req.Description, nil, nil) + tx := this.NullTx() + + policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, req.IsOn, req.Name, req.Description, nil, nil) if err != nil { return nil, err } @@ -66,7 +70,7 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont isOn := lists.ContainsString(req.HttpFirewallGroupCodes, group.Code) group.IsOn = isOn - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) if err != nil { return nil, err } @@ -81,7 +85,7 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont isOn := lists.ContainsString(req.HttpFirewallGroupCodes, group.Code) group.IsOn = isOn - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) if err != nil { return nil, err } @@ -102,7 +106,7 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(policyId, inboundConfigJSON, outboundConfigJSON) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, inboundConfigJSON, outboundConfigJSON) if err != nil { return nil, err } @@ -120,8 +124,10 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont templatePolicy := firewallconfigs.HTTPFirewallTemplate() + tx := this.NullTx() + // 已经有的数据 - firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.HttpFirewallPolicyId) + firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -146,12 +152,12 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont if len(g.Code) > 0 { oldCodes = append(oldCodes, g.Code) if lists.ContainsString(req.FirewallGroupCodes, g.Code) { - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, true) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, true) if err != nil { return nil, err } } else { - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, false) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, false) if err != nil { return nil, err } @@ -164,12 +170,12 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont if len(g.Code) > 0 { oldCodes = append(oldCodes, g.Code) if lists.ContainsString(req.FirewallGroupCodes, g.Code) { - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, true) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, true) if err != nil { return nil, err } } else { - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, false) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, false) if err != nil { return nil, err } @@ -188,7 +194,7 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code) group.IsOn = isOn - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) if err != nil { return nil, err } @@ -207,7 +213,7 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code) group.IsOn = isOn - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) if err != nil { return nil, err } @@ -228,7 +234,7 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicy(req.HttpFirewallPolicyId, req.IsOn, req.Name, req.Description, inboundConfigJSON, outboundConfigJSON, req.BlockOptionsJSON) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicy(tx, req.HttpFirewallPolicyId, req.IsOn, req.Name, req.Description, inboundConfigJSON, outboundConfigJSON, req.BlockOptionsJSON) if err != nil { return nil, err } @@ -244,7 +250,9 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicyGroups(ctx contex return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(req.HttpFirewallPolicyId, req.InboundJSON, req.OutboundJSON) + tx := this.NullTx() + + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, req.InboundJSON, req.OutboundJSON) if err != nil { return nil, err } @@ -260,7 +268,9 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallInboundConfig(ctx conte return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInbound(req.HttpFirewallPolicyId, req.InboundJSON) + tx := this.NullTx() + + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInbound(tx, req.HttpFirewallPolicyId, req.InboundJSON) if err != nil { return nil, err } @@ -276,7 +286,9 @@ func (this *HTTPFirewallPolicyService) CountAllEnabledHTTPFirewallPolicies(ctx c return nil, err } - count, err := models.SharedHTTPFirewallPolicyDAO.CountAllEnabledFirewallPolicies() + tx := this.NullTx() + + count, err := models.SharedHTTPFirewallPolicyDAO.CountAllEnabledFirewallPolicies(tx) if err != nil { return nil, err } @@ -291,7 +303,9 @@ func (this *HTTPFirewallPolicyService) ListEnabledHTTPFirewallPolicies(ctx conte return nil, err } - policies, err := models.SharedHTTPFirewallPolicyDAO.ListEnabledFirewallPolicies(req.Offset, req.Size) + tx := this.NullTx() + + policies, err := models.SharedHTTPFirewallPolicyDAO.ListEnabledFirewallPolicies(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -319,7 +333,9 @@ func (this *HTTPFirewallPolicyService) DeleteHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(req.HttpFirewallPolicyId) + tx := this.NullTx() + + err = models.SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -335,7 +351,9 @@ func (this *HTTPFirewallPolicyService) FindEnabledHTTPFirewallPolicyConfig(ctx c return nil, err } - config, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.HttpFirewallPolicyId) + tx := this.NullTx() + + config, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -359,7 +377,9 @@ func (this *HTTPFirewallPolicyService) FindEnabledHTTPFirewallPolicy(ctx context return nil, err } - policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicy(req.HttpFirewallPolicyId) + tx := this.NullTx() + + policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicy(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -385,7 +405,9 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont // TODO 检查权限 - oldConfig, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.HttpFirewallPolicyId) + tx := this.NullTx() + + oldConfig, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -408,7 +430,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) if oldGroup == nil { // 新创建分组 - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g) if err != nil { return nil, err } @@ -419,7 +441,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont } else { setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} for _, set := range g.Sets { - setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set) if err != nil { return nil, err } @@ -432,18 +454,18 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont if err != nil { return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, oldGroup.Id, true) if err != nil { return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, oldGroup.Id, setsJSON) if err != nil { return nil, err } } } else { // 没有代号的直接创建 - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g) if err != nil { return nil, err } @@ -463,7 +485,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) if oldGroup == nil { // 新创建分组 - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g) if err != nil { return nil, err } @@ -474,7 +496,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont } else { setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} for _, set := range g.Sets { - setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set) if err != nil { return nil, err } @@ -487,18 +509,18 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont if err != nil { return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, oldGroup.Id, true) if err != nil { return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, oldGroup.Id, setsJSON) if err != nil { return nil, err } } } else { // 没有代号的直接创建 - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g) if err != nil { return nil, err } @@ -524,7 +546,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(req.HttpFirewallPolicyId, inboundJSON, outboundJSON) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, inboundJSON, outboundJSON) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_firewall_rule_group.go b/internal/rpc/services/service_http_firewall_rule_group.go index 6cd3ca77..de53ce5e 100644 --- a/internal/rpc/services/service_http_firewall_rule_group.go +++ b/internal/rpc/services/service_http_firewall_rule_group.go @@ -21,7 +21,9 @@ func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupIsOn(ctx co return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(req.FirewallRuleGroupId, req.IsOn) + tx := this.NullTx() + + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, req.FirewallRuleGroupId, req.IsOn) if err != nil { return nil, err } @@ -37,7 +39,9 @@ func (this *HTTPFirewallRuleGroupService) CreateHTTPFirewallRuleGroup(ctx contex return nil, err } - groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroup(req.IsOn, req.Name, req.Description) + tx := this.NullTx() + + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroup(tx, req.IsOn, req.Name, req.Description) if err != nil { return nil, err } @@ -52,7 +56,9 @@ func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroup(ctx contex return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroup(req.FirewallRuleGroupId, req.IsOn, req.Name, req.Description) + tx := this.NullTx() + + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroup(tx, req.FirewallRuleGroupId, req.IsOn, req.Name, req.Description) if err != nil { return nil, err } @@ -68,7 +74,9 @@ func (this *HTTPFirewallRuleGroupService) FindEnabledHTTPFirewallRuleGroupConfig return nil, err } - groupConfig, err := models.SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(req.FirewallRuleGroupId) + tx := this.NullTx() + + groupConfig, err := models.SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, req.FirewallRuleGroupId) if err != nil { return nil, err } @@ -90,7 +98,9 @@ func (this *HTTPFirewallRuleGroupService) FindEnabledHTTPFirewallRuleGroup(ctx c return nil, err } - group, err := models.SharedHTTPFirewallRuleGroupDAO.FindEnabledHTTPFirewallRuleGroup(req.FirewallRuleGroupId) + tx := this.NullTx() + + group, err := models.SharedHTTPFirewallRuleGroupDAO.FindEnabledHTTPFirewallRuleGroup(tx, req.FirewallRuleGroupId) if err != nil { return nil, err } @@ -119,7 +129,9 @@ func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupSets(ctx co return nil, err } - err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(req.GetFirewallRuleGroupId(), req.FirewallRuleSetsJSON) + tx := this.NullTx() + + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, req.GetFirewallRuleGroupId(), req.FirewallRuleSetsJSON) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_firewall_rule_set.go b/internal/rpc/services/service_http_firewall_rule_set.go index b582bc48..a37dd1ab 100644 --- a/internal/rpc/services/service_http_firewall_rule_set.go +++ b/internal/rpc/services/service_http_firewall_rule_set.go @@ -28,7 +28,9 @@ func (this *HTTPFirewallRuleSetService) CreateOrUpdateHTTPFirewallRuleSetFromCon return nil, err } - setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(setConfig) + tx := this.NullTx() + + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, setConfig) if err != nil { return nil, err } @@ -44,7 +46,9 @@ func (this *HTTPFirewallRuleSetService) UpdateHTTPFirewallRuleSetIsOn(ctx contex return nil, err } - err = models.SharedHTTPFirewallRuleSetDAO.UpdateRuleSetIsOn(req.FirewallRuleSetId, req.IsOn) + tx := this.NullTx() + + err = models.SharedHTTPFirewallRuleSetDAO.UpdateRuleSetIsOn(tx, req.FirewallRuleSetId, req.IsOn) if err != nil { return nil, err } @@ -60,7 +64,9 @@ func (this *HTTPFirewallRuleSetService) FindEnabledHTTPFirewallRuleSetConfig(ctx return nil, err } - config, err := models.SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(req.FirewallRuleSetId) + tx := this.NullTx() + + config, err := models.SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(tx, req.FirewallRuleSetId) if err != nil { return nil, err } @@ -82,7 +88,9 @@ func (this *HTTPFirewallRuleSetService) FindEnabledHTTPFirewallRuleSet(ctx conte return nil, err } - set, err := models.SharedHTTPFirewallRuleSetDAO.FindEnabledHTTPFirewallRuleSet(req.FirewallRuleSetId) + tx := this.NullTx() + + set, err := models.SharedHTTPFirewallRuleSetDAO.FindEnabledHTTPFirewallRuleSet(tx, req.FirewallRuleSetId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_header.go b/internal/rpc/services/service_http_header.go index a445138c..17c1e83b 100644 --- a/internal/rpc/services/service_http_header.go +++ b/internal/rpc/services/service_http_header.go @@ -23,7 +23,9 @@ func (this *HTTPHeaderService) CreateHTTPHeader(ctx context.Context, req *pb.Cre // TODO 检查用户权限 } - headerId, err := models.SharedHTTPHeaderDAO.CreateHeader(req.Name, req.Value) + tx := this.NullTx() + + headerId, err := models.SharedHTTPHeaderDAO.CreateHeader(tx, req.Name, req.Value) if err != nil { return nil, err } @@ -43,7 +45,9 @@ func (this *HTTPHeaderService) UpdateHTTPHeader(ctx context.Context, req *pb.Upd // TODO 检查用户权限 } - err = models.SharedHTTPHeaderDAO.UpdateHeader(req.HeaderId, req.Name, req.Value) + tx := this.NullTx() + + err = models.SharedHTTPHeaderDAO.UpdateHeader(tx, req.HeaderId, req.Name, req.Value) if err != nil { return nil, err } @@ -63,7 +67,9 @@ func (this *HTTPHeaderService) FindEnabledHTTPHeaderConfig(ctx context.Context, // TODO 检查用户权限 } - config, err := models.SharedHTTPHeaderDAO.ComposeHeaderConfig(req.HeaderId) + tx := this.NullTx() + + config, err := models.SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, req.HeaderId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_header_policy.go b/internal/rpc/services/service_http_header_policy.go index e46c2ae0..977c9f16 100644 --- a/internal/rpc/services/service_http_header_policy.go +++ b/internal/rpc/services/service_http_header_policy.go @@ -18,7 +18,9 @@ func (this *HTTPHeaderPolicyService) FindEnabledHTTPHeaderPolicyConfig(ctx conte return nil, err } - config, err := models.SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(req.HeaderPolicyId) + tx := this.NullTx() + + config, err := models.SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, req.HeaderPolicyId) if err != nil { return nil, err } @@ -38,7 +40,9 @@ func (this *HTTPHeaderPolicyService) CreateHTTPHeaderPolicy(ctx context.Context, return nil, err } - headerPolicyId, err := models.SharedHTTPHeaderPolicyDAO.CreateHeaderPolicy() + tx := this.NullTx() + + headerPolicyId, err := models.SharedHTTPHeaderPolicyDAO.CreateHeaderPolicy(tx) if err != nil { return nil, err } @@ -53,7 +57,9 @@ func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyAddingHeaders(ctx con return nil, err } - err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingHeaders(req.HeaderPolicyId, req.HeadersJSON) + tx := this.NullTx() + + err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingHeaders(tx, req.HeaderPolicyId, req.HeadersJSON) if err != nil { return nil, err } @@ -68,7 +74,9 @@ func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicySettingHeaders(ctx co return nil, err } - err = models.SharedHTTPHeaderPolicyDAO.UpdateSettingHeaders(req.HeaderPolicyId, req.HeadersJSON) + tx := this.NullTx() + + err = models.SharedHTTPHeaderPolicyDAO.UpdateSettingHeaders(tx, req.HeaderPolicyId, req.HeadersJSON) if err != nil { return nil, err } @@ -83,7 +91,9 @@ func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyAddingTrailers(ctx co return nil, err } - err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingTrailers(req.HeaderPolicyId, req.HeadersJSON) + tx := this.NullTx() + + err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingTrailers(tx, req.HeaderPolicyId, req.HeadersJSON) if err != nil { return nil, err } @@ -98,7 +108,9 @@ func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyReplacingHeaders(ctx return nil, err } - err = models.SharedHTTPHeaderPolicyDAO.UpdateReplacingHeaders(req.HeaderPolicyId, req.HeadersJSON) + tx := this.NullTx() + + err = models.SharedHTTPHeaderPolicyDAO.UpdateReplacingHeaders(tx, req.HeaderPolicyId, req.HeadersJSON) if err != nil { return nil, err } @@ -113,7 +125,9 @@ func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyDeletingHeaders(ctx c return nil, err } - err = models.SharedHTTPHeaderPolicyDAO.UpdateDeletingHeaders(req.HeaderPolicyId, req.HeaderNames) + tx := this.NullTx() + + err = models.SharedHTTPHeaderPolicyDAO.UpdateDeletingHeaders(tx, req.HeaderPolicyId, req.HeaderNames) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_location.go b/internal/rpc/services/service_http_location.go index b81fd635..8d9a8999 100644 --- a/internal/rpc/services/service_http_location.go +++ b/internal/rpc/services/service_http_location.go @@ -21,7 +21,9 @@ func (this *HTTPLocationService) CreateHTTPLocation(ctx context.Context, req *pb return nil, err } - locationId, err := models.SharedHTTPLocationDAO.CreateLocation(req.ParentId, req.Name, req.Pattern, req.Description, req.IsBreak) + tx := this.NullTx() + + locationId, err := models.SharedHTTPLocationDAO.CreateLocation(tx, req.ParentId, req.Name, req.Pattern, req.Description, req.IsBreak) if err != nil { return nil, err } @@ -37,7 +39,9 @@ func (this *HTTPLocationService) UpdateHTTPLocation(ctx context.Context, req *pb return nil, err } - err = models.SharedHTTPLocationDAO.UpdateLocation(req.LocationId, req.Name, req.Pattern, req.Description, req.IsOn, req.IsBreak) + tx := this.NullTx() + + err = models.SharedHTTPLocationDAO.UpdateLocation(tx, req.LocationId, req.Name, req.Pattern, req.Description, req.IsOn, req.IsBreak) if err != nil { return nil, err } @@ -53,7 +57,9 @@ func (this *HTTPLocationService) FindEnabledHTTPLocationConfig(ctx context.Conte return nil, err } - config, err := models.SharedHTTPLocationDAO.ComposeLocationConfig(req.LocationId) + tx := this.NullTx() + + config, err := models.SharedHTTPLocationDAO.ComposeLocationConfig(tx, req.LocationId) if err != nil { return nil, err } @@ -72,7 +78,9 @@ func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb return nil, err } - err = models.SharedHTTPLocationDAO.DisableHTTPLocation(req.LocationId) + tx := this.NullTx() + + err = models.SharedHTTPLocationDAO.DisableHTTPLocation(tx, req.LocationId) if err != nil { return nil, err } @@ -87,12 +95,14 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c return nil, err } - reverseProxyRef, err := models.SharedHTTPLocationDAO.FindLocationReverseProxy(req.LocationId) + tx := this.NullTx() + + reverseProxyRef, err := models.SharedHTTPLocationDAO.FindLocationReverseProxy(tx, req.LocationId) if err != nil { return nil, err } if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 { - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil) + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, userId, nil, nil, nil) if err != nil { return nil, err } @@ -104,13 +114,13 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c if err != nil { return nil, err } - err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(req.LocationId, reverseProxyJSON) + err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(tx, req.LocationId, reverseProxyJSON) if err != nil { return nil, err } } - reverseProxyConfig, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(reverseProxyRef.ReverseProxyId) + reverseProxyConfig, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId) if err != nil { return nil, err } @@ -138,23 +148,25 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co return nil, rpcutils.Wrap("ValidateRequest()", err) } - webId, err := models.SharedHTTPLocationDAO.FindLocationWebId(req.LocationId) + tx := this.NullTx() + + webId, err := models.SharedHTTPLocationDAO.FindLocationWebId(tx, req.LocationId) if err != nil { return nil, rpcutils.Wrap("FindLocationWebId()", err) } if webId <= 0 { - webId, err = models.SharedHTTPWebDAO.CreateWeb(adminId, userId, nil) + webId, err = models.SharedHTTPWebDAO.CreateWeb(tx, adminId, userId, nil) if err != nil { return nil, rpcutils.Wrap("CreateWeb()", err) } - err = models.SharedHTTPLocationDAO.UpdateLocationWeb(req.LocationId, webId) + err = models.SharedHTTPLocationDAO.UpdateLocationWeb(tx, req.LocationId, webId) if err != nil { return nil, rpcutils.Wrap("UpdateLocationWeb()", err) } } - config, err := models.SharedHTTPWebDAO.ComposeWebConfig(webId) + config, err := models.SharedHTTPWebDAO.ComposeWebConfig(tx, webId) if err != nil { return nil, rpcutils.Wrap("ComposeWebConfig()", err) } @@ -175,7 +187,9 @@ func (this *HTTPLocationService) UpdateHTTPLocationReverseProxy(ctx context.Cont return nil, err } - err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(req.LocationId, req.ReverseProxyJSON) + tx := this.NullTx() + + err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(tx, req.LocationId, req.ReverseProxyJSON) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_page.go b/internal/rpc/services/service_http_page.go index c6b5ee7b..618511c4 100644 --- a/internal/rpc/services/service_http_page.go +++ b/internal/rpc/services/service_http_page.go @@ -21,7 +21,9 @@ func (this *HTTPPageService) CreateHTTPPage(ctx context.Context, req *pb.CreateH return nil, err } - pageId, err := models.SharedHTTPPageDAO.CreatePage(req.StatusList, req.Url, types.Int(req.NewStatus)) + tx := this.NullTx() + + pageId, err := models.SharedHTTPPageDAO.CreatePage(tx, req.StatusList, req.Url, types.Int(req.NewStatus)) if err != nil { return nil, err } @@ -37,7 +39,9 @@ func (this *HTTPPageService) UpdateHTTPPage(ctx context.Context, req *pb.UpdateH return nil, err } - err = models.SharedHTTPPageDAO.UpdatePage(req.PageId, req.StatusList, req.Url, types.Int(req.NewStatus)) + tx := this.NullTx() + + err = models.SharedHTTPPageDAO.UpdatePage(tx, req.PageId, req.StatusList, req.Url, types.Int(req.NewStatus)) if err != nil { return nil, err } @@ -53,7 +57,9 @@ func (this *HTTPPageService) FindEnabledHTTPPageConfig(ctx context.Context, req return nil, err } - config, err := models.SharedHTTPPageDAO.ComposePageConfig(req.PageId) + tx := this.NullTx() + + config, err := models.SharedHTTPPageDAO.ComposePageConfig(tx, req.PageId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_rewrite_rule.go b/internal/rpc/services/service_http_rewrite_rule.go index 1d5153ab..1d76ea28 100644 --- a/internal/rpc/services/service_http_rewrite_rule.go +++ b/internal/rpc/services/service_http_rewrite_rule.go @@ -20,7 +20,9 @@ func (this *HTTPRewriteRuleService) CreateHTTPRewriteRule(ctx context.Context, r return nil, err } - rewriteRuleId, err := models.SharedHTTPRewriteRuleDAO.CreateRewriteRule(req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) + tx := this.NullTx() + + rewriteRuleId, err := models.SharedHTTPRewriteRuleDAO.CreateRewriteRule(tx, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) if err != nil { return nil, err } @@ -36,7 +38,9 @@ func (this *HTTPRewriteRuleService) UpdateHTTPRewriteRule(ctx context.Context, r return nil, err } - err = models.SharedHTTPRewriteRuleDAO.UpdateRewriteRule(req.RewriteRuleId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) + tx := this.NullTx() + + err = models.SharedHTTPRewriteRuleDAO.UpdateRewriteRule(tx, req.RewriteRuleId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_web.go b/internal/rpc/services/service_http_web.go index e5110410..2c33daeb 100644 --- a/internal/rpc/services/service_http_web.go +++ b/internal/rpc/services/service_http_web.go @@ -19,7 +19,9 @@ func (this *HTTPWebService) CreateHTTPWeb(ctx context.Context, req *pb.CreateHTT return nil, err } - webId, err := models.SharedHTTPWebDAO.CreateWeb(adminId, userId, req.RootJSON) + tx := this.NullTx() + + webId, err := models.SharedHTTPWebDAO.CreateWeb(tx, adminId, userId, req.RootJSON) if err != nil { return nil, err } @@ -39,7 +41,9 @@ func (this *HTTPWebService) FindEnabledHTTPWeb(ctx context.Context, req *pb.Find // TODO 检查用户权限 } - web, err := models.SharedHTTPWebDAO.FindEnabledHTTPWeb(req.WebId) + tx := this.NullTx() + + web, err := models.SharedHTTPWebDAO.FindEnabledHTTPWeb(tx, req.WebId) if err != nil { return nil, err } @@ -66,7 +70,9 @@ func (this *HTTPWebService) FindEnabledHTTPWebConfig(ctx context.Context, req *p // TODO 检查用户权限 } - config, err := models.SharedHTTPWebDAO.ComposeWebConfig(req.WebId) + tx := this.NullTx() + + config, err := models.SharedHTTPWebDAO.ComposeWebConfig(tx, req.WebId) if err != nil { return nil, err } @@ -90,7 +96,9 @@ func (this *HTTPWebService) UpdateHTTPWeb(ctx context.Context, req *pb.UpdateHTT // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWeb(req.WebId, req.RootJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWeb(tx, req.WebId, req.RootJSON) if err != nil { return nil, err } @@ -110,7 +118,9 @@ func (this *HTTPWebService) UpdateHTTPWebGzip(ctx context.Context, req *pb.Updat // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebGzip(req.WebId, req.GzipJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebGzip(tx, req.WebId, req.GzipJSON) if err != nil { return nil, err } @@ -130,7 +140,9 @@ func (this *HTTPWebService) UpdateHTTPWebCharset(ctx context.Context, req *pb.Up // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebCharset(req.WebId, req.CharsetJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebCharset(tx, req.WebId, req.CharsetJSON) if err != nil { return nil, err } @@ -149,7 +161,9 @@ func (this *HTTPWebService) UpdateHTTPWebRequestHeader(ctx context.Context, req // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebRequestHeaderPolicy(req.WebId, req.HeaderJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebRequestHeaderPolicy(tx, req.WebId, req.HeaderJSON) if err != nil { return nil, err } @@ -169,7 +183,9 @@ func (this *HTTPWebService) UpdateHTTPWebResponseHeader(ctx context.Context, req // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebResponseHeaderPolicy(req.WebId, req.HeaderJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebResponseHeaderPolicy(tx, req.WebId, req.HeaderJSON) if err != nil { return nil, err } @@ -189,7 +205,9 @@ func (this *HTTPWebService) UpdateHTTPWebShutdown(ctx context.Context, req *pb.U // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebShutdown(req.WebId, req.ShutdownJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebShutdown(tx, req.WebId, req.ShutdownJSON) if err != nil { return nil, err } @@ -208,7 +226,9 @@ func (this *HTTPWebService) UpdateHTTPWebPages(ctx context.Context, req *pb.Upda // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebPages(req.WebId, req.PagesJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebPages(tx, req.WebId, req.PagesJSON) if err != nil { return nil, err } @@ -227,7 +247,9 @@ func (this *HTTPWebService) UpdateHTTPWebAccessLog(ctx context.Context, req *pb. // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebAccessLogConfig(req.WebId, req.AccessLogJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebAccessLogConfig(tx, req.WebId, req.AccessLogJSON) if err != nil { return nil, err } @@ -246,7 +268,9 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebStat(req.WebId, req.StatJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebStat(tx, req.WebId, req.StatJSON) if err != nil { return nil, err } @@ -265,7 +289,9 @@ func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.Upda // TODO 检查权限 } - err = models.SharedHTTPWebDAO.UpdateWebCache(req.WebId, req.CacheJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebCache(tx, req.WebId, req.CacheJSON) if err != nil { return nil, err } @@ -285,7 +311,9 @@ func (this *HTTPWebService) UpdateHTTPWebFirewall(ctx context.Context, req *pb.U // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebFirewall(req.WebId, req.FirewallJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebFirewall(tx, req.WebId, req.FirewallJSON) if err != nil { return nil, err } @@ -305,7 +333,9 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb. // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebLocations(req.WebId, req.LocationsJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebLocations(tx, req.WebId, req.LocationsJSON) if err != nil { return nil, err } @@ -323,7 +353,9 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re // TODO 检查权限 - err = models.SharedHTTPWebDAO.UpdateWebRedirectToHTTPS(req.WebId, req.RedirectToHTTPSJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebRedirectToHTTPS(tx, req.WebId, req.RedirectToHTTPSJSON) if err != nil { return nil, err } @@ -340,7 +372,9 @@ func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb. // TODO 检查权限 - err = models.SharedHTTPWebDAO.UpdateWebsocket(req.WebId, req.WebsocketJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebsocket(tx, req.WebId, req.WebsocketJSON) if err != nil { return nil, err } @@ -359,7 +393,9 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req * // TODO 检查用户权限 } - err = models.SharedHTTPWebDAO.UpdateWebRewriteRules(req.WebId, req.RewriteRulesJSON) + tx := this.NullTx() + + err = models.SharedHTTPWebDAO.UpdateWebRewriteRules(tx, req.WebId, req.RewriteRulesJSON) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_websocket.go b/internal/rpc/services/service_http_websocket.go index dd7f9cb5..f34a6412 100644 --- a/internal/rpc/services/service_http_websocket.go +++ b/internal/rpc/services/service_http_websocket.go @@ -18,7 +18,9 @@ func (this *HTTPWebsocketService) CreateHTTPWebsocket(ctx context.Context, req * return nil, err } - websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) + tx := this.NullTx() + + websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(tx, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *HTTPWebsocketService) UpdateHTTPWebsocket(ctx context.Context, req * // TODO 用户不能修改别人的WebSocket设置 - err = models.SharedHTTPWebsocketDAO.UpdateWebsocket(req.WebsocketId, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) + tx := this.NullTx() + + err = models.SharedHTTPWebsocketDAO.UpdateWebsocket(tx, req.WebsocketId, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index f0ec107a..1bf1ef6d 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -20,13 +20,13 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte return nil, err } - itemId, err := models.SharedIPItemDAO.CreateIPItem(req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) + tx := this.NullTx() + + itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) if err != nil { return nil, err } - - return &pb.CreateIPItemResponse{IpItemId: itemId}, nil } @@ -38,7 +38,9 @@ func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPIte return nil, err } - err = models.SharedIPItemDAO.UpdateIPItem(req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) + tx := this.NullTx() + + err = models.SharedIPItemDAO.UpdateIPItem(tx, req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) if err != nil { return nil, err } @@ -53,7 +55,9 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte return nil, err } - err = models.SharedIPItemDAO.DisableIPItem(req.IpItemId) + tx := this.NullTx() + + err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId) if err != nil { return nil, err } @@ -68,7 +72,9 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C return nil, err } - count, err := models.SharedIPItemDAO.CountIPItemsWithListId(req.IpListId) + tx := this.NullTx() + + count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId) if err != nil { return nil, err } @@ -83,7 +89,9 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li return nil, err } - items, err := models.SharedIPItemDAO.ListIPItemsWithListId(req.IpListId, req.Offset, req.Size) + tx := this.NullTx() + + items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Offset, req.Size) if err != nil { return nil, err } @@ -110,7 +118,9 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn return nil, err } - item, err := models.SharedIPItemDAO.FindEnabledIPItem(req.IpItemId) + tx := this.NullTx() + + item, err := models.SharedIPItemDAO.FindEnabledIPItem(tx, req.IpItemId) if err != nil { return nil, err } @@ -135,8 +145,10 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. return nil, err } + tx := this.NullTx() + result := []*pb.IPItem{} - items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(req.Version, req.Size) + items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(tx, req.Version, req.Size) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_library.go b/internal/rpc/services/service_ip_library.go index 4a02db62..ab5dfb70 100644 --- a/internal/rpc/services/service_ip_library.go +++ b/internal/rpc/services/service_ip_library.go @@ -21,7 +21,9 @@ func (this *IPLibraryService) CreateIPLibrary(ctx context.Context, req *pb.Creat return nil, err } - ipLibraryId, err := models.SharedIPLibraryDAO.CreateIPLibrary(req.Type, req.FileId) + tx := this.NullTx() + + ipLibraryId, err := models.SharedIPLibraryDAO.CreateIPLibrary(tx, req.Type, req.FileId) if err != nil { return nil, err } @@ -39,7 +41,9 @@ func (this *IPLibraryService) FindEnabledIPLibrary(ctx context.Context, req *pb. return nil, err } - ipLibrary, err := models.SharedIPLibraryDAO.FindEnabledIPLibrary(req.IpLibraryId) + tx := this.NullTx() + + ipLibrary, err := models.SharedIPLibraryDAO.FindEnabledIPLibrary(tx, req.IpLibraryId) if err != nil { return nil, err } @@ -49,7 +53,7 @@ func (this *IPLibraryService) FindEnabledIPLibrary(ctx context.Context, req *pb. // 文件相关 var pbFile *pb.File = nil - file, err := models.SharedFileDAO.FindEnabledFile(int64(ipLibrary.FileId)) + file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(ipLibrary.FileId)) if err != nil { return nil, err } @@ -79,7 +83,9 @@ func (this *IPLibraryService) FindLatestIPLibraryWithType(ctx context.Context, r return nil, err } - ipLibrary, err := models.SharedIPLibraryDAO.FindLatestIPLibraryWithType(req.Type) + tx := this.NullTx() + + ipLibrary, err := models.SharedIPLibraryDAO.FindLatestIPLibraryWithType(tx, req.Type) if err != nil { return nil, err } @@ -89,7 +95,7 @@ func (this *IPLibraryService) FindLatestIPLibraryWithType(ctx context.Context, r // 文件相关 var pbFile *pb.File = nil - file, err := models.SharedFileDAO.FindEnabledFile(int64(ipLibrary.FileId)) + file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(ipLibrary.FileId)) if err != nil { return nil, err } @@ -119,7 +125,9 @@ func (this *IPLibraryService) FindAllEnabledIPLibrariesWithType(ctx context.Cont return nil, err } - ipLibraries, err := models.SharedIPLibraryDAO.FindAllEnabledIPLibrariesWithType(req.Type) + tx := this.NullTx() + + ipLibraries, err := models.SharedIPLibraryDAO.FindAllEnabledIPLibrariesWithType(tx, req.Type) if err != nil { return nil, err } @@ -127,7 +135,7 @@ func (this *IPLibraryService) FindAllEnabledIPLibrariesWithType(ctx context.Cont for _, library := range ipLibraries { // 文件相关 var pbFile *pb.File = nil - file, err := models.SharedFileDAO.FindEnabledFile(int64(library.FileId)) + file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(library.FileId)) if err != nil { return nil, err } @@ -157,7 +165,9 @@ func (this *IPLibraryService) DeleteIPLibrary(ctx context.Context, req *pb.Delet return nil, err } - err = models.SharedIPLibraryDAO.DisableIPLibrary(req.IpLibraryId) + tx := this.NullTx() + + err = models.SharedIPLibraryDAO.DisableIPLibrary(tx, req.IpLibraryId) if err != nil { return nil, err } @@ -180,12 +190,14 @@ func (this *IPLibraryService) LookupIPRegion(ctx context.Context, req *pb.Lookup return &pb.LookupIPRegionResponse{Region: nil}, nil } - countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(result.Country) + tx := this.NullTx() + + countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(tx, result.Country) if err != nil { return nil, err } - provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(result.Province) + provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(tx, result.Province) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 27698660..25981828 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -20,7 +20,9 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis return nil, err } - listId, err := models.SharedIPListDAO.CreateIPList(req.Type, req.Name, req.Code, req.TimeoutJSON) + tx := this.NullTx() + + listId, err := models.SharedIPListDAO.CreateIPList(tx, req.Type, req.Name, req.Code, req.TimeoutJSON) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPLis return nil, err } - err = models.SharedIPListDAO.UpdateIPList(req.IpListId, req.Name, req.Code, req.TimeoutJSON) + tx := this.NullTx() + + err = models.SharedIPListDAO.UpdateIPList(tx, req.IpListId, req.Name, req.Code, req.TimeoutJSON) if err != nil { return nil, err } @@ -50,7 +54,9 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn return nil, err } - list, err := models.SharedIPListDAO.FindEnabledIPList(req.IpListId) + tx := this.NullTx() + + list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_log.go b/internal/rpc/services/service_log.go index 15f8abb0..e1a357b9 100644 --- a/internal/rpc/services/service_log.go +++ b/internal/rpc/services/service_log.go @@ -20,7 +20,9 @@ func (this *LogService) CreateLog(ctx context.Context, req *pb.CreateLogRequest) return nil, err } - err = models.SharedLogDAO.CreateLog(userType, userId, req.Level, req.Description, req.Action, req.Ip) + tx := this.NullTx() + + err = models.SharedLogDAO.CreateLog(tx, userType, userId, req.Level, req.Description, req.Action, req.Ip) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *LogService) CountLogs(ctx context.Context, req *pb.CountLogRequest) return nil, err } - count, err := models.SharedLogDAO.CountLogs(req.DayFrom, req.DayTo, req.Keyword, req.UserType) + tx := this.NullTx() + + count, err := models.SharedLogDAO.CountLogs(tx, req.DayFrom, req.DayTo, req.Keyword, req.UserType) if err != nil { return nil, err } @@ -50,7 +54,9 @@ func (this *LogService) ListLogs(ctx context.Context, req *pb.ListLogsRequest) ( return nil, err } - logs, err := models.SharedLogDAO.ListLogs(req.Offset, req.Size, req.DayFrom, req.DayTo, req.Keyword, req.UserType) + tx := this.NullTx() + + logs, err := models.SharedLogDAO.ListLogs(tx, req.Offset, req.Size, req.DayFrom, req.DayTo, req.Keyword, req.UserType) if err != nil { return nil, err } @@ -59,11 +65,11 @@ func (this *LogService) ListLogs(ctx context.Context, req *pb.ListLogsRequest) ( for _, log := range logs { userName := "" if log.AdminId > 0 { - userName, err = models.SharedAdminDAO.FindAdminFullname(int64(log.AdminId)) + userName, err = models.SharedAdminDAO.FindAdminFullname(tx, int64(log.AdminId)) } else if log.UserId > 0 { - userName, err = models.SharedUserDAO.FindUserFullname(int64(log.UserId)) + userName, err = models.SharedUserDAO.FindUserFullname(tx, int64(log.UserId)) } else if log.ProviderId > 0 { - userName, err = models.SharedProviderDAO.FindProviderName(int64(log.ProviderId)) + userName, err = models.SharedProviderDAO.FindProviderName(tx, int64(log.ProviderId)) } if err != nil { @@ -97,8 +103,10 @@ func (this *LogService) DeleteLogPermanently(ctx context.Context, req *pb.Delete // TODO 校验权限 + tx := this.NullTx() + // 执行物理删除 - err = models.SharedLogDAO.DeleteLogPermanently(req.LogId) + err = models.SharedLogDAO.DeleteLogPermanently(tx, req.LogId) if err != nil { return nil, err } @@ -115,9 +123,11 @@ func (this *LogService) DeleteLogsPermanently(ctx context.Context, req *pb.Delet // TODO 校验权限 + tx := this.NullTx() + // 执行物理删除 for _, logId := range req.LogIds { - err = models.SharedLogDAO.DeleteLogPermanently(logId) + err = models.SharedLogDAO.DeleteLogPermanently(tx, logId) if err != nil { return nil, err } @@ -135,13 +145,15 @@ func (this *LogService) CleanLogsPermanently(ctx context.Context, req *pb.CleanL // TODO 校验权限 + tx := this.NullTx() + if req.ClearAll { - err = models.SharedLogDAO.DeleteAllLogsPermanently() + err = models.SharedLogDAO.DeleteAllLogsPermanently(tx) if err != nil { return nil, err } } else if req.Days > 0 { - err = models.SharedLogDAO.DeleteLogsPermanentlyBeforeDays(int(req.Days)) + err = models.SharedLogDAO.DeleteLogsPermanentlyBeforeDays(tx, int(req.Days)) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_login.go b/internal/rpc/services/service_login.go index e091f90e..229028cd 100644 --- a/internal/rpc/services/service_login.go +++ b/internal/rpc/services/service_login.go @@ -20,7 +20,10 @@ func (this *LoginService) FindEnabledLogin(ctx context.Context, req *pb.FindEnab if err != nil { return nil, err } - login, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(req.AdminId, req.Type) + + tx := this.NullTx() + + login, err := models.SharedLoginDAO.FindEnabledLoginWithAdminId(tx, req.AdminId, req.Type) if err != nil { return nil, err } @@ -48,6 +51,8 @@ func (this *LoginService) UpdateLogin(ctx context.Context, req *pb.UpdateLoginRe return nil, errors.New("'login' should not be nil") } + tx := this.NullTx() + if req.Login.IsOn { params := maps.Map{} if len(req.Login.ParamsJSON) > 0 { @@ -56,12 +61,12 @@ func (this *LoginService) UpdateLogin(ctx context.Context, req *pb.UpdateLoginRe return nil, err } } - err = models.SharedLoginDAO.UpdateLogin(req.Login.AdminId, req.Login.Type, params, req.Login.IsOn) + err = models.SharedLoginDAO.UpdateLogin(tx, req.Login.AdminId, req.Login.Type, params, req.Login.IsOn) if err != nil { return nil, err } } else { - err = models.SharedLoginDAO.DisableLoginWithAdminId(req.Login.AdminId, req.Login.Type) + err = models.SharedLoginDAO.DisableLoginWithAdminId(tx, req.Login.AdminId, req.Login.Type) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_message.go b/internal/rpc/services/service_message.go index aa81cf73..c363d01b 100644 --- a/internal/rpc/services/service_message.go +++ b/internal/rpc/services/service_message.go @@ -19,7 +19,9 @@ func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.Cou return nil, err } - count, err := models.SharedMessageDAO.CountUnreadMessages(adminId, userId) + tx := this.NullTx() + + count, err := models.SharedMessageDAO.CountUnreadMessages(tx, adminId, userId) if err != nil { return nil, err } @@ -34,7 +36,9 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List return nil, err } - messages, err := models.SharedMessageDAO.ListUnreadMessages(adminId, userId, req.Offset, req.Size) + tx := this.NullTx() + + messages, err := models.SharedMessageDAO.ListUnreadMessages(tx, adminId, userId, req.Offset, req.Size) if err != nil { return nil, err } @@ -44,7 +48,7 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List var pbNode *pb.Node = nil if message.ClusterId > 0 { - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(int64(message.ClusterId)) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, int64(message.ClusterId)) if err != nil { return nil, err } @@ -57,7 +61,7 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List } if message.NodeId > 0 { - node, err := models.SharedNodeDAO.FindEnabledNode(int64(message.NodeId)) + node, err := models.SharedNodeDAO.FindEnabledNode(tx, int64(message.NodeId)) if err != nil { return nil, err } @@ -93,8 +97,10 @@ func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.Updat return nil, err } + tx := this.NullTx() + // 校验权限 - exists, err := models.SharedMessageDAO.CheckMessageUser(req.MessageId, adminId, userId) + exists, err := models.SharedMessageDAO.CheckMessageUser(tx, req.MessageId, adminId, userId) if err != nil { return nil, err } @@ -102,7 +108,7 @@ func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.Updat return nil, this.PermissionError() } - err = models.SharedMessageDAO.UpdateMessageRead(req.MessageId, req.IsRead) + err = models.SharedMessageDAO.UpdateMessageRead(tx, req.MessageId, req.IsRead) if err != nil { return nil, err } @@ -117,9 +123,11 @@ func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.Upda return nil, err } + tx := this.NullTx() + // 校验权限 for _, messageId := range req.MessageIds { - exists, err := models.SharedMessageDAO.CheckMessageUser(messageId, adminId, userId) + exists, err := models.SharedMessageDAO.CheckMessageUser(tx, messageId, adminId, userId) if err != nil { return nil, err } @@ -127,7 +135,7 @@ func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.Upda return nil, this.PermissionError() } - err = models.SharedMessageDAO.UpdateMessageRead(messageId, req.IsRead) + err = models.SharedMessageDAO.UpdateMessageRead(tx, messageId, req.IsRead) if err != nil { return nil, err } @@ -144,7 +152,9 @@ func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.U return nil, err } - err = models.SharedMessageDAO.UpdateAllMessagesRead(adminId, userId) + tx := this.NullTx() + + err = models.SharedMessageDAO.UpdateAllMessagesRead(tx, adminId, userId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node.go b/internal/rpc/services/service_node.go index 7b707487..84d8cdfd 100644 --- a/internal/rpc/services/service_node.go +++ b/internal/rpc/services/service_node.go @@ -43,14 +43,16 @@ func (this *NodeService) CreateNode(ctx context.Context, req *pb.CreateNodeReque return nil, err } - nodeId, err := models.SharedNodeDAO.CreateNode(adminId, req.Name, req.NodeClusterId, req.GroupId, req.RegionId) + tx := this.NullTx() + + nodeId, err := models.SharedNodeDAO.CreateNode(tx, adminId, req.Name, req.NodeClusterId, req.GroupId, req.RegionId) if err != nil { return nil, err } // 增加认证相关 if req.Login != nil { - _, err = models.SharedNodeLoginDAO.CreateNodeLogin(nodeId, req.Login.Name, req.Login.Type, req.Login.Params) + _, err = models.SharedNodeLoginDAO.CreateNodeLogin(tx, nodeId, req.Login.Name, req.Login.Type, req.Login.Params) if err != nil { return nil, err } @@ -58,7 +60,7 @@ func (this *NodeService) CreateNode(ctx context.Context, req *pb.CreateNodeReque // 保存DNS相关 if req.DnsDomainId > 0 && len(req.DnsRoutes) > 0 { - err = models.SharedNodeDAO.UpdateNodeDNS(nodeId, map[int64][]string{ + err = models.SharedNodeDAO.UpdateNodeDNS(tx, nodeId, map[int64][]string{ req.DnsDomainId: req.DnsRoutes, }) if err != nil { @@ -87,21 +89,23 @@ func (this *NodeService) RegisterClusterNode(ctx context.Context, req *pb.Regist return nil, err } - adminId, err := models.SharedNodeClusterDAO.FindClusterAdminId(clusterId) + tx := this.NullTx() + + adminId, err := models.SharedNodeClusterDAO.FindClusterAdminId(tx, clusterId) if err != nil { return nil, err } - nodeId, err := models.SharedNodeDAO.CreateNode(adminId, req.Name, clusterId, 0, 0) + nodeId, err := models.SharedNodeDAO.CreateNode(tx, adminId, req.Name, clusterId, 0, 0) if err != nil { return nil, err } - err = models.SharedNodeDAO.UpdateNodeIsInstalled(nodeId, true) + err = models.SharedNodeDAO.UpdateNodeIsInstalled(tx, nodeId, true) if err != nil { return nil, err } - node, err := models.SharedNodeDAO.FindEnabledNode(nodeId) + node, err := models.SharedNodeDAO.FindEnabledNode(tx, nodeId) if err != nil { return nil, err } @@ -110,7 +114,7 @@ func (this *NodeService) RegisterClusterNode(ctx context.Context, req *pb.Regist } // 获取集群可以使用的所有API节点 - apiAddrs, err := models.SharedNodeClusterDAO.FindAllAPINodeAddrsWithCluster(clusterId) + apiAddrs, err := models.SharedNodeClusterDAO.FindAllAPINodeAddrsWithCluster(tx, clusterId) if err != nil { return nil, err } @@ -130,7 +134,9 @@ func (this *NodeService) CountAllEnabledNodes(ctx context.Context, req *pb.Count return nil, err } - count, err := models.SharedNodeDAO.CountAllEnabledNodes() + tx := this.NullTx() + + count, err := models.SharedNodeDAO.CountAllEnabledNodes(tx) if err != nil { return nil, err } @@ -144,7 +150,10 @@ func (this *NodeService) CountAllEnabledNodesMatch(ctx context.Context, req *pb. if err != nil { return nil, err } - count, err := models.SharedNodeDAO.CountAllEnabledNodesMatch(req.NodeClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword, req.GroupId, req.RegionId) + + tx := this.NullTx() + + count, err := models.SharedNodeDAO.CountAllEnabledNodesMatch(tx, req.NodeClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword, req.GroupId, req.RegionId) if err != nil { return nil, err } @@ -158,7 +167,9 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List return nil, err } - clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(req.NodeClusterId) + tx := this.NullTx() + + clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -168,21 +179,21 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List if clusterDNS != nil { dnsDomainId = int64(clusterDNS.DnsDomainId) if clusterDNS.DnsDomainId > 0 { - domainRoutes, err = models.SharedDNSDomainDAO.FindDomainRoutes(dnsDomainId) + domainRoutes, err = models.SharedDNSDomainDAO.FindDomainRoutes(tx, dnsDomainId) if err != nil { return nil, err } } } - nodes, err := models.SharedNodeDAO.ListEnabledNodesMatch(req.Offset, req.Size, req.NodeClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword, req.GroupId, req.RegionId) + nodes, err := models.SharedNodeDAO.ListEnabledNodesMatch(tx, req.Offset, req.Size, req.NodeClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword, req.GroupId, req.RegionId) if err != nil { return nil, err } result := []*pb.Node{} for _, node := range nodes { // 集群信息 - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(node.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(node.ClusterId)) if err != nil { return nil, err } @@ -207,7 +218,7 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List // 分组信息 var pbGroup *pb.NodeGroup = nil if node.GroupId > 0 { - group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(int64(node.GroupId)) + group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(tx, int64(node.GroupId)) if err != nil { return nil, err } @@ -240,7 +251,7 @@ func (this *NodeService) ListEnabledNodesMatch(ctx context.Context, req *pb.List // 区域 var pbRegion *pb.NodeRegion = nil if node.RegionId > 0 { - region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(int64(node.RegionId)) + region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(tx, int64(node.RegionId)) if err != nil { return nil, err } @@ -289,7 +300,9 @@ func (this *NodeService) FindAllEnabledNodesWithClusterId(ctx context.Context, r // TODO 检查权限 } - nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(req.NodeClusterId) + tx := this.NullTx() + + nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -323,7 +336,9 @@ func (this *NodeService) DeleteNode(ctx context.Context, req *pb.DeleteNodeReque return nil, err } - err = models.SharedNodeDAO.DisableNode(req.NodeId) + tx := this.NullTx() + + err = models.SharedNodeDAO.DisableNode(tx, req.NodeId) if err != nil { return nil, err } @@ -346,24 +361,26 @@ func (this *NodeService) UpdateNode(ctx context.Context, req *pb.UpdateNodeReque return nil, err } - err = models.SharedNodeDAO.UpdateNode(req.NodeId, req.Name, req.NodeClusterId, req.GroupId, req.RegionId, req.MaxCPU, req.IsOn) + tx := this.NullTx() + + err = models.SharedNodeDAO.UpdateNode(tx, req.NodeId, req.Name, req.NodeClusterId, req.GroupId, req.RegionId, req.MaxCPU, req.IsOn) if err != nil { return nil, err } if req.Login == nil { - err = models.SharedNodeLoginDAO.DisableNodeLogins(req.NodeId) + err = models.SharedNodeLoginDAO.DisableNodeLogins(tx, req.NodeId) if err != nil { return nil, err } } else { if req.Login.Id > 0 { - err = models.SharedNodeLoginDAO.UpdateNodeLogin(req.Login.Id, req.Login.Name, req.Login.Type, req.Login.Params) + err = models.SharedNodeLoginDAO.UpdateNodeLogin(tx, req.Login.Id, req.Login.Name, req.Login.Type, req.Login.Params) if err != nil { return nil, err } } else { - _, err = models.SharedNodeLoginDAO.CreateNodeLogin(req.NodeId, req.Login.Name, req.Login.Type, req.Login.Params) + _, err = models.SharedNodeLoginDAO.CreateNodeLogin(tx, req.NodeId, req.Login.Name, req.Login.Type, req.Login.Params) if err != nil { return nil, err } @@ -372,7 +389,7 @@ func (this *NodeService) UpdateNode(ctx context.Context, req *pb.UpdateNodeReque // 保存DNS相关 if req.DnsDomainId > 0 && len(req.DnsRoutes) > 0 { - err = models.SharedNodeDAO.UpdateNodeDNS(req.NodeId, map[int64][]string{ + err = models.SharedNodeDAO.UpdateNodeDNS(tx, req.NodeId, map[int64][]string{ req.DnsDomainId: req.DnsRoutes, }) if err != nil { @@ -399,7 +416,9 @@ func (this *NodeService) FindEnabledNode(ctx context.Context, req *pb.FindEnable return nil, err } - node, err := models.SharedNodeDAO.FindEnabledNode(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedNodeDAO.FindEnabledNode(tx, req.NodeId) if err != nil { return nil, err } @@ -408,13 +427,13 @@ func (this *NodeService) FindEnabledNode(ctx context.Context, req *pb.FindEnable } // 集群信息 - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(node.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(node.ClusterId)) if err != nil { return nil, err } // 认证信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(req.NodeId) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(tx, req.NodeId) if err != nil { return nil, err } @@ -448,7 +467,7 @@ func (this *NodeService) FindEnabledNode(ctx context.Context, req *pb.FindEnable // 分组信息 var pbGroup *pb.NodeGroup = nil if node.GroupId > 0 { - group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(int64(node.GroupId)) + group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(tx, int64(node.GroupId)) if err != nil { return nil, err } @@ -463,7 +482,7 @@ func (this *NodeService) FindEnabledNode(ctx context.Context, req *pb.FindEnable // 区域 var pbRegion *pb.NodeRegion = nil if node.RegionId > 0 { - region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(int64(node.RegionId)) + region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(tx, int64(node.RegionId)) if err != nil { return nil, err } @@ -509,8 +528,10 @@ func (this *NodeService) FindCurrentNodeConfig(ctx context.Context, req *pb.Find return nil, err } + tx := this.NullTx() + // 检查版本号 - currentVersion, err := models.SharedNodeDAO.FindNodeVersion(nodeId) + currentVersion, err := models.SharedNodeDAO.FindNodeVersion(tx, nodeId) if err != nil { return nil, err } @@ -518,7 +539,7 @@ func (this *NodeService) FindCurrentNodeConfig(ctx context.Context, req *pb.Find return &pb.FindCurrentNodeConfigResponse{IsChanged: false}, nil } - nodeConfig, err := models.SharedNodeDAO.ComposeNodeConfig(nodeId) + nodeConfig, err := models.SharedNodeDAO.ComposeNodeConfig(tx, nodeId) if err != nil { return nil, err } @@ -547,7 +568,9 @@ func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNod return nil, errors.New("'nodeId' should be greater than 0") } - err = models.SharedNodeDAO.UpdateNodeStatus(nodeId, req.StatusJSON) + tx := this.NullTx() + + err = models.SharedNodeDAO.UpdateNodeStatus(tx, nodeId, req.StatusJSON) if err != nil { return nil, err } @@ -561,7 +584,9 @@ func (this *NodeService) SyncNodesVersionWithCluster(ctx context.Context, req *p return nil, err } - err = models.SharedNodeDAO.SyncNodeVersionsWithCluster(req.NodeClusterId) + tx := this.NullTx() + + err = models.SharedNodeDAO.SyncNodeVersionsWithCluster(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -576,7 +601,9 @@ func (this *NodeService) UpdateNodeIsInstalled(ctx context.Context, req *pb.Upda return nil, err } - err = models.SharedNodeDAO.UpdateNodeIsInstalled(req.NodeId, req.IsInstalled) + tx := this.NullTx() + + err = models.SharedNodeDAO.UpdateNodeIsInstalled(tx, req.NodeId, req.IsInstalled) if err != nil { return nil, err } @@ -610,13 +637,15 @@ func (this *NodeService) UpgradeNode(ctx context.Context, req *pb.UpgradeNodeReq return nil, err } - err = models.SharedNodeDAO.UpdateNodeIsInstalled(req.NodeId, false) + tx := this.NullTx() + + err = models.SharedNodeDAO.UpdateNodeIsInstalled(tx, req.NodeId, false) if err != nil { return nil, err } // 检查状态 - installStatus, err := models.SharedNodeDAO.FindNodeInstallStatus(req.NodeId) + installStatus, err := models.SharedNodeDAO.FindNodeInstallStatus(tx, req.NodeId) if err != nil { return nil, err } @@ -625,7 +654,7 @@ func (this *NodeService) UpgradeNode(ctx context.Context, req *pb.UpgradeNodeReq } installStatus.IsOk = false installStatus.IsFinished = false - err = models.SharedNodeDAO.UpdateNodeInstallStatus(req.NodeId, installStatus) + err = models.SharedNodeDAO.UpdateNodeInstallStatus(tx, req.NodeId, installStatus) if err != nil { return nil, err } @@ -702,7 +731,9 @@ func (this *NodeService) UpdateNodeConnectedAPINodes(ctx context.Context, req *p return nil, err } - err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(nodeId, req.ApiNodeIds) + tx := this.NullTx() + + err = models.SharedNodeDAO.UpdateNodeConnectedAPINodes(tx, nodeId, req.ApiNodeIds) if err != nil { return nil, errors.Wrap(err) } @@ -718,7 +749,9 @@ func (this *NodeService) CountAllEnabledNodesWithGrantId(ctx context.Context, re return nil, err } - count, err := models.SharedNodeDAO.CountAllEnabledNodesWithGrantId(req.GrantId) + tx := this.NullTx() + + count, err := models.SharedNodeDAO.CountAllEnabledNodesWithGrantId(tx, req.GrantId) if err != nil { return nil, err } @@ -733,7 +766,9 @@ func (this *NodeService) FindAllEnabledNodesWithGrantId(ctx context.Context, req return nil, err } - nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithGrantId(req.GrantId) + tx := this.NullTx() + + nodes, err := models.SharedNodeDAO.FindAllEnabledNodesWithGrantId(tx, req.GrantId) if err != nil { return nil, err } @@ -741,7 +776,7 @@ func (this *NodeService) FindAllEnabledNodesWithGrantId(ctx context.Context, req result := []*pb.Node{} for _, node := range nodes { // 集群信息 - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(node.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(node.ClusterId)) if err != nil { return nil, err } @@ -771,14 +806,16 @@ func (this *NodeService) FindAllNotInstalledNodesWithClusterId(ctx context.Conte return nil, err } - nodes, err := models.SharedNodeDAO.FindAllNotInstalledNodesWithClusterId(req.NodeClusterId) + tx := this.NullTx() + + nodes, err := models.SharedNodeDAO.FindAllNotInstalledNodesWithClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } result := []*pb.Node{} for _, node := range nodes { // 认证信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(int64(node.Id)) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(tx, int64(node.Id)) if err != nil { return nil, err } @@ -793,7 +830,7 @@ func (this *NodeService) FindAllNotInstalledNodesWithClusterId(ctx context.Conte } // IP信息 - addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(int64(node.Id)) + addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(tx, int64(node.Id)) if err != nil { return nil, err } @@ -852,10 +889,12 @@ func (this *NodeService) CountAllUpgradeNodesWithClusterId(ctx context.Context, return nil, err } + tx := this.NullTx() + deployFiles := installers.SharedDeployManager.LoadFiles() total := int64(0) for _, deployFile := range deployFiles { - count, err := models.SharedNodeDAO.CountAllLowerVersionNodesWithClusterId(req.NodeClusterId, deployFile.OS, deployFile.Arch, deployFile.Version) + count, err := models.SharedNodeDAO.CountAllLowerVersionNodesWithClusterId(tx, req.NodeClusterId, deployFile.OS, deployFile.Arch, deployFile.Version) if err != nil { return nil, err } @@ -873,17 +912,19 @@ func (this *NodeService) FindAllUpgradeNodesWithClusterId(ctx context.Context, r return nil, err } + tx := this.NullTx() + // 获取当前能升级到的最新版本 deployFiles := installers.SharedDeployManager.LoadFiles() result := []*pb.FindAllUpgradeNodesWithClusterIdResponse_NodeUpgrade{} for _, deployFile := range deployFiles { - nodes, err := models.SharedNodeDAO.FindAllLowerVersionNodesWithClusterId(req.NodeClusterId, deployFile.OS, deployFile.Arch, deployFile.Version) + nodes, err := models.SharedNodeDAO.FindAllLowerVersionNodesWithClusterId(tx, req.NodeClusterId, deployFile.OS, deployFile.Arch, deployFile.Version) if err != nil { return nil, err } for _, node := range nodes { // 认证信息 - login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(int64(node.Id)) + login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(tx, int64(node.Id)) if err != nil { return nil, err } @@ -898,7 +939,7 @@ func (this *NodeService) FindAllUpgradeNodesWithClusterId(ctx context.Context, r } // IP信息 - addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(int64(node.Id)) + addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(tx, int64(node.Id)) if err != nil { return nil, err } @@ -977,7 +1018,9 @@ func (this *NodeService) FindNodeInstallStatus(ctx context.Context, req *pb.Find return nil, err } - installStatus, err := models.SharedNodeDAO.FindNodeInstallStatus(req.NodeId) + tx := this.NullTx() + + installStatus, err := models.SharedNodeDAO.FindNodeInstallStatus(tx, req.NodeId) if err != nil { return nil, err } @@ -1004,14 +1047,16 @@ func (this *NodeService) UpdateNodeLogin(ctx context.Context, req *pb.UpdateNode return nil, err } + tx := this.NullTx() + if req.Login.Id <= 0 { - _, err := models.SharedNodeLoginDAO.CreateNodeLogin(req.NodeId, req.Login.Name, req.Login.Type, req.Login.Params) + _, err := models.SharedNodeLoginDAO.CreateNodeLogin(tx, req.NodeId, req.Login.Name, req.Login.Type, req.Login.Params) if err != nil { return nil, err } } - err = models.SharedNodeLoginDAO.UpdateNodeLogin(req.Login.Id, req.Login.Name, req.Login.Type, req.Login.Params) + err = models.SharedNodeLoginDAO.UpdateNodeLogin(tx, req.Login.Id, req.Login.Name, req.Login.Type, req.Login.Params) return this.Success() } @@ -1024,7 +1069,9 @@ func (this *NodeService) CountAllEnabledNodesWithNodeGroupId(ctx context.Context return nil, err } - count, err := models.SharedNodeDAO.CountAllEnabledNodesWithGroupId(req.NodeGroupId) + tx := this.NullTx() + + count, err := models.SharedNodeDAO.CountAllEnabledNodesWithGroupId(tx, req.NodeGroupId) if err != nil { return nil, err } @@ -1039,7 +1086,9 @@ func (this *NodeService) FindAllEnabledNodesDNSWithClusterId(ctx context.Context return nil, err } - clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(req.NodeClusterId) + tx := this.NullTx() + + clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -1048,18 +1097,18 @@ func (this *NodeService) FindAllEnabledNodesDNSWithClusterId(ctx context.Context } dnsDomainId := int64(clusterDNS.DnsDomainId) - routes, err := models.SharedDNSDomainDAO.FindDomainRoutes(dnsDomainId) + routes, err := models.SharedDNSDomainDAO.FindDomainRoutes(tx, dnsDomainId) if err != nil { return nil, err } - nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(req.NodeClusterId) + nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } result := []*pb.NodeDNSInfo{} for _, node := range nodes { - ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(int64(node.Id)) + ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(tx, int64(node.Id)) if err != nil { return nil, err } @@ -1100,7 +1149,9 @@ func (this *NodeService) FindEnabledNodeDNS(ctx context.Context, req *pb.FindEna return nil, err } - node, err := models.SharedNodeDAO.FindEnabledNodeDNS(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedNodeDAO.FindEnabledNodeDNS(tx, req.NodeId) if err != nil { return nil, err } @@ -1109,13 +1160,13 @@ func (this *NodeService) FindEnabledNodeDNS(ctx context.Context, req *pb.FindEna return &pb.FindEnabledNodeDNSResponse{Node: nil}, nil } - ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(int64(node.Id)) + ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddress(tx, int64(node.Id)) if err != nil { return nil, err } clusterId := int64(node.ClusterId) - clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(clusterId) + clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) if err != nil { return nil, err } @@ -1124,7 +1175,7 @@ func (this *NodeService) FindEnabledNodeDNS(ctx context.Context, req *pb.FindEna } dnsDomainId := int64(clusterDNS.DnsDomainId) - dnsDomainName, err := models.SharedDNSDomainDAO.FindDNSDomainName(dnsDomainId) + dnsDomainName, err := models.SharedDNSDomainDAO.FindDNSDomainName(tx, dnsDomainId) if err != nil { return nil, err } @@ -1137,7 +1188,7 @@ func (this *NodeService) FindEnabledNodeDNS(ctx context.Context, req *pb.FindEna } for _, routeCode := range routeCodes { - routeName, err := models.SharedDNSDomainDAO.FindDomainRouteName(dnsDomainId, routeCode) + routeName, err := models.SharedDNSDomainDAO.FindDomainRouteName(tx, dnsDomainId, routeCode) if err != nil { return nil, err } @@ -1170,7 +1221,9 @@ func (this *NodeService) UpdateNodeDNS(ctx context.Context, req *pb.UpdateNodeDN return nil, err } - node, err := models.SharedNodeDAO.FindEnabledNodeDNS(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedNodeDAO.FindEnabledNodeDNS(tx, req.NodeId) if err != nil { return nil, err } @@ -1187,24 +1240,24 @@ func (this *NodeService) UpdateNodeDNS(ctx context.Context, req *pb.UpdateNodeDN routeCodeMap[req.DnsDomainId] = req.Routes } - err = models.SharedNodeDAO.UpdateNodeDNS(req.NodeId, routeCodeMap) + err = models.SharedNodeDAO.UpdateNodeDNS(tx, req.NodeId, routeCodeMap) if err != nil { return nil, err } // 修改IP if len(req.IpAddr) > 0 { - ipAddrId, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddressId(req.NodeId) + ipAddrId, err := models.SharedNodeIPAddressDAO.FindFirstNodeIPAddressId(tx, req.NodeId) if err != nil { return nil, err } if ipAddrId > 0 { - err = models.SharedNodeIPAddressDAO.UpdateAddressIP(ipAddrId, req.IpAddr) + err = models.SharedNodeIPAddressDAO.UpdateAddressIP(tx, ipAddrId, req.IpAddr) if err != nil { return nil, err } } else { - _, err = models.SharedNodeIPAddressDAO.CreateAddress(req.NodeId, "DNS IP", req.IpAddr, true) + _, err = models.SharedNodeIPAddressDAO.CreateAddress(tx, req.NodeId, "DNS IP", req.IpAddr, true) if err != nil { return nil, err } @@ -1216,11 +1269,13 @@ func (this *NodeService) UpdateNodeDNS(ctx context.Context, req *pb.UpdateNodeDN // 自动同步DNS状态 func (this *NodeService) notifyNodeDNSChanged(nodeId int64) error { - clusterId, err := models.SharedNodeDAO.FindNodeClusterId(nodeId) + tx := this.NullTx() + + clusterId, err := models.SharedNodeDAO.FindNodeClusterId(tx, nodeId) if err != nil { return err } - dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(clusterId) + dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) if err != nil { return err } @@ -1248,7 +1303,7 @@ func (this *NodeService) notifyNodeDNSChanged(nodeId int64) error { return err } if !resp.IsOk { - err = models.SharedMessageDAO.CreateClusterMessage(clusterId, models.MessageTypeClusterDNSSyncFailed, models.LevelError, "集群DNS同步失败:"+resp.Error, nil) + err = models.SharedMessageDAO.CreateClusterMessage(tx, clusterId, models.MessageTypeClusterDNSSyncFailed, models.LevelError, "集群DNS同步失败:"+resp.Error, nil) if err != nil { logs.Println("[NODE_SERVICE]" + err.Error()) } @@ -1262,7 +1317,10 @@ func (this *NodeService) CountAllEnabledNodesWithNodeRegionId(ctx context.Contex if err != nil { return nil, err } - count, err := models.SharedNodeDAO.CountAllEnabledNodesWithRegionId(req.NodeRegionId) + + tx := this.NullTx() + + count, err := models.SharedNodeDAO.CountAllEnabledNodesWithRegionId(tx, req.NodeRegionId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 078d79cb..e43b49f0 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -24,7 +24,9 @@ func (this *NodeClusterService) CreateNodeCluster(ctx context.Context, req *pb.C return nil, err } - clusterId, err := models.SharedNodeClusterDAO.CreateCluster(adminId, req.Name, req.GrantId, req.InstallDir, req.DnsDomainId, req.DnsName, req.HttpCachePolicyId, req.HttpFirewallPolicyId) + tx := this.NullTx() + + clusterId, err := models.SharedNodeClusterDAO.CreateCluster(tx, adminId, req.Name, req.GrantId, req.InstallDir, req.DnsDomainId, req.DnsName, req.HttpCachePolicyId, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -39,7 +41,9 @@ func (this *NodeClusterService) UpdateNodeCluster(ctx context.Context, req *pb.U return nil, err } - err = models.SharedNodeClusterDAO.UpdateCluster(req.NodeClusterId, req.Name, req.GrantId, req.InstallDir) + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateCluster(tx, req.NodeClusterId, req.Name, req.GrantId, req.InstallDir) if err != nil { return nil, err } @@ -54,7 +58,9 @@ func (this *NodeClusterService) DeleteNodeCluster(ctx context.Context, req *pb.D return nil, err } - err = models.SharedNodeClusterDAO.DisableNodeCluster(req.NodeClusterId) + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.DisableNodeCluster(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -73,7 +79,9 @@ func (this *NodeClusterService) FindEnabledNodeCluster(ctx context.Context, req // TODO 检查用户是否有权限 } - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(req.NodeClusterId) + tx := this.NullTx() + + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -105,7 +113,9 @@ func (this *NodeClusterService) FindAPINodesWithNodeCluster(ctx context.Context, return nil, err } - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(req.NodeClusterId) + tx := this.NullTx() + + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -125,7 +135,7 @@ func (this *NodeClusterService) FindAPINodesWithNodeCluster(ctx context.Context, if len(apiNodeIds) > 0 { apiNodes := []*pb.APINode{} for _, apiNodeId := range apiNodeIds { - apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINode(apiNodeId) + apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINode(tx, apiNodeId) if err != nil { return nil, err } @@ -157,7 +167,9 @@ func (this *NodeClusterService) FindAllEnabledNodeClusters(ctx context.Context, return nil, err } - clusters, err := models.SharedNodeClusterDAO.FindAllEnableClusters() + tx := this.NullTx() + + clusters, err := models.SharedNodeClusterDAO.FindAllEnableClusters(tx) if err != nil { return nil, err } @@ -185,7 +197,9 @@ func (this *NodeClusterService) FindAllChangedNodeClusters(ctx context.Context, return nil, err } - clusterIds, err := models.SharedNodeDAO.FindChangedClusterIds() + tx := this.NullTx() + + clusterIds, err := models.SharedNodeDAO.FindChangedClusterIds(tx) if err != nil { return nil, err } @@ -196,7 +210,7 @@ func (this *NodeClusterService) FindAllChangedNodeClusters(ctx context.Context, } result := []*pb.NodeCluster{} for _, clusterId := range clusterIds { - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(clusterId) + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, clusterId) if err != nil { return nil, err } @@ -221,7 +235,9 @@ func (this *NodeClusterService) CountAllEnabledNodeClusters(ctx context.Context, return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledClusters(req.Keyword) + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledClusters(tx, req.Keyword) if err != nil { return nil, err } @@ -236,7 +252,9 @@ func (this *NodeClusterService) ListEnabledNodeClusters(ctx context.Context, req return nil, err } - clusters, err := models.SharedNodeClusterDAO.ListEnabledClusters(req.Keyword, req.Offset, req.Size) + tx := this.NullTx() + + clusters, err := models.SharedNodeClusterDAO.ListEnabledClusters(tx, req.Keyword, req.Offset, req.Size) if err != nil { return nil, err } @@ -267,7 +285,9 @@ func (this *NodeClusterService) FindNodeClusterHealthCheckConfig(ctx context.Con return nil, err } - config, err := models.SharedNodeClusterDAO.FindClusterHealthCheckConfig(req.NodeClusterId) + tx := this.NullTx() + + config, err := models.SharedNodeClusterDAO.FindClusterHealthCheckConfig(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -286,7 +306,9 @@ func (this *NodeClusterService) UpdateNodeClusterHealthCheck(ctx context.Context return nil, err } - err = models.SharedNodeClusterDAO.UpdateClusterHealthCheck(req.NodeClusterId, req.HealthCheckJSON) + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateClusterHealthCheck(tx, req.NodeClusterId, req.HealthCheckJSON) if err != nil { return nil, err } @@ -330,7 +352,9 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithGrantId(ctx conte return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithGrantId(req.GrantId) + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithGrantId(tx, req.GrantId) if err != nil { return nil, err } @@ -345,7 +369,9 @@ func (this *NodeClusterService) FindAllEnabledNodeClustersWithGrantId(ctx contex return nil, err } - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithGrantId(req.GrantId) + tx := this.NullTx() + + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithGrantId(tx, req.GrantId) if err != nil { return nil, err } @@ -371,7 +397,9 @@ func (this *NodeClusterService) FindEnabledNodeClusterDNS(ctx context.Context, r return nil, err } - dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(req.NodeClusterId) + tx := this.NullTx() + + dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -398,7 +426,7 @@ func (this *NodeClusterService) FindEnabledNodeClusterDNS(ctx context.Context, r }, nil } - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(int64(dnsInfo.DnsDomainId)) + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, int64(dnsInfo.DnsDomainId)) if err != nil { return nil, err } @@ -415,7 +443,7 @@ func (this *NodeClusterService) FindEnabledNodeClusterDNS(ctx context.Context, r IsOn: domain.IsOn == 1, } - provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(domain.ProviderId)) + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(domain.ProviderId)) if err != nil { return nil, err } @@ -447,7 +475,9 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithDNSProviderId(ctx return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithDNSProviderId(req.DnsProviderId) + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithDNSProviderId(tx, req.DnsProviderId) if err != nil { return nil, err } @@ -462,7 +492,9 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithDNSDomainId(ctx c return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithDNSDomainId(req.DnsDomainId) + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledClustersWithDNSDomainId(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -477,7 +509,9 @@ func (this *NodeClusterService) FindAllEnabledNodeClustersWithDNSDomainId(ctx co return nil, err } - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(req.DnsDomainId) + tx := this.NullTx() + + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, req.DnsDomainId) if err != nil { return nil, err } @@ -502,7 +536,9 @@ func (this *NodeClusterService) CheckNodeClusterDNSName(ctx context.Context, req return nil, err } - exists, err := models.SharedNodeClusterDAO.ExistClusterDNSName(req.DnsName, req.NodeClusterId) + tx := this.NullTx() + + exists, err := models.SharedNodeClusterDAO.ExistClusterDNSName(tx, req.DnsName, req.NodeClusterId) if err != nil { return nil, err } @@ -517,7 +553,9 @@ func (this *NodeClusterService) UpdateNodeClusterDNS(ctx context.Context, req *p return nil, err } - err = models.SharedNodeClusterDAO.UpdateClusterDNS(req.NodeClusterId, req.DnsName, req.DnsDomainId, req.NodesAutoSync, req.ServersAutoSync) + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateClusterDNS(tx, req.NodeClusterId, req.DnsName, req.DnsDomainId, req.NodesAutoSync, req.ServersAutoSync) if err != nil { return nil, err } @@ -532,7 +570,9 @@ func (this *NodeClusterService) CheckNodeClusterDNSChanges(ctx context.Context, return nil, err } - cluster, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(req.NodeClusterId) + tx := this.NullTx() + + cluster, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -542,7 +582,7 @@ func (this *NodeClusterService) CheckNodeClusterDNSChanges(ctx context.Context, } domainId := int64(cluster.DnsDomainId) - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(domainId) + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, domainId) if err != nil { return nil, err } @@ -572,7 +612,9 @@ func (this *NodeClusterService) FindEnabledNodeClusterTOA(ctx context.Context, r // TODO 检查权限 - config, err := models.SharedNodeClusterDAO.FindClusterTOAConfig(req.NodeClusterId) + tx := this.NullTx() + + config, err := models.SharedNodeClusterDAO.FindClusterTOAConfig(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -592,13 +634,15 @@ func (this *NodeClusterService) UpdateNodeClusterTOA(ctx context.Context, req *p // TODO 检查权限 - err = models.SharedNodeClusterDAO.UpdateClusterTOA(req.NodeClusterId, req.ToaJSON) + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateClusterTOA(tx, req.NodeClusterId, req.ToaJSON) if err != nil { return nil, err } // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -612,7 +656,10 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithHTTPCachePolicyId if err != nil { return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledNodeClustersWithHTTPCachePolicyId(req.HttpCachePolicyId) + + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledNodeClustersWithHTTPCachePolicyId(tx, req.HttpCachePolicyId) if err != nil { return nil, err } @@ -625,8 +672,11 @@ func (this *NodeClusterService) FindAllEnabledNodeClustersWithHTTPCachePolicyId( if err != nil { return nil, err } + + tx := this.NullTx() + result := []*pb.NodeCluster{} - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledNodeClustersWithHTTPCachePolicyId(req.HttpCachePolicyId) + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledNodeClustersWithHTTPCachePolicyId(tx, req.HttpCachePolicyId) if err != nil { return nil, err } @@ -647,7 +697,10 @@ func (this *NodeClusterService) CountAllEnabledNodeClustersWithHTTPFirewallPolic if err != nil { return nil, err } - count, err := models.SharedNodeClusterDAO.CountAllEnabledNodeClustersWithHTTPFirewallPolicyId(req.HttpFirewallPolicyId) + + tx := this.NullTx() + + count, err := models.SharedNodeClusterDAO.CountAllEnabledNodeClustersWithHTTPFirewallPolicyId(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -660,8 +713,11 @@ func (this *NodeClusterService) FindAllEnabledNodeClustersWithHTTPFirewallPolicy if err != nil { return nil, err } + + tx := this.NullTx() + result := []*pb.NodeCluster{} - clusters, err := models.SharedNodeClusterDAO.FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(req.HttpFirewallPolicyId) + clusters, err := models.SharedNodeClusterDAO.FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(tx, req.HttpFirewallPolicyId) if err != nil { return nil, err } @@ -682,13 +738,16 @@ func (this *NodeClusterService) UpdateNodeClusterHTTPCachePolicyId(ctx context.C if err != nil { return nil, err } - err = models.SharedNodeClusterDAO.UpdateNodeClusterHTTPCachePolicyId(req.NodeClusterId, req.HttpCachePolicyId) + + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateNodeClusterHTTPCachePolicyId(tx, req.NodeClusterId, req.HttpCachePolicyId) if err != nil { return nil, err } // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -702,13 +761,16 @@ func (this *NodeClusterService) UpdateNodeClusterHTTPFirewallPolicyId(ctx contex if err != nil { return nil, err } - err = models.SharedNodeClusterDAO.UpdateNodeClusterHTTPFirewallPolicyId(req.NodeClusterId, req.HttpFirewallPolicyId) + + tx := this.NullTx() + + err = models.SharedNodeClusterDAO.UpdateNodeClusterHTTPFirewallPolicyId(tx, req.NodeClusterId, req.HttpFirewallPolicyId) if err != nil { return nil, err } // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_grant.go b/internal/rpc/services/service_node_grant.go index 52e387bc..d94df653 100644 --- a/internal/rpc/services/service_node_grant.go +++ b/internal/rpc/services/service_node_grant.go @@ -18,7 +18,9 @@ func (this *NodeGrantService) CreateNodeGrant(ctx context.Context, req *pb.Creat return nil, err } - grantId, err := models.SharedNodeGrantDAO.CreateGrant(adminId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) + tx := this.NullTx() + + grantId, err := models.SharedNodeGrantDAO.CreateGrant(tx, adminId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) if err != nil { return nil, err } @@ -37,7 +39,9 @@ func (this *NodeGrantService) UpdateNodeGrant(ctx context.Context, req *pb.Updat return nil, errors.New("wrong grantId") } - err = models.SharedNodeGrantDAO.UpdateGrant(req.GrantId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) + tx := this.NullTx() + + err = models.SharedNodeGrantDAO.UpdateGrant(tx, req.GrantId, req.Name, req.Method, req.Username, req.Password, req.PrivateKey, req.Description, req.NodeId) return this.Success() } @@ -47,7 +51,9 @@ func (this *NodeGrantService) DisableNodeGrant(ctx context.Context, req *pb.Disa return nil, err } - err = models.SharedNodeGrantDAO.DisableNodeGrant(req.GrantId) + tx := this.NullTx() + + err = models.SharedNodeGrantDAO.DisableNodeGrant(tx, req.GrantId) return &pb.DisableNodeGrantResponse{}, err } @@ -56,7 +62,10 @@ func (this *NodeGrantService) CountAllEnabledNodeGrants(ctx context.Context, req if err != nil { return nil, err } - count, err := models.SharedNodeGrantDAO.CountAllEnabledGrants() + + tx := this.NullTx() + + count, err := models.SharedNodeGrantDAO.CountAllEnabledGrants(tx) if err != nil { return nil, err } @@ -68,7 +77,10 @@ func (this *NodeGrantService) ListEnabledNodeGrants(ctx context.Context, req *pb if err != nil { return nil, err } - grants, err := models.SharedNodeGrantDAO.ListEnabledGrants(req.Offset, req.Size) + + tx := this.NullTx() + + grants, err := models.SharedNodeGrantDAO.ListEnabledGrants(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -95,7 +107,7 @@ func (this *NodeGrantService) FindAllEnabledNodeGrants(ctx context.Context, req if err != nil { return nil, err } - grants, err := models.SharedNodeGrantDAO.FindAllEnabledGrants() + grants, err := models.SharedNodeGrantDAO.FindAllEnabledGrants(this.NullTx()) if err != nil { return nil, err } @@ -122,7 +134,7 @@ func (this *NodeGrantService) FindEnabledGrant(ctx context.Context, req *pb.Find return nil, err } - grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(req.GrantId) + grant, err := models.SharedNodeGrantDAO.FindEnabledNodeGrant(this.NullTx(), req.GrantId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_group.go b/internal/rpc/services/service_node_group.go index a3c62ef7..f13716cf 100644 --- a/internal/rpc/services/service_node_group.go +++ b/internal/rpc/services/service_node_group.go @@ -20,7 +20,9 @@ func (this *NodeGroupService) CreateNodeGroup(ctx context.Context, req *pb.Creat return nil, err } - groupId, err := models.SharedNodeGroupDAO.CreateNodeGroup(req.NodeClusterId, req.Name) + tx := this.NullTx() + + groupId, err := models.SharedNodeGroupDAO.CreateNodeGroup(tx, req.NodeClusterId, req.Name) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *NodeGroupService) UpdateNodeGroup(ctx context.Context, req *pb.Updat return nil, err } - err = models.SharedNodeGroupDAO.UpdateNodeGroup(req.GroupId, req.Name) + tx := this.NullTx() + + err = models.SharedNodeGroupDAO.UpdateNodeGroup(tx, req.GroupId, req.Name) if err != nil { return nil, err } @@ -51,7 +55,9 @@ func (this *NodeGroupService) DeleteNodeGroup(ctx context.Context, req *pb.Delet return nil, err } - _, err = models.SharedNodeGroupDAO.DisableNodeGroup(req.GroupId) + tx := this.NullTx() + + _, err = models.SharedNodeGroupDAO.DisableNodeGroup(tx, req.GroupId) if err != nil { return nil, err } @@ -67,7 +73,9 @@ func (this *NodeGroupService) FindAllEnabledNodeGroupsWithClusterId(ctx context. return nil, err } - groups, err := models.SharedNodeGroupDAO.FindAllEnabledGroupsWithClusterId(req.NodeClusterId) + tx := this.NullTx() + + groups, err := models.SharedNodeGroupDAO.FindAllEnabledGroupsWithClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -89,7 +97,9 @@ func (this *NodeGroupService) UpdateNodeGroupOrders(ctx context.Context, req *pb return nil, err } - err = models.SharedNodeGroupDAO.UpdateGroupOrders(req.GroupIds) + tx := this.NullTx() + + err = models.SharedNodeGroupDAO.UpdateGroupOrders(tx, req.GroupIds) if err != nil { return nil, err } @@ -104,7 +114,9 @@ func (this *NodeGroupService) FindEnabledNodeGroup(ctx context.Context, req *pb. return nil, err } - group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(req.GroupId) + tx := this.NullTx() + + group, err := models.SharedNodeGroupDAO.FindEnabledNodeGroup(tx, req.GroupId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_ip_address.go b/internal/rpc/services/service_node_ip_address.go index f7f92bca..038478e9 100644 --- a/internal/rpc/services/service_node_ip_address.go +++ b/internal/rpc/services/service_node_ip_address.go @@ -19,7 +19,9 @@ func (this *NodeIPAddressService) CreateNodeIPAddress(ctx context.Context, req * return nil, err } - addressId, err := models.SharedNodeIPAddressDAO.CreateAddress(req.NodeId, req.Name, req.Ip, req.CanAccess) + tx := this.NullTx() + + addressId, err := models.SharedNodeIPAddressDAO.CreateAddress(tx, req.NodeId, req.Name, req.Ip, req.CanAccess) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *NodeIPAddressService) UpdateNodeIPAddress(ctx context.Context, req * return nil, err } - err = models.SharedNodeIPAddressDAO.UpdateAddress(req.AddressId, req.Name, req.Ip, req.CanAccess) + tx := this.NullTx() + + err = models.SharedNodeIPAddressDAO.UpdateAddress(tx, req.AddressId, req.Name, req.Ip, req.CanAccess) if err != nil { return nil, err } @@ -51,7 +55,9 @@ func (this *NodeIPAddressService) UpdateNodeIPAddressNodeId(ctx context.Context, return nil, err } - err = models.SharedNodeIPAddressDAO.UpdateAddressNodeId(req.AddressId, req.NodeId) + tx := this.NullTx() + + err = models.SharedNodeIPAddressDAO.UpdateAddressNodeId(tx, req.AddressId, req.NodeId) if err != nil { return nil, err } @@ -67,7 +73,9 @@ func (this *NodeIPAddressService) DisableNodeIPAddress(ctx context.Context, req return nil, err } - err = models.SharedNodeIPAddressDAO.DisableAddress(req.AddressId) + tx := this.NullTx() + + err = models.SharedNodeIPAddressDAO.DisableAddress(tx, req.AddressId) if err != nil { return nil, err } @@ -83,7 +91,9 @@ func (this *NodeIPAddressService) DisableAllIPAddressesWithNodeId(ctx context.Co return nil, err } - err = models.SharedNodeIPAddressDAO.DisableAllAddressesWithNodeId(req.NodeId) + tx := this.NullTx() + + err = models.SharedNodeIPAddressDAO.DisableAllAddressesWithNodeId(tx, req.NodeId) if err != nil { return nil, err } @@ -99,7 +109,9 @@ func (this *NodeIPAddressService) FindEnabledNodeIPAddress(ctx context.Context, return nil, err } - address, err := models.SharedNodeIPAddressDAO.FindEnabledAddress(req.AddressId) + tx := this.NullTx() + + address, err := models.SharedNodeIPAddressDAO.FindEnabledAddress(tx, req.AddressId) if err != nil { return nil, err } @@ -129,7 +141,9 @@ func (this *NodeIPAddressService) FindAllEnabledIPAddressesWithNodeId(ctx contex return nil, err } - addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(req.NodeId) + tx := this.NullTx() + + addresses, err := models.SharedNodeIPAddressDAO.FindAllEnabledAddressesWithNode(tx, req.NodeId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_log.go b/internal/rpc/services/service_node_log.go index 0c3860d8..b9ebf54b 100644 --- a/internal/rpc/services/service_node_log.go +++ b/internal/rpc/services/service_node_log.go @@ -19,8 +19,10 @@ func (this *NodeLogService) CreateNodeLogs(ctx context.Context, req *pb.CreateNo return nil, err } + tx := this.NullTx() + for _, nodeLog := range req.NodeLogs { - err := models.SharedNodeLogDAO.CreateLog(nodeLog.Role, nodeLog.NodeId, nodeLog.Level, nodeLog.Tag, nodeLog.Description, nodeLog.CreatedAt) + err := models.SharedNodeLogDAO.CreateLog(tx, nodeLog.Role, nodeLog.NodeId, nodeLog.Level, nodeLog.Tag, nodeLog.Description, nodeLog.CreatedAt) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *NodeLogService) CountNodeLogs(ctx context.Context, req *pb.CountNode return nil, err } - count, err := models.SharedNodeLogDAO.CountNodeLogs(req.Role, req.NodeId) + tx := this.NullTx() + + count, err := models.SharedNodeLogDAO.CountNodeLogs(tx, req.Role, req.NodeId) if err != nil { return nil, err } @@ -49,7 +53,9 @@ func (this *NodeLogService) ListNodeLogs(ctx context.Context, req *pb.ListNodeLo return nil, err } - logs, err := models.SharedNodeLogDAO.ListNodeLogs(req.Role, req.NodeId, req.Offset, req.Size) + tx := this.NullTx() + + logs, err := models.SharedNodeLogDAO.ListNodeLogs(tx, req.Role, req.NodeId, req.Offset, req.Size) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_price_item.go b/internal/rpc/services/service_node_price_item.go index c96757a3..29c93b4a 100644 --- a/internal/rpc/services/service_node_price_item.go +++ b/internal/rpc/services/service_node_price_item.go @@ -18,7 +18,9 @@ func (this *NodePriceItemService) CreateNodePriceItem(ctx context.Context, req * return nil, err } - itemId, err := models.SharedNodePriceItemDAO.CreateItem(req.Name, req.Type, req.BitsFrom, req.BitsTo) + tx := this.NullTx() + + itemId, err := models.SharedNodePriceItemDAO.CreateItem(tx, req.Name, req.Type, req.BitsFrom, req.BitsTo) if err != nil { return nil, err } @@ -32,7 +34,9 @@ func (this *NodePriceItemService) UpdateNodePriceItem(ctx context.Context, req * return nil, err } - err = models.SharedNodePriceItemDAO.UpdateItem(req.NodePriceItemId, req.Name, req.BitsFrom, req.BitsTo) + tx := this.NullTx() + + err = models.SharedNodePriceItemDAO.UpdateItem(tx, req.NodePriceItemId, req.Name, req.BitsFrom, req.BitsTo) if err != nil { return nil, err } @@ -46,7 +50,9 @@ func (this *NodePriceItemService) DeleteNodePriceItem(ctx context.Context, req * return nil, err } - err = models.SharedNodePriceItemDAO.DisableNodePriceItem(req.NodePriceItemId) + tx := this.NullTx() + + err = models.SharedNodePriceItemDAO.DisableNodePriceItem(tx, req.NodePriceItemId) if err != nil { return nil, err } @@ -60,7 +66,9 @@ func (this *NodePriceItemService) FindAllEnabledNodePriceItems(ctx context.Conte return nil, err } - prices, err := models.SharedNodePriceItemDAO.FindAllEnabledRegionPrices(req.Type) + tx := this.NullTx() + + prices, err := models.SharedNodePriceItemDAO.FindAllEnabledRegionPrices(tx, req.Type) if err != nil { return nil, err } @@ -86,7 +94,9 @@ func (this *NodePriceItemService) FindAllEnabledAndOnNodePriceItems(ctx context. return nil, err } - prices, err := models.SharedNodePriceItemDAO.FindAllEnabledAndOnRegionPrices(req.Type) + tx := this.NullTx() + + prices, err := models.SharedNodePriceItemDAO.FindAllEnabledAndOnRegionPrices(tx, req.Type) if err != nil { return nil, err } @@ -112,7 +122,9 @@ func (this *NodePriceItemService) FindEnabledNodePriceItem(ctx context.Context, return nil, err } - price, err := models.SharedNodePriceItemDAO.FindEnabledNodePriceItem(req.NodePriceItemId) + tx := this.NullTx() + + price, err := models.SharedNodePriceItemDAO.FindEnabledNodePriceItem(tx, req.NodePriceItemId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_region.go b/internal/rpc/services/service_node_region.go index 901b58fa..9f9f3d42 100644 --- a/internal/rpc/services/service_node_region.go +++ b/internal/rpc/services/service_node_region.go @@ -17,7 +17,10 @@ func (this *NodeRegionService) CreateNodeRegion(ctx context.Context, req *pb.Cre if err != nil { return nil, err } - regionId, err := models.SharedNodeRegionDAO.CreateRegion(adminId, req.Name, req.Description) + + tx := this.NullTx() + + regionId, err := models.SharedNodeRegionDAO.CreateRegion(tx, adminId, req.Name, req.Description) if err != nil { return nil, err } @@ -30,7 +33,10 @@ func (this *NodeRegionService) UpdateNodeRegion(ctx context.Context, req *pb.Upd if err != nil { return nil, err } - err = models.SharedNodeRegionDAO.UpdateRegion(req.NodeRegionId, req.Name, req.Description, req.IsOn) + + tx := this.NullTx() + + err = models.SharedNodeRegionDAO.UpdateRegion(tx, req.NodeRegionId, req.Name, req.Description, req.IsOn) if err != nil { return nil, err } @@ -43,7 +49,10 @@ func (this *NodeRegionService) DeleteNodeRegion(ctx context.Context, req *pb.Del if err != nil { return nil, err } - err = models.SharedNodeRegionDAO.DisableNodeRegion(req.NodeRegionId) + + tx := this.NullTx() + + err = models.SharedNodeRegionDAO.DisableNodeRegion(tx, req.NodeRegionId) if err != nil { return nil, err } @@ -56,7 +65,10 @@ func (this *NodeRegionService) FindAllEnabledNodeRegions(ctx context.Context, re if err != nil { return nil, err } - regions, err := models.SharedNodeRegionDAO.FindAllEnabledRegions() + + tx := this.NullTx() + + regions, err := models.SharedNodeRegionDAO.FindAllEnabledRegions(tx) if err != nil { return nil, err } @@ -79,7 +91,10 @@ func (this *NodeRegionService) FindAllEnabledAndOnNodeRegions(ctx context.Contex if err != nil { return nil, err } - regions, err := models.SharedNodeRegionDAO.FindAllEnabledAndOnRegions() + + tx := this.NullTx() + + regions, err := models.SharedNodeRegionDAO.FindAllEnabledAndOnRegions(tx) if err != nil { return nil, err } @@ -102,7 +117,10 @@ func (this *NodeRegionService) UpdateNodeRegionOrders(ctx context.Context, req * if err != nil { return nil, err } - err = models.SharedNodeRegionDAO.UpdateRegionOrders(req.NodeRegionIds) + + tx := this.NullTx() + + err = models.SharedNodeRegionDAO.UpdateRegionOrders(tx, req.NodeRegionIds) if err != nil { return nil, err } @@ -115,7 +133,10 @@ func (this *NodeRegionService) FindEnabledNodeRegion(ctx context.Context, req *p if err != nil { return nil, err } - region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(req.NodeRegionId) + + tx := this.NullTx() + + region, err := models.SharedNodeRegionDAO.FindEnabledNodeRegion(tx, req.NodeRegionId) if err != nil { return nil, err } @@ -137,7 +158,10 @@ func (this *NodeRegionService) UpdateNodeRegionPrice(ctx context.Context, req *p if err != nil { return nil, err } - err = models.SharedNodeRegionDAO.UpdateRegionItemPrice(req.NodeRegionId, req.NodeItemId, req.Price) + + tx := this.NullTx() + + err = models.SharedNodeRegionDAO.UpdateRegionItemPrice(tx, req.NodeRegionId, req.NodeItemId, req.Price) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_stream.go b/internal/rpc/services/service_node_stream.go index c3e9db30..6d8585f7 100644 --- a/internal/rpc/services/service_node_stream.go +++ b/internal/rpc/services/service_node_stream.go @@ -96,23 +96,25 @@ func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) erro logs.Println("[RPC]accepted node '" + numberutils.FormatInt64(nodeId) + "' connection") + tx := this.NullTx() + // 标记为活跃状态 - oldIsActive, err := models.SharedNodeDAO.FindNodeActive(nodeId) + oldIsActive, err := models.SharedNodeDAO.FindNodeActive(tx, nodeId) if err != nil { return err } if !oldIsActive { - err = models.SharedNodeDAO.UpdateNodeActive(nodeId, true) + err = models.SharedNodeDAO.UpdateNodeActive(tx, nodeId, true) if err != nil { return err } // 发送恢复消息 - clusterId, err := models.SharedNodeDAO.FindNodeClusterId(nodeId) + clusterId, err := models.SharedNodeDAO.FindNodeClusterId(tx, nodeId) if err != nil { return err } - err = models.SharedMessageDAO.CreateNodeMessage(clusterId, nodeId, models.MessageTypeNodeActive, models.MessageLevelSuccess, "节点已经恢复在线", nil) + err = models.SharedMessageDAO.CreateNodeMessage(tx, clusterId, nodeId, models.MessageTypeNodeActive, models.MessageLevelSuccess, "节点已经恢复在线", nil) if err != nil { return err } @@ -160,7 +162,7 @@ func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) erro req, err := server.Recv() if err != nil { // 修改节点状态 - err1 := models.SharedNodeDAO.UpdateNodeIsActive(nodeId, false) + err1 := models.SharedNodeDAO.UpdateNodeIsActive(tx, nodeId, false) if err1 != nil { logs.Println(err1.Error()) } diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 537107ac..0012ce8c 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -29,7 +29,10 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi "portRange": req.Addr.PortRange, "host": req.Addr.Host, } - originId, err := models.SharedOriginDAO.CreateOrigin(adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) + + tx := this.NullTx() + + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) if err != nil { return nil, err } @@ -55,7 +58,10 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi "portRange": req.Addr.PortRange, "host": req.Addr.Host, } - err = models.SharedOriginDAO.UpdateOrigin(req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) + + tx := this.NullTx() + + err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) if err != nil { return nil, err } @@ -74,7 +80,9 @@ func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEn // TODO 校验权限 } - origin, err := models.SharedOriginDAO.FindEnabledOrigin(req.OriginId) + tx := this.NullTx() + + origin, err := models.SharedOriginDAO.FindEnabledOrigin(tx, req.OriginId) if err != nil { return nil, err } @@ -113,7 +121,9 @@ func (this *OriginService) FindEnabledOriginConfig(ctx context.Context, req *pb. // TODO 校验权限 } - config, err := models.SharedOriginDAO.ComposeOriginConfig(req.OriginId) + tx := this.NullTx() + + config, err := models.SharedOriginDAO.ComposeOriginConfig(tx, req.OriginId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_region_country.go b/internal/rpc/services/service_region_country.go index 6fcc4a0f..2d7ed203 100644 --- a/internal/rpc/services/service_region_country.go +++ b/internal/rpc/services/service_region_country.go @@ -10,6 +10,7 @@ import ( // 国家相关服务 type RegionCountryService struct { + BaseService } // 查找所有的国家列表 @@ -20,7 +21,9 @@ func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Cont return nil, err } - countries, err := models.SharedRegionCountryDAO.FindAllEnabledCountriesOrderByPinyin() + tx := this.NullTx() + + countries, err := models.SharedRegionCountryDAO.FindAllEnabledCountriesOrderByPinyin(tx) if err != nil { return nil, err } @@ -56,7 +59,9 @@ func (this *RegionCountryService) FindEnabledRegionCountry(ctx context.Context, return nil, err } - country, err := models.SharedRegionCountryDAO.FindEnabledRegionCountry(req.CountryId) + tx := this.NullTx() + + country, err := models.SharedRegionCountryDAO.FindEnabledRegionCountry(tx, req.CountryId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_region_province.go b/internal/rpc/services/service_region_province.go index 403944ad..5caeef86 100644 --- a/internal/rpc/services/service_region_province.go +++ b/internal/rpc/services/service_region_province.go @@ -9,6 +9,7 @@ import ( // 省份相关服务 type RegionProvinceService struct { + BaseService } // 查找所有省份 @@ -19,7 +20,9 @@ func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ct return nil, err } - provinces, err := models.SharedRegionProvinceDAO.FindAllEnabledProvincesWithCountryId(req.CountryId) + tx := this.NullTx() + + provinces, err := models.SharedRegionProvinceDAO.FindAllEnabledProvincesWithCountryId(tx, req.CountryId) if err != nil { return nil, err } @@ -44,7 +47,10 @@ func (this *RegionProvinceService) FindEnabledRegionProvince(ctx context.Context if err != nil { return nil, err } - province, err := models.SharedRegionProvinceDAO.FindEnabledRegionProvince(req.ProvinceId) + + tx := this.NullTx() + + province, err := models.SharedRegionProvinceDAO.FindEnabledRegionProvince(tx, req.ProvinceId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_reverse_proxy.go b/internal/rpc/services/service_reverse_proxy.go index b7253525..786ff020 100644 --- a/internal/rpc/services/service_reverse_proxy.go +++ b/internal/rpc/services/service_reverse_proxy.go @@ -24,7 +24,9 @@ func (this *ReverseProxyService) CreateReverseProxy(ctx context.Context, req *pb // TODO 校验源站 } - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON) + tx := this.NullTx() + + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, userId, req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON) if err != nil { return nil, err } @@ -44,7 +46,9 @@ func (this *ReverseProxyService) FindEnabledReverseProxy(ctx context.Context, re // TODO 检查权限 } - reverseProxy, err := models.SharedReverseProxyDAO.FindEnabledReverseProxy(req.ReverseProxyId) + tx := this.NullTx() + + reverseProxy, err := models.SharedReverseProxyDAO.FindEnabledReverseProxy(tx, req.ReverseProxyId) if err != nil { return nil, err } @@ -73,7 +77,9 @@ func (this *ReverseProxyService) FindEnabledReverseProxyConfig(ctx context.Conte // TODO 检查权限 } - config, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(req.ReverseProxyId) + tx := this.NullTx() + + config, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, req.ReverseProxyId) if err != nil { return nil, err } @@ -98,7 +104,9 @@ func (this *ReverseProxyService) UpdateReverseProxyScheduling(ctx context.Contex // TODO 检查权限 } - err = models.SharedReverseProxyDAO.UpdateReverseProxyScheduling(req.ReverseProxyId, req.SchedulingJSON) + tx := this.NullTx() + + err = models.SharedReverseProxyDAO.UpdateReverseProxyScheduling(tx, req.ReverseProxyId, req.SchedulingJSON) if err != nil { return nil, err } @@ -118,7 +126,9 @@ func (this *ReverseProxyService) UpdateReverseProxyPrimaryOrigins(ctx context.Co // TODO 检查权限 } - err = models.SharedReverseProxyDAO.UpdateReverseProxyPrimaryOrigins(req.ReverseProxyId, req.OriginsJSON) + tx := this.NullTx() + + err = models.SharedReverseProxyDAO.UpdateReverseProxyPrimaryOrigins(tx, req.ReverseProxyId, req.OriginsJSON) if err != nil { return nil, err } @@ -138,7 +148,9 @@ func (this *ReverseProxyService) UpdateReverseProxyBackupOrigins(ctx context.Con // TODO 检查权限 } - err = models.SharedReverseProxyDAO.UpdateReverseProxyBackupOrigins(req.ReverseProxyId, req.OriginsJSON) + tx := this.NullTx() + + err = models.SharedReverseProxyDAO.UpdateReverseProxyBackupOrigins(tx, req.ReverseProxyId, req.OriginsJSON) if err != nil { return nil, err } @@ -158,7 +170,9 @@ func (this *ReverseProxyService) UpdateReverseProxy(ctx context.Context, req *pb // TODO 检查权限 } - err = models.SharedReverseProxyDAO.UpdateReverseProxy(req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush) + tx := this.NullTx() + + err = models.SharedReverseProxyDAO.UpdateReverseProxy(tx, req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index a8990f8c..c5bd1768 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -25,6 +25,8 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } + tx := this.NullTx() + // 校验用户相关数据 if userId > 0 { // HTTPS @@ -35,7 +37,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(httpsConfig.SSLPolicyRef.SSLPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, httpsConfig.SSLPolicyRef.SSLPolicyId, userId) if err != nil { return nil, err } @@ -48,7 +50,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe serverNamesJSON := req.ServerNamesJON auditingServerNamesJSON := []byte("[]") if userId > 0 { - globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig() + globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) if err != nil { return nil, err } @@ -59,13 +61,13 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe } } - serverId, err := models.SharedServerDAO.CreateServer(req.AdminId, req.UserId, req.Type, req.Name, req.Description, serverNamesJSON, isAuditing, auditingServerNamesJSON, string(req.HttpJSON), string(req.HttpsJSON), string(req.TcpJSON), string(req.TlsJSON), string(req.UnixJSON), string(req.UdpJSON), req.WebId, req.ReverseProxyJSON, req.NodeClusterId, string(req.IncludeNodesJSON), string(req.ExcludeNodesJSON), req.GroupIds) + serverId, err := models.SharedServerDAO.CreateServer(tx, req.AdminId, req.UserId, req.Type, req.Name, req.Description, serverNamesJSON, isAuditing, auditingServerNamesJSON, string(req.HttpJSON), string(req.HttpsJSON), string(req.TcpJSON), string(req.TlsJSON), string(req.UnixJSON), string(req.UdpJSON), req.WebId, req.ReverseProxyJSON, req.NodeClusterId, string(req.IncludeNodesJSON), string(req.ExcludeNodesJSON), req.GroupIds) if err != nil { return nil, err } // 更新节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -85,8 +87,10 @@ func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.Update return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 查询老的节点信息 - server, err := models.SharedServerDAO.FindEnabledServer(req.ServerId) + server, err := models.SharedServerDAO.FindEnabledServer(tx, req.ServerId) if err != nil { return nil, err } @@ -94,7 +98,7 @@ func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.Update return nil, errors.New("can not find server") } - err = models.SharedServerDAO.UpdateServerBasic(req.ServerId, req.Name, req.Description, req.NodeClusterId, req.IsOn, req.GroupIds) + err = models.SharedServerDAO.UpdateServerBasic(tx, req.ServerId, req.Name, req.Description, req.NodeClusterId, req.IsOn, req.GroupIds) if err != nil { return nil, err } @@ -112,14 +116,14 @@ func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.Update // 更新老的节点版本 if req.NodeClusterId != int64(server.ClusterId) { - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(int64(server.ClusterId)) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, int64(server.ClusterId)) if err != nil { return nil, err } } // 更新新的节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -133,13 +137,16 @@ func (this *ServerService) UpdateServerIsOn(ctx context.Context, req *pb.UpdateS if err != nil { return nil, err } + + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - err = models.SharedServerDAO.UpdateServerIsOn(req.ServerId, req.IsOn) + err = models.SharedServerDAO.UpdateServerIsOn(tx, req.ServerId, req.IsOn) if err != nil { return nil, err } @@ -154,15 +161,17 @@ func (this *ServerService) UpdateServerHTTP(ctx context.Context, req *pb.UpdateS return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } // 修改配置 - err = models.SharedServerDAO.UpdateServerHTTP(req.ServerId, req.HttpJSON) + err = models.SharedServerDAO.UpdateServerHTTP(tx, req.ServerId, req.HttpJSON) if err != nil { return nil, err } @@ -178,15 +187,17 @@ func (this *ServerService) UpdateServerHTTPS(ctx context.Context, req *pb.Update return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } // 修改配置 - err = models.SharedServerDAO.UpdateServerHTTPS(req.ServerId, req.HttpsJSON) + err = models.SharedServerDAO.UpdateServerHTTPS(tx, req.ServerId, req.HttpsJSON) if err != nil { return nil, err } @@ -206,8 +217,10 @@ func (this *ServerService) UpdateServerTCP(ctx context.Context, req *pb.UpdateSe return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 修改配置 - err = models.SharedServerDAO.UpdateServerTCP(req.ServerId, req.TcpJSON) + err = models.SharedServerDAO.UpdateServerTCP(tx, req.ServerId, req.TcpJSON) if err != nil { return nil, err } @@ -227,8 +240,10 @@ func (this *ServerService) UpdateServerTLS(ctx context.Context, req *pb.UpdateSe return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 修改配置 - err = models.SharedServerDAO.UpdateServerTLS(req.ServerId, req.TlsJSON) + err = models.SharedServerDAO.UpdateServerTLS(tx, req.ServerId, req.TlsJSON) if err != nil { return nil, err } @@ -248,8 +263,10 @@ func (this *ServerService) UpdateServerUnix(ctx context.Context, req *pb.UpdateS return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 修改配置 - err = models.SharedServerDAO.UpdateServerUnix(req.ServerId, req.UnixJSON) + err = models.SharedServerDAO.UpdateServerUnix(tx, req.ServerId, req.UnixJSON) if err != nil { return nil, err } @@ -269,8 +286,10 @@ func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateSe return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 修改配置 - err = models.SharedServerDAO.UpdateServerUDP(req.ServerId, req.UdpJSON) + err = models.SharedServerDAO.UpdateServerUDP(tx, req.ServerId, req.UdpJSON) if err != nil { return nil, err } @@ -286,15 +305,17 @@ func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateSe return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } // 修改配置 - err = models.SharedServerDAO.UpdateServerWeb(req.ServerId, req.WebId) + err = models.SharedServerDAO.UpdateServerWeb(tx, req.ServerId, req.WebId) if err != nil { return nil, err } @@ -310,15 +331,17 @@ func (this *ServerService) UpdateServerReverseProxy(ctx context.Context, req *pb return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } // 修改配置 - err = models.SharedServerDAO.UpdateServerReverseProxy(req.ServerId, req.ReverseProxyJSON) + err = models.SharedServerDAO.UpdateServerReverseProxy(tx, req.ServerId, req.ReverseProxyJSON) if err != nil { return nil, err } @@ -333,14 +356,16 @@ func (this *ServerService) FindServerNames(ctx context.Context, req *pb.FindServ return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - serverNamesJSON, isAuditing, auditingServerNamesJSON, auditingResultJSON, err := models.SharedServerDAO.FindServerNames(req.ServerId) + serverNamesJSON, isAuditing, auditingServerNamesJSON, auditingResultJSON, err := models.SharedServerDAO.FindServerNames(tx, req.ServerId) if err != nil { return nil, err } @@ -376,14 +401,16 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update return nil, errors.New("invalid serverId") } + tx := this.NullTx() + // 是否需要审核 if userId > 0 { - globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig() + globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) if err != nil { return nil, err } if globalConfig != nil && globalConfig.HTTPAll.DomainAuditingIsOn { - err = models.SharedServerDAO.UpdateAuditingServerNames(req.ServerId, true, req.ServerNamesJSON) + err = models.SharedServerDAO.UpdateAuditingServerNames(tx, req.ServerId, true, req.ServerNamesJSON) if err != nil { return nil, err } @@ -392,7 +419,7 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update } // 修改配置 - err = models.SharedServerDAO.UpdateServerNames(req.ServerId, req.ServerNamesJSON) + err = models.SharedServerDAO.UpdateServerNames(tx, req.ServerId, req.ServerNamesJSON) if err != nil { return nil, err } @@ -412,23 +439,25 @@ func (this *ServerService) UpdateServerNamesAuditing(ctx context.Context, req *p return nil, errors.New("'result' should not be nil") } - err = models.SharedServerDAO.UpdateServerAuditing(req.ServerId, req.AuditingResult) + tx := this.NullTx() + + err = models.SharedServerDAO.UpdateServerAuditing(tx, req.ServerId, req.AuditingResult) if err != nil { return nil, err } // 发送消息提醒 - _, userId, err := models.SharedServerDAO.FindServerAdminIdAndUserId(req.ServerId) + _, userId, err := models.SharedServerDAO.FindServerAdminIdAndUserId(tx, req.ServerId) if userId > 0 { if req.AuditingResult.IsOk { - err = models.SharedMessageDAO.CreateMessage(0, userId, models.MessageTypeServerNamesAuditingSuccess, models.LevelInfo, "服务域名审核通过", maps.Map{ + err = models.SharedMessageDAO.CreateMessage(tx, 0, userId, models.MessageTypeServerNamesAuditingSuccess, models.LevelInfo, "服务域名审核通过", maps.Map{ "serverId": req.ServerId, }.AsJSON()) if err != nil { return nil, err } } else { - err = models.SharedMessageDAO.CreateMessage(0, userId, models.MessageTypeServerNamesAuditingFailed, models.LevelError, "服务域名审核失败,原因:"+req.AuditingResult.Reason, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(tx, 0, userId, models.MessageTypeServerNamesAuditingFailed, models.LevelError, "服务域名审核失败,原因:"+req.AuditingResult.Reason, maps.Map{ "serverId": req.ServerId, }.AsJSON()) if err != nil { @@ -455,7 +484,10 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req if err != nil { return nil, err } - count, err := models.SharedServerDAO.CountAllEnabledServersMatch(req.GroupId, req.Keyword, req.UserId, req.ClusterId, types.Int8(req.AuditingFlag)) + + tx := this.NullTx() + + count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.GroupId, req.Keyword, req.UserId, req.ClusterId, types.Int8(req.AuditingFlag)) if err != nil { return nil, err } @@ -470,13 +502,16 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. if err != nil { return nil, err } - servers, err := models.SharedServerDAO.ListEnabledServersMatch(req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId, req.ClusterId, req.AuditingFlag) + + tx := this.NullTx() + + servers, err := models.SharedServerDAO.ListEnabledServersMatch(tx, req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId, req.ClusterId, req.AuditingFlag) if err != nil { return nil, err } result := []*pb.Server{} for _, server := range servers { - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(server.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(server.ClusterId)) if err != nil { return nil, err } @@ -490,7 +525,7 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. return nil, err } for _, groupId := range groupIds { - group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(groupId) + group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(tx, groupId) if err != nil { return nil, err } @@ -505,7 +540,7 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. } // 用户 - user, err := models.SharedUserDAO.FindEnabledBasicUser(int64(server.UserId)) + user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(server.UserId)) if err != nil { return nil, err } @@ -569,15 +604,17 @@ func (this *ServerService) DeleteServer(ctx context.Context, req *pb.DeleteServe return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } // 查找服务 - server, err := models.SharedServerDAO.FindEnabledServer(req.ServerId) + server, err := models.SharedServerDAO.FindEnabledServer(tx, req.ServerId) if err != nil { return nil, err } @@ -586,13 +623,13 @@ func (this *ServerService) DeleteServer(ctx context.Context, req *pb.DeleteServe } // 禁用服务 - err = models.SharedServerDAO.DisableServer(req.ServerId) + err = models.SharedServerDAO.DisableServer(tx, req.ServerId) if err != nil { return nil, err } // 更新节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(int64(server.ClusterId)) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, int64(server.ClusterId)) if err != nil { return nil, err } @@ -608,15 +645,17 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - server, err := models.SharedServerDAO.FindEnabledServer(req.ServerId) + server, err := models.SharedServerDAO.FindEnabledServer(tx, req.ServerId) if err != nil { return nil, err } @@ -626,7 +665,7 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn } // 集群信息 - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(server.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(server.ClusterId)) if err != nil { return nil, err } @@ -640,7 +679,7 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn return nil, err } for _, groupId := range groupIds { - group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(groupId) + group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(tx, groupId) if err != nil { return nil, err } @@ -657,7 +696,7 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn // 用户信息 var pbUser *pb.User = nil if server.UserId > 0 { - user, err := models.SharedUserDAO.FindEnabledBasicUser(int64(server.UserId)) + user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(server.UserId)) if err != nil { return nil, err } @@ -708,15 +747,17 @@ func (this *ServerService) FindEnabledServerConfig(ctx context.Context, req *pb. return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - config, err := models.SharedServerDAO.ComposeServerConfig(req.ServerId) + config, err := models.SharedServerDAO.ComposeServerConfig(tx, req.ServerId) if err != nil { return nil, err } @@ -739,15 +780,17 @@ func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.Fi return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - serverType, err := models.SharedServerDAO.FindEnabledServerType(req.ServerId) + serverType, err := models.SharedServerDAO.FindEnabledServerType(tx, req.ServerId) if err != nil { return nil, err } @@ -763,13 +806,15 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte return nil, err } - reverseProxyRef, err := models.SharedServerDAO.FindReverseProxyRef(req.ServerId) + tx := this.NullTx() + + reverseProxyRef, err := models.SharedServerDAO.FindReverseProxyRef(tx, req.ServerId) if err != nil { return nil, err } if reverseProxyRef == nil { - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil) + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, userId, nil, nil, nil) if err != nil { return nil, err } @@ -782,13 +827,13 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte if err != nil { return nil, err } - err = models.SharedServerDAO.UpdateServerReverseProxy(req.ServerId, refJSON) + err = models.SharedServerDAO.UpdateServerReverseProxy(tx, req.ServerId, refJSON) if err != nil { return nil, err } } - reverseProxyConfig, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(reverseProxyRef.ReverseProxyId) + reverseProxyConfig, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId) if err != nil { return nil, err } @@ -814,26 +859,28 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req * return nil, err } + tx := this.NullTx() + if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } } - webId, err := models.SharedServerDAO.FindServerWebId(req.ServerId) + webId, err := models.SharedServerDAO.FindServerWebId(tx, req.ServerId) if err != nil { return nil, err } if webId == 0 { - webId, err = models.SharedServerDAO.InitServerWeb(req.ServerId) + webId, err = models.SharedServerDAO.InitServerWeb(tx, req.ServerId) if err != nil { return nil, err } } - config, err := models.SharedHTTPWebDAO.ComposeWebConfig(webId) + config, err := models.SharedHTTPWebDAO.ComposeWebConfig(tx, webId) if err != nil { return nil, err } @@ -856,7 +903,9 @@ func (this *ServerService) CountAllEnabledServersWithSSLCertId(ctx context.Conte // TODO 校验权限 } - policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId) + tx := this.NullTx() + + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId) if err != nil { return nil, err } @@ -865,7 +914,7 @@ func (this *ServerService) CountAllEnabledServersWithSSLCertId(ctx context.Conte return this.SuccessCount(0) } - count, err := models.SharedServerDAO.CountAllEnabledServersWithSSLPolicyIds(policyIds) + count, err := models.SharedServerDAO.CountAllEnabledServersWithSSLPolicyIds(tx, policyIds) if err != nil { return nil, err } @@ -885,7 +934,9 @@ func (this *ServerService) FindAllEnabledServersWithSSLCertId(ctx context.Contex // TODO 校验权限 } - policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId) + tx := this.NullTx() + + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId) if err != nil { return nil, err } @@ -893,7 +944,7 @@ func (this *ServerService) FindAllEnabledServersWithSSLCertId(ctx context.Contex return &pb.FindAllEnabledServersWithSSLCertIdResponse{Servers: nil}, nil } - servers, err := models.SharedServerDAO.FindAllEnabledServersWithSSLPolicyIds(policyIds) + servers, err := models.SharedServerDAO.FindAllEnabledServersWithSSLPolicyIds(tx, policyIds) if err != nil { return nil, err } @@ -917,7 +968,9 @@ func (this *ServerService) CountAllEnabledServersWithNodeClusterId(ctx context.C return nil, err } - count, err := models.SharedServerDAO.CountAllEnabledServersWithNodeClusterId(req.NodeClusterId) + tx := this.NullTx() + + count, err := models.SharedServerDAO.CountAllEnabledServersWithNodeClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -932,7 +985,9 @@ func (this *ServerService) CountAllEnabledServersWithGroupId(ctx context.Context return nil, err } - count, err := models.SharedServerDAO.CountAllEnabledServersWithGroupId(req.GroupId) + tx := this.NullTx() + + count, err := models.SharedServerDAO.CountAllEnabledServersWithGroupId(tx, req.GroupId) if err != nil { return nil, err } @@ -947,7 +1002,9 @@ func (this *ServerService) NotifyServersChange(ctx context.Context, req *pb.Noti return nil, err } - err = models.SharedSysEventDAO.CreateEvent(models.NewServerChangeEvent()) + tx := this.NullTx() + + err = models.SharedSysEventDAO.CreateEvent(tx, models.NewServerChangeEvent()) if err != nil { return nil, err } @@ -963,7 +1020,9 @@ func (this *ServerService) FindAllEnabledServersDNSWithClusterId(ctx context.Con return nil, err } - servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(req.NodeClusterId) + tx := this.NullTx() + + servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -972,7 +1031,7 @@ func (this *ServerService) FindAllEnabledServersDNSWithClusterId(ctx context.Con // 如果子域名为空 if len(server.DnsName) == 0 { // 自动生成子域名 - dnsName, err := models.SharedServerDAO.GenerateServerDNSName(int64(server.Id)) + dnsName, err := models.SharedServerDAO.GenerateServerDNSName(tx, int64(server.Id)) if err != nil { return nil, err } @@ -997,25 +1056,27 @@ func (this *ServerService) FindEnabledServerDNS(ctx context.Context, req *pb.Fin return nil, err } - dnsName, err := models.SharedServerDAO.FindServerDNSName(req.ServerId) + tx := this.NullTx() + + dnsName, err := models.SharedServerDAO.FindServerDNSName(tx, req.ServerId) if err != nil { return nil, err } - clusterId, err := models.SharedServerDAO.FindServerClusterId(req.ServerId) + clusterId, err := models.SharedServerDAO.FindServerClusterId(tx, req.ServerId) if err != nil { return nil, err } var pbDomain *pb.DNSDomain = nil if clusterId > 0 { - clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(clusterId) + clusterDNS, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) if err != nil { return nil, err } if clusterDNS != nil { domainId := int64(clusterDNS.DnsDomainId) if domainId > 0 { - domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(domainId) + domain, err := models.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, domainId) if err != nil { return nil, err } @@ -1037,11 +1098,13 @@ func (this *ServerService) FindEnabledServerDNS(ctx context.Context, req *pb.Fin // 自动同步DNS状态 func (this *ServerService) notifyServerDNSChanged(serverId int64) error { - clusterId, err := models.SharedServerDAO.FindServerClusterId(serverId) + tx := this.NullTx() + + clusterId, err := models.SharedServerDAO.FindServerClusterId(tx, serverId) if err != nil { return err } - dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(clusterId) + dnsInfo, err := models.SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) if err != nil { return err } @@ -1069,7 +1132,7 @@ func (this *ServerService) notifyServerDNSChanged(serverId int64) error { return err } if !resp.IsOk { - err = models.SharedMessageDAO.CreateClusterMessage(clusterId, models.MessageTypeClusterDNSSyncFailed, models.LevelError, "集群DNS同步失败:"+resp.Error, nil) + err = models.SharedMessageDAO.CreateClusterMessage(tx, clusterId, models.MessageTypeClusterDNSSyncFailed, models.LevelError, "集群DNS同步失败:"+resp.Error, nil) if err != nil { logs.Println("[NODE_SERVICE]" + err.Error()) } @@ -1083,7 +1146,10 @@ func (this *ServerService) CheckUserServer(ctx context.Context, req *pb.CheckUse if err != nil { return nil, err } - err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + + tx := this.NullTx() + + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) if err != nil { return nil, err } @@ -1097,7 +1163,9 @@ func (this *ServerService) FindAllEnabledServerNamesWithUserId(ctx context.Conte return nil, err } - servers, err := models.SharedServerDAO.FindAllEnabledServersWithUserId(req.UserId) + tx := this.NullTx() + + servers, err := models.SharedServerDAO.FindAllEnabledServersWithUserId(tx, req.UserId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_server_daily_stat.go b/internal/rpc/services/service_server_daily_stat.go index 6efe1aa9..e348b274 100644 --- a/internal/rpc/services/service_server_daily_stat.go +++ b/internal/rpc/services/service_server_daily_stat.go @@ -18,7 +18,9 @@ func (this *ServerDailyStatService) UploadServerDailyStats(ctx context.Context, return nil, err } - err = models.SharedServerDailyStatDAO.SaveStats(req.Stats) + tx := this.NullTx() + + err = models.SharedServerDailyStatDAO.SaveStats(tx, req.Stats) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_server_group.go b/internal/rpc/services/service_server_group.go index 9289c05b..e37bd076 100644 --- a/internal/rpc/services/service_server_group.go +++ b/internal/rpc/services/service_server_group.go @@ -20,7 +20,9 @@ func (this *ServerGroupService) CreateServerGroup(ctx context.Context, req *pb.C return nil, err } - groupId, err := models.SharedServerGroupDAO.CreateGroup(req.Name) + tx := this.NullTx() + + groupId, err := models.SharedServerGroupDAO.CreateGroup(tx, req.Name) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *ServerGroupService) UpdateServerGroup(ctx context.Context, req *pb.U return nil, err } - err = models.SharedServerGroupDAO.UpdateGroup(req.GroupId, req.Name) + tx := this.NullTx() + + err = models.SharedServerGroupDAO.UpdateGroup(tx, req.GroupId, req.Name) if err != nil { return nil, err } @@ -51,7 +55,9 @@ func (this *ServerGroupService) DeleteServerGroup(ctx context.Context, req *pb.D return nil, err } - err = models.SharedServerGroupDAO.DisableServerGroup(req.GroupId) + tx := this.NullTx() + + err = models.SharedServerGroupDAO.DisableServerGroup(tx, req.GroupId) if err != nil { return nil, err } @@ -67,7 +73,9 @@ func (this *ServerGroupService) FindAllEnabledServerGroups(ctx context.Context, return nil, err } - groups, err := models.SharedServerGroupDAO.FindAllEnabledGroups() + tx := this.NullTx() + + groups, err := models.SharedServerGroupDAO.FindAllEnabledGroups(tx) if err != nil { return nil, err } @@ -89,7 +97,9 @@ func (this *ServerGroupService) UpdateServerGroupOrders(ctx context.Context, req return nil, err } - err = models.SharedServerGroupDAO.UpdateGroupOrders(req.GroupIds) + tx := this.NullTx() + + err = models.SharedServerGroupDAO.UpdateGroupOrders(tx, req.GroupIds) if err != nil { return nil, err } @@ -104,7 +114,9 @@ func (this *ServerGroupService) FindEnabledServerGroup(ctx context.Context, req return nil, err } - group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(req.GroupId) + tx := this.NullTx() + + group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(tx, req.GroupId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go index 9347279d..cf162e2a 100644 --- a/internal/rpc/services/service_ssl_cert.go +++ b/internal/rpc/services/service_ssl_cert.go @@ -21,7 +21,9 @@ func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSL return nil, err } - certId, err := models.SharedSSLCertDAO.CreateCert(adminId, userId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) + tx := this.NullTx() + + certId, err := models.SharedSSLCertDAO.CreateCert(tx, adminId, userId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) if err != nil { return nil, err } @@ -37,15 +39,17 @@ func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSL return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + err := models.SharedSSLCertDAO.CheckUserCert(tx, req.SslCertId, userId) if err != nil { return nil, err } } - err = models.SharedSSLCertDAO.UpdateCert(req.SslCertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) + err = models.SharedSSLCertDAO.UpdateCert(tx, req.SslCertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) if err != nil { return nil, err } @@ -61,15 +65,17 @@ func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *p return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + err := models.SharedSSLCertDAO.CheckUserCert(tx, req.SslCertId, userId) if err != nil { return nil, err } } - config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.SslCertId) + config, err := models.SharedSSLCertDAO.ComposeCertConfig(tx, req.SslCertId) if err != nil { return nil, err } @@ -89,21 +95,23 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL return nil, err } + tx := this.NullTx() + // 检查权限 if userId > 0 { - err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + err := models.SharedSSLCertDAO.CheckUserCert(tx, req.SslCertId, userId) if err != nil { return nil, err } } - err = models.SharedSSLCertDAO.DisableSSLCert(req.SslCertId) + err = models.SharedSSLCertDAO.DisableSSLCert(tx, req.SslCertId) if err != nil { return nil, err } // 停止相关ACME任务 - err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.SslCertId) + err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(tx, req.SslCertId) if err != nil { return nil, err } @@ -119,7 +127,9 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC return nil, err } - count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId) + tx := this.NullTx() + + count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId) if err != nil { return nil, err } @@ -135,14 +145,16 @@ func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCer return nil, err } - certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId, req.Offset, req.Size) + tx := this.NullTx() + + certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId, req.Offset, req.Size) if err != nil { return nil, err } certConfigs := []*sslconfigs.SSLCertConfig{} for _, certId := range certIds { - certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(certId) + certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(tx, certId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go index b1f09aaa..f647c6f7 100644 --- a/internal/rpc/services/service_ssl_policy.go +++ b/internal/rpc/services/service_ssl_policy.go @@ -21,6 +21,8 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat return nil, err } + tx := this.NullTx() + if userId > 0 { // 检查证书 if len(req.SslCertsJSON) > 0 { @@ -30,7 +32,7 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat return nil, err } for _, certRef := range certRefs { - err = models.SharedSSLCertDAO.CheckUserCert(certRef.CertId, userId) + err = models.SharedSSLCertDAO.CheckUserCert(tx, certRef.CertId, userId) if err != nil { return nil, err } @@ -41,7 +43,7 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat // TODO } - policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(adminId, userId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, userId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) if err != nil { return nil, err } @@ -56,14 +58,17 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat if err != nil { return nil, err } + + tx := this.NullTx() + if userId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(req.SslPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, req.SslPolicyId, userId) if err != nil { return nil, err } } - err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + err = models.SharedSSLPolicyDAO.UpdatePolicy(tx, req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) if err != nil { return nil, err } @@ -80,7 +85,9 @@ func (this *SSLPolicyService) FindEnabledSSLPolicyConfig(ctx context.Context, re return nil, err } - config, err := models.SharedSSLPolicyDAO.ComposePolicyConfig(req.SslPolicyId) + tx := this.NullTx() + + config, err := models.SharedSSLPolicyDAO.ComposePolicyConfig(tx, req.SslPolicyId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_sys_setting.go b/internal/rpc/services/service_sys_setting.go index fd691b1a..dc749e25 100644 --- a/internal/rpc/services/service_sys_setting.go +++ b/internal/rpc/services/service_sys_setting.go @@ -19,7 +19,9 @@ func (this *SysSettingService) UpdateSysSetting(ctx context.Context, req *pb.Upd return nil, err } - err = models.SharedSysSettingDAO.UpdateSetting(req.Code, req.ValueJSON) + tx := this.NullTx() + + err = models.SharedSysSettingDAO.UpdateSetting(tx, req.Code, req.ValueJSON) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *SysSettingService) ReadSysSetting(ctx context.Context, req *pb.ReadS return nil, err } - valueJSON, err := models.SharedSysSettingDAO.ReadSetting(req.Code) + tx := this.NullTx() + + valueJSON, err := models.SharedSysSettingDAO.ReadSetting(tx, req.Code) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 1088465f..18d40b53 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -24,7 +24,9 @@ func (this *UserService) CreateUser(ctx context.Context, req *pb.CreateUserReque return nil, err } - userId, err := models.SharedUserDAO.CreateUser(req.Username, req.Password, req.Fullname, req.Mobile, req.Tel, req.Email, req.Remark, req.Source, req.NodeClusterId) + tx := this.NullTx() + + userId, err := models.SharedUserDAO.CreateUser(tx, req.Username, req.Password, req.Fullname, req.Mobile, req.Tel, req.Email, req.Remark, req.Source, req.NodeClusterId) if err != nil { return nil, err } @@ -38,28 +40,30 @@ func (this *UserService) UpdateUser(ctx context.Context, req *pb.UpdateUserReque return nil, err } - oldClusterId, err := models.SharedUserDAO.FindUserClusterId(req.UserId) + tx := this.NullTx() + + oldClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, req.UserId) if err != nil { return nil, err } - err = models.SharedUserDAO.UpdateUser(req.UserId, req.Username, req.Password, req.Fullname, req.Mobile, req.Tel, req.Email, req.Remark, req.IsOn, req.NodeClusterId) + err = models.SharedUserDAO.UpdateUser(tx, req.UserId, req.Username, req.Password, req.Fullname, req.Mobile, req.Tel, req.Email, req.Remark, req.IsOn, req.NodeClusterId) if err != nil { return nil, err } if oldClusterId != req.NodeClusterId { - err = models.SharedServerDAO.UpdateUserServersClusterId(req.UserId, req.NodeClusterId) + err = models.SharedServerDAO.UpdateUserServersClusterId(tx, req.UserId, req.NodeClusterId) if err != nil { return nil, err } - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(oldClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, oldClusterId) if err != nil { return nil, err } - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(req.NodeClusterId) + err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) if err != nil { return nil, err } @@ -75,7 +79,9 @@ func (this *UserService) DeleteUser(ctx context.Context, req *pb.DeleteUserReque return nil, err } - _, err = models.SharedUserDAO.DisableUser(req.UserId) + tx := this.NullTx() + + _, err = models.SharedUserDAO.DisableUser(tx, req.UserId) if err != nil { return nil, err } @@ -89,7 +95,9 @@ func (this *UserService) CountAllEnabledUsers(ctx context.Context, req *pb.Count return nil, err } - count, err := models.SharedUserDAO.CountAllEnabledUsers(req.Keyword) + tx := this.NullTx() + + count, err := models.SharedUserDAO.CountAllEnabledUsers(tx, req.Keyword) if err != nil { return nil, err } @@ -103,7 +111,9 @@ func (this *UserService) ListEnabledUsers(ctx context.Context, req *pb.ListEnabl return nil, err } - users, err := models.SharedUserDAO.ListEnabledUsers(req.Keyword) + tx := this.NullTx() + + users, err := models.SharedUserDAO.ListEnabledUsers(tx, req.Keyword) if err != nil { return nil, err } @@ -113,7 +123,7 @@ func (this *UserService) ListEnabledUsers(ctx context.Context, req *pb.ListEnabl // 集群信息 var pbCluster *pb.NodeCluster = nil if user.ClusterId > 0 { - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(user.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(user.ClusterId)) if err != nil { return nil, err } @@ -147,7 +157,9 @@ func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnable return nil, err } - user, err := models.SharedUserDAO.FindEnabledUser(req.UserId) + tx := this.NullTx() + + user, err := models.SharedUserDAO.FindEnabledUser(tx, req.UserId) if err != nil { return nil, err } @@ -158,7 +170,7 @@ func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnable // 集群信息 var pbCluster *pb.NodeCluster = nil if user.ClusterId > 0 { - clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(int64(user.ClusterId)) + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(user.ClusterId)) if err != nil { return nil, err } @@ -194,7 +206,9 @@ func (this *UserService) CheckUserUsername(ctx context.Context, req *pb.CheckUse return nil, this.PermissionError() } - b, err := models.SharedUserDAO.ExistUser(req.UserId, req.Username) + tx := this.NullTx() + + b, err := models.SharedUserDAO.ExistUser(tx, req.UserId, req.Username) if err != nil { return nil, err } @@ -216,7 +230,9 @@ func (this *UserService) LoginUser(ctx context.Context, req *pb.LoginUserRequest }, nil } - userId, err := models.SharedUserDAO.CheckUserPassword(req.Username, req.Password) + tx := this.NullTx() + + userId, err := models.SharedUserDAO.CheckUserPassword(tx, req.Username, req.Password) if err != nil { utils.PrintError(err) return nil, err @@ -247,7 +263,9 @@ func (this *UserService) UpdateUserInfo(ctx context.Context, req *pb.UpdateUserI return nil, this.PermissionError() } - err = models.SharedUserDAO.UpdateUserInfo(req.UserId, req.Fullname) + tx := this.NullTx() + + err = models.SharedUserDAO.UpdateUserInfo(tx, req.UserId, req.Fullname) if err != nil { return nil, err } @@ -265,7 +283,9 @@ func (this *UserService) UpdateUserLogin(ctx context.Context, req *pb.UpdateUser return nil, this.PermissionError() } - err = models.SharedUserDAO.UpdateUserLogin(req.UserId, req.Username, req.Password) + tx := this.NullTx() + + err = models.SharedUserDAO.UpdateUserLogin(tx, req.UserId, req.Username, req.Password) if err != nil { return nil, err } @@ -283,34 +303,36 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo return nil, this.PermissionError() } + tx := this.NullTx() + // 网站数量 - countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(0, "", req.UserId, 0, configutils.BoolStateAll) + countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll) if err != nil { return nil, err } // 本月总流量 month := timeutil.Format("Ym") - monthlyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserMonthly(req.UserId, 0, month) + monthlyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserMonthly(tx, req.UserId, 0, month) if err != nil { return nil, err } // 本月带宽峰值 - monthlyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserMonthlyPeek(req.UserId, 0, month) + monthlyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserMonthlyPeek(tx, req.UserId, 0, month) if err != nil { return nil, err } // 今日总流量 day := timeutil.Format("Ymd") - dailyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDaily(req.UserId, 0, day) + dailyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDaily(tx, req.UserId, 0, day) if err != nil { return nil, err } // 今日带宽峰值 - dailyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDailyPeek(req.UserId, 0, day) + dailyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDailyPeek(tx, req.UserId, 0, day) if err != nil { return nil, err } @@ -322,12 +344,12 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo for i := 14; i >= 0; i-- { day := timeutil.Format("Ymd", time.Now().AddDate(0, 0, -i)) - dailyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDaily(req.UserId, 0, day) + dailyTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDaily(tx, req.UserId, 0, day) if err != nil { return nil, err } - dailyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDailyPeek(req.UserId, 0, day) + dailyPeekTrafficBytes, err := models.SharedServerDailyStatDAO.SumUserDailyPeek(tx, req.UserId, 0, day) if err != nil { return nil, err } @@ -354,7 +376,9 @@ func (this *UserService) FindUserNodeClusterId(ctx context.Context, req *pb.Find return nil, err } - clusterId, err := models.SharedUserDAO.FindUserClusterId(req.UserId) + tx := this.NullTx() + + clusterId, err := models.SharedUserDAO.FindUserClusterId(tx, req.UserId) if err != nil { return nil, err } @@ -372,7 +396,10 @@ func (this *UserService) UpdateUserFeatures(ctx context.Context, req *pb.UpdateU if err != nil { return nil, err } - err = models.SharedUserDAO.UpdateUserFeatures(req.UserId, featuresJSON) + + tx := this.NullTx() + + err = models.SharedUserDAO.UpdateUserFeatures(tx, req.UserId, featuresJSON) if err != nil { return nil, err } @@ -391,7 +418,9 @@ func (this *UserService) FindUserFeatures(ctx context.Context, req *pb.FindUserF } } - features, err := models.SharedUserDAO.FindUserFeatures(req.UserId) + tx := this.NullTx() + + features, err := models.SharedUserDAO.FindUserFeatures(tx, req.UserId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user_access_key.go b/internal/rpc/services/service_user_access_key.go index 54cbdcd6..cdf50832 100644 --- a/internal/rpc/services/service_user_access_key.go +++ b/internal/rpc/services/service_user_access_key.go @@ -18,7 +18,9 @@ func (this *UserAccessKeyService) CreateUserAccessKey(ctx context.Context, req * return nil, err } - userAccessKeyId, err := models.SharedUserAccessKeyDAO.CreateAccessKey(req.UserId, req.Description) + tx := this.NullTx() + + userAccessKeyId, err := models.SharedUserAccessKeyDAO.CreateAccessKey(tx, req.UserId, req.Description) if err != nil { return nil, err } @@ -32,7 +34,9 @@ func (this *UserAccessKeyService) FindAllEnabledUserAccessKeys(ctx context.Conte return nil, err } - accessKeys, err := models.SharedUserAccessKeyDAO.FindAllEnabledAccessKeys(req.UserId) + tx := this.NullTx() + + accessKeys, err := models.SharedUserAccessKeyDAO.FindAllEnabledAccessKeys(tx, req.UserId) if err != nil { return nil, err } @@ -60,8 +64,10 @@ func (this *UserAccessKeyService) DeleteUserAccessKey(ctx context.Context, req * return nil, err } + tx := this.NullTx() + if userId > 0 { - ok, err := models.SharedUserAccessKeyDAO.CheckUserAccessKey(userId, req.UserAccessKeyId) + ok, err := models.SharedUserAccessKeyDAO.CheckUserAccessKey(tx, userId, req.UserAccessKeyId) if err != nil { return nil, err } @@ -70,7 +76,7 @@ func (this *UserAccessKeyService) DeleteUserAccessKey(ctx context.Context, req * } } - err = models.SharedUserAccessKeyDAO.DisableUserAccessKey(req.UserAccessKeyId) + err = models.SharedUserAccessKeyDAO.DisableUserAccessKey(tx, req.UserAccessKeyId) if err != nil { return nil, err } @@ -84,8 +90,10 @@ func (this *UserAccessKeyService) UpdateUserAccessKeyIsOn(ctx context.Context, r return nil, err } + tx := this.NullTx() + if userId > 0 { - ok, err := models.SharedUserAccessKeyDAO.CheckUserAccessKey(userId, req.UserAccessKeyId) + ok, err := models.SharedUserAccessKeyDAO.CheckUserAccessKey(tx, userId, req.UserAccessKeyId) if err != nil { return nil, err } @@ -94,7 +102,7 @@ func (this *UserAccessKeyService) UpdateUserAccessKeyIsOn(ctx context.Context, r } } - err = models.SharedUserAccessKeyDAO.UpdateAccessKeyIsOn(req.UserAccessKeyId, req.IsOn) + err = models.SharedUserAccessKeyDAO.UpdateAccessKeyIsOn(tx, req.UserAccessKeyId, req.IsOn) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user_bill.go b/internal/rpc/services/service_user_bill.go index f963df6c..b0c421b8 100644 --- a/internal/rpc/services/service_user_bill.go +++ b/internal/rpc/services/service_user_bill.go @@ -29,7 +29,9 @@ func (this *UserBillService) GenerateAllUserBills(ctx context.Context, req *pb.G return nil, errors.New("invalid month '" + req.Month + "'") } - err = models.SharedUserBillDAO.GenerateBills(req.Month) + tx := this.NullTx() + + err = models.SharedUserBillDAO.GenerateBills(tx, req.Month) if err != nil { return nil, err } @@ -44,7 +46,9 @@ func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.Coun return nil, err } - count, err := models.SharedUserBillDAO.CountAllUserBills(req.PaidFlag, req.UserId, req.Month) + tx := this.NullTx() + + count, err := models.SharedUserBillDAO.CountAllUserBills(tx, req.PaidFlag, req.UserId, req.Month) if err != nil { return nil, err } @@ -58,13 +62,15 @@ func (this *UserBillService) ListUserBills(ctx context.Context, req *pb.ListUser return nil, err } - bills, err := models.SharedUserBillDAO.ListUserBills(req.PaidFlag, req.UserId, req.Month, req.Offset, req.Size) + tx := this.NullTx() + + bills, err := models.SharedUserBillDAO.ListUserBills(tx, req.PaidFlag, req.UserId, req.Month, req.Offset, req.Size) if err != nil { return nil, err } result := []*pb.UserBill{} for _, bill := range bills { - userFullname, err := models.SharedUserDAO.FindUserFullname(int64(bill.UserId)) + userFullname, err := models.SharedUserDAO.FindUserFullname(tx, int64(bill.UserId)) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user_node.go b/internal/rpc/services/service_user_node.go index fa6b4bcd..00c9ff21 100644 --- a/internal/rpc/services/service_user_node.go +++ b/internal/rpc/services/service_user_node.go @@ -20,7 +20,9 @@ func (this *UserNodeService) CreateUserNode(ctx context.Context, req *pb.CreateU return nil, err } - nodeId, err := models.SharedUserNodeDAO.CreateUserNode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + tx := this.NullTx() + + nodeId, err := models.SharedUserNodeDAO.CreateUserNode(tx, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -35,7 +37,9 @@ func (this *UserNodeService) UpdateUserNode(ctx context.Context, req *pb.UpdateU return nil, err } - err = models.SharedUserNodeDAO.UpdateUserNode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + tx := this.NullTx() + + err = models.SharedUserNodeDAO.UpdateUserNode(tx, req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -50,7 +54,9 @@ func (this *UserNodeService) DeleteUserNode(ctx context.Context, req *pb.DeleteU return nil, err } - err = models.SharedUserNodeDAO.DisableUserNode(req.NodeId) + tx := this.NullTx() + + err = models.SharedUserNodeDAO.DisableUserNode(tx, req.NodeId) if err != nil { return nil, err } @@ -65,7 +71,9 @@ func (this *UserNodeService) FindAllEnabledUserNodes(ctx context.Context, req *p return nil, err } - nodes, err := models.SharedUserNodeDAO.FindAllEnabledUserNodes() + tx := this.NullTx() + + nodes, err := models.SharedUserNodeDAO.FindAllEnabledUserNodes(tx) if err != nil { return nil, err } @@ -101,7 +109,9 @@ func (this *UserNodeService) CountAllEnabledUserNodes(ctx context.Context, req * return nil, err } - count, err := models.SharedUserNodeDAO.CountAllEnabledUserNodes() + tx := this.NullTx() + + count, err := models.SharedUserNodeDAO.CountAllEnabledUserNodes(tx) if err != nil { return nil, err } @@ -116,7 +126,9 @@ func (this *UserNodeService) ListEnabledUserNodes(ctx context.Context, req *pb.L return nil, err } - nodes, err := models.SharedUserNodeDAO.ListEnabledUserNodes(req.Offset, req.Size) + tx := this.NullTx() + + nodes, err := models.SharedUserNodeDAO.ListEnabledUserNodes(tx, req.Offset, req.Size) if err != nil { return nil, err } @@ -152,7 +164,9 @@ func (this *UserNodeService) FindEnabledUserNode(ctx context.Context, req *pb.Fi return nil, err } - node, err := models.SharedUserNodeDAO.FindEnabledUserNode(req.NodeId) + tx := this.NullTx() + + node, err := models.SharedUserNodeDAO.FindEnabledUserNode(tx, req.NodeId) if err != nil { return nil, err } @@ -188,6 +202,8 @@ func (this *UserNodeService) FindCurrentUserNode(ctx context.Context, req *pb.Fi return nil, err } + tx := this.NullTx() + md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, errors.New("context: need 'nodeId'") @@ -197,7 +213,7 @@ func (this *UserNodeService) FindCurrentUserNode(ctx context.Context, req *pb.Fi return nil, errors.New("invalid 'nodeId'") } nodeId := nodeIds[0] - node, err := models.SharedUserNodeDAO.FindEnabledUserNodeWithUniqueId(nodeId) + node, err := models.SharedUserNodeDAO.FindEnabledUserNodeWithUniqueId(tx, nodeId) if err != nil { return nil, err } diff --git a/internal/rpc/services/sevice_http_gzip.go b/internal/rpc/services/sevice_http_gzip.go index 7c76344f..4adddbae 100644 --- a/internal/rpc/services/sevice_http_gzip.go +++ b/internal/rpc/services/sevice_http_gzip.go @@ -43,7 +43,9 @@ func (this *HTTPGzipService) CreateHTTPGzip(ctx context.Context, req *pb.CreateH } } - gzipId, err := models.SharedHTTPGzipDAO.CreateGzip(int(req.Level), minLengthJSON, maxLengthJSON, req.CondsJSON) + tx := this.NullTx() + + gzipId, err := models.SharedHTTPGzipDAO.CreateGzip(tx, int(req.Level), minLengthJSON, maxLengthJSON, req.CondsJSON) if err != nil { return nil, err } @@ -59,7 +61,9 @@ func (this *HTTPGzipService) FindEnabledHTTPGzipConfig(ctx context.Context, req return nil, err } - config, err := models.SharedHTTPGzipDAO.ComposeGzipConfig(req.GzipId) + tx := this.NullTx() + + config, err := models.SharedHTTPGzipDAO.ComposeGzipConfig(tx, req.GzipId) if err != nil { return nil, err } @@ -101,7 +105,9 @@ func (this *HTTPGzipService) UpdateHTTPGzip(ctx context.Context, req *pb.UpdateH } } - err = models.SharedHTTPGzipDAO.UpdateGzip(req.GzipId, int(req.Level), minLengthJSON, maxLengthJSON, req.CondsJSON) + tx := this.NullTx() + + err = models.SharedHTTPGzipDAO.UpdateGzip(tx, req.GzipId, int(req.Level), minLengthJSON, maxLengthJSON, req.CondsJSON) if err != nil { return nil, err } diff --git a/internal/rpc/utils/utils.go b/internal/rpc/utils/utils.go index 9536bbe3..59cd3ad8 100644 --- a/internal/rpc/utils/utils.go +++ b/internal/rpc/utils/utils.go @@ -68,7 +68,7 @@ func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserT // 获取角色Node信息 // TODO 缓存节点ID相关信息 - apiToken, err := models.SharedApiTokenDAO.FindEnabledTokenWithNode(nodeId) + apiToken, err := models.SharedApiTokenDAO.FindEnabledTokenWithNode(nil, nodeId) if err != nil { utils.PrintError(err) return UserTypeNone, 0, err @@ -121,7 +121,7 @@ func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserT switch apiToken.Role { case UserTypeNode: - nodeIntId, err := models.SharedNodeDAO.FindEnabledNodeIdWithUniqueId(nodeId) + nodeIntId, err := models.SharedNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId) if err != nil { return UserTypeNode, 0, errors.New("context: " + err.Error()) } @@ -130,7 +130,7 @@ func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserT } nodeUserId = nodeIntId case UserTypeCluster: - clusterId, err := models.SharedNodeClusterDAO.FindEnabledClusterIdWithUniqueId(nodeId) + clusterId, err := models.SharedNodeClusterDAO.FindEnabledClusterIdWithUniqueId(nil, nodeId) if err != nil { return UserTypeCluster, 0, errors.New("context: " + err.Error()) } diff --git a/internal/setup/setup.go b/internal/setup/setup.go index ad5c65c7..668618f5 100644 --- a/internal/setup/setup.go +++ b/internal/setup/setup.go @@ -100,7 +100,7 @@ func (this *Setup) Run() error { // Admin节点信息 apiTokenDAO := models.NewApiTokenDAO() - token, err := apiTokenDAO.FindEnabledTokenWithRole("admin") + token, err := apiTokenDAO.FindEnabledTokenWithRole(nil, "admin") if err != nil { return err } @@ -112,7 +112,7 @@ func (this *Setup) Run() error { // 检查API节点 dao := models.NewAPINodeDAO() - apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(this.config.APINodeProtocol, this.config.APINodeHost, this.config.APINodePort) + apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(nil, this.config.APINodeProtocol, this.config.APINodeHost, this.config.APINodePort) if err != nil { return err } @@ -160,14 +160,14 @@ func (this *Setup) Run() error { } // 创建API节点 - nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true) + nodeId, err := dao.CreateAPINode(nil, "默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true) if err != nil { return errors.New("create api node in database failed: " + err.Error()) } apiNodeId = nodeId } - apiNode, err := dao.FindEnabledAPINode(apiNodeId) + apiNode, err := dao.FindEnabledAPINode(nil, apiNodeId) if err != nil { return err } diff --git a/internal/tasks/event_looper.go b/internal/tasks/event_looper.go index 900ed473..0c3fe416 100644 --- a/internal/tasks/event_looper.go +++ b/internal/tasks/event_looper.go @@ -34,12 +34,12 @@ func (this *EventLooper) Start() { func (this *EventLooper) loop() error { lockerKey := "eventLooper" - isOk, err := models.SharedSysLockerDAO.Lock(lockerKey, 3600) + isOk, err := models.SharedSysLockerDAO.Lock(nil, lockerKey, 3600) if err != nil { return err } defer func() { - err = models.SharedSysLockerDAO.Unlock(lockerKey) + err = models.SharedSysLockerDAO.Unlock(nil, lockerKey) if err != nil { logs.Println("[EVENT_LOOPER]" + err.Error()) } @@ -48,7 +48,7 @@ func (this *EventLooper) loop() error { return nil } - events, err := models.SharedSysEventDAO.FindEvents(100) + events, err := models.SharedSysEventDAO.FindEvents(nil, 100) if err != nil { return err } @@ -63,7 +63,7 @@ func (this *EventLooper) loop() error { logs.Println("[EVENT_LOOPER]" + err.Error()) continue } - err = models.SharedSysEventDAO.DeleteEvent(int64(eventOne.Id)) + err = models.SharedSysEventDAO.DeleteEvent(nil, int64(eventOne.Id)) if err != nil { return err } diff --git a/internal/tasks/health_check_cluster_task.go b/internal/tasks/health_check_cluster_task.go index 8bb8dc78..f814983c 100644 --- a/internal/tasks/health_check_cluster_task.go +++ b/internal/tasks/health_check_cluster_task.go @@ -89,7 +89,7 @@ func (this *HealthCheckClusterTask) loop(seconds int64) error { // 检查上次运行时间,防止重复运行 settingKey := models.SettingCodeClusterHealthCheck + "Loop" + numberutils.FormatInt64(this.clusterId) timestamp := time.Now().Unix() - c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-seconds) + c, err := models.SharedSysSettingDAO.CompareInt64Setting(nil, settingKey, timestamp-seconds) if err != nil { return err } @@ -98,7 +98,7 @@ func (this *HealthCheckClusterTask) loop(seconds int64) error { } // 记录时间 - err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + err = models.SharedSysSettingDAO.UpdateSetting(nil, settingKey, []byte(numberutils.FormatInt64(timestamp))) if err != nil { return err } @@ -131,7 +131,7 @@ func (this *HealthCheckClusterTask) loop(seconds int64) error { return err } message := "有" + numberutils.FormatInt(len(failedResults)) + "个节点在健康检查中出现问题" - err = models.NewMessageDAO().CreateClusterMessage(this.clusterId, models.MessageTypeHealthCheckFailed, models.MessageLevelError, message, failedResultsJSON) + err = models.NewMessageDAO().CreateClusterMessage(nil, this.clusterId, models.MessageTypeHealthCheckFailed, models.MessageLevelError, message, failedResultsJSON) if err != nil { return err } diff --git a/internal/tasks/health_check_executor.go b/internal/tasks/health_check_executor.go index faeeb492..ce937cc9 100644 --- a/internal/tasks/health_check_executor.go +++ b/internal/tasks/health_check_executor.go @@ -26,7 +26,7 @@ func NewHealthCheckExecutor(clusterId int64) *HealthCheckExecutor { } func (this *HealthCheckExecutor) Run() ([]*HealthCheckResult, error) { - cluster, err := models.NewNodeClusterDAO().FindEnabledNodeCluster(this.clusterId) + cluster, err := models.NewNodeClusterDAO().FindEnabledNodeCluster(nil, this.clusterId) if err != nil { return nil, err } @@ -44,7 +44,7 @@ func (this *HealthCheckExecutor) Run() ([]*HealthCheckResult, error) { } results := []*HealthCheckResult{} - nodes, err := models.NewNodeDAO().FindAllEnabledNodesWithClusterId(this.clusterId) + nodes, err := models.NewNodeDAO().FindAllEnabledNodesWithClusterId(nil, this.clusterId) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func (this *HealthCheckExecutor) Run() ([]*HealthCheckResult, error) { Node: node, } - ipAddr, err := models.NewNodeIPAddressDAO().FindFirstNodeIPAddress(int64(node.Id)) + ipAddr, err := models.NewNodeIPAddressDAO().FindFirstNodeIPAddress(nil, int64(node.Id)) if err != nil { return nil, err } @@ -129,7 +129,7 @@ func (this *HealthCheckExecutor) Run() ([]*HealthCheckResult, error) { // 修改节点状态 if healthCheckConfig.AutoDown { - isChanged, err := models.SharedNodeDAO.UpdateNodeUp(int64(result.Node.Id), result.IsOk, healthCheckConfig.CountUp, healthCheckConfig.CountDown) + isChanged, err := models.SharedNodeDAO.UpdateNodeUp(nil, int64(result.Node.Id), result.IsOk, healthCheckConfig.CountUp, healthCheckConfig.CountDown) if err != nil { logs.Println("[HEALTH_CHECK]" + err.Error()) } else if isChanged { @@ -142,9 +142,9 @@ func (this *HealthCheckExecutor) Run() ([]*HealthCheckResult, error) { // 通知恢复或下线 if result.IsOk { - err = models.NewMessageDAO().CreateNodeMessage(this.clusterId, int64(result.Node.Id), models.MessageTypeHealthCheckNodeUp, models.MessageLevelSuccess, "健康检查成功,节点\""+result.Node.Name+"\"已恢复上线", nil) + err = models.NewMessageDAO().CreateNodeMessage(nil, this.clusterId, int64(result.Node.Id), models.MessageTypeHealthCheckNodeUp, models.MessageLevelSuccess, "健康检查成功,节点\""+result.Node.Name+"\"已恢复上线", nil) } else { - err = models.NewMessageDAO().CreateNodeMessage(this.clusterId, int64(result.Node.Id), models.MessageTypeHealthCheckNodeDown, models.MessageLevelError, "健康检查失败,节点\""+result.Node.Name+"\"已自动下线", nil) + err = models.NewMessageDAO().CreateNodeMessage(nil, this.clusterId, int64(result.Node.Id), models.MessageTypeHealthCheckNodeDown, models.MessageLevelError, "健康检查失败,节点\""+result.Node.Name+"\"已自动下线", nil) } } } diff --git a/internal/tasks/health_check_task.go b/internal/tasks/health_check_task.go index 526c3aa1..4933148b 100644 --- a/internal/tasks/health_check_task.go +++ b/internal/tasks/health_check_task.go @@ -46,7 +46,7 @@ func (this *HealthCheckTask) Run() { } func (this *HealthCheckTask) loop() error { - clusters, err := models.NewNodeClusterDAO().FindAllEnableClusters() + clusters, err := models.NewNodeClusterDAO().FindAllEnableClusters(nil) if err != nil { return err } diff --git a/internal/tasks/log_task.go b/internal/tasks/log_task.go index 4e8bfb35..88a1eb38 100644 --- a/internal/tasks/log_task.go +++ b/internal/tasks/log_task.go @@ -44,7 +44,7 @@ func (this *LogTask) loopClean(seconds int64) error { // 检查上次运行时间,防止重复运行 settingKey := "logTaskCleanLoop" timestamp := time.Now().Unix() - c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-seconds) + c, err := models.SharedSysSettingDAO.CompareInt64Setting(nil, settingKey, timestamp-seconds) if err != nil { return err } @@ -53,13 +53,13 @@ func (this *LogTask) loopClean(seconds int64) error { } // 记录时间 - err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + err = models.SharedSysSettingDAO.UpdateSetting(nil, settingKey, []byte(numberutils.FormatInt64(timestamp))) if err != nil { return err } configKey := "adminLogConfig" - valueJSON, err := models.SharedSysSettingDAO.ReadSetting(configKey) + valueJSON, err := models.SharedSysSettingDAO.ReadSetting(nil, configKey) if err != nil { return err } @@ -73,7 +73,7 @@ func (this *LogTask) loopClean(seconds int64) error { return err } if config.Days > 0 { - err = models.SharedLogDAO.DeleteLogsPermanentlyBeforeDays(config.Days) + err = models.SharedLogDAO.DeleteLogsPermanentlyBeforeDays(nil, config.Days) if err != nil { return err } @@ -95,7 +95,7 @@ func (this *LogTask) loopMonitor(seconds int64) error { // 检查上次运行时间,防止重复运行 settingKey := "logTaskMonitorLoop" timestamp := time.Now().Unix() - c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-seconds) + c, err := models.SharedSysSettingDAO.CompareInt64Setting(nil, settingKey, timestamp-seconds) if err != nil { return err } @@ -104,13 +104,13 @@ func (this *LogTask) loopMonitor(seconds int64) error { } // 记录时间 - err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + err = models.SharedSysSettingDAO.UpdateSetting(nil, settingKey, []byte(numberutils.FormatInt64(timestamp))) if err != nil { return err } configKey := "adminLogConfig" - valueJSON, err := models.SharedSysSettingDAO.ReadSetting(configKey) + valueJSON, err := models.SharedSysSettingDAO.ReadSetting(nil, configKey) if err != nil { return err } @@ -132,7 +132,7 @@ func (this *LogTask) loopMonitor(seconds int64) error { return err } if sumBytes > capacityBytes { - err := models.SharedMessageDAO.CreateMessage(0, 0, models.MessageTypeLogCapacityOverflow, models.MessageLevelError, "日志用量已经超出最大限制,当前的用量为"+this.formatBytes(sumBytes)+",而设置的最大容量为"+this.formatBytes(capacityBytes)+"。", nil) + err := models.SharedMessageDAO.CreateMessage(nil, 0, 0, models.MessageTypeLogCapacityOverflow, models.MessageLevelError, "日志用量已经超出最大限制,当前的用量为"+this.formatBytes(sumBytes)+",而设置的最大容量为"+this.formatBytes(capacityBytes)+"。", nil) if err != nil { return err } diff --git a/internal/tasks/message_task.go b/internal/tasks/message_task.go index e7aab11f..f9307ca7 100644 --- a/internal/tasks/message_task.go +++ b/internal/tasks/message_task.go @@ -37,5 +37,5 @@ func (this *MessageTask) Run() { // 单次运行 func (this *MessageTask) loop() error { dayTime := time.Now().AddDate(0, 0, -30) // TODO 这个30天应该可以在界面上设置 - return models.NewMessageDAO().DeleteMessagesBeforeDay(dayTime) + return models.NewMessageDAO().DeleteMessagesBeforeDay(nil, dayTime) } diff --git a/internal/tasks/node_log_cleaner_task.go b/internal/tasks/node_log_cleaner_task.go index 1e0d952d..4e83ac6f 100644 --- a/internal/tasks/node_log_cleaner_task.go +++ b/internal/tasks/node_log_cleaner_task.go @@ -36,5 +36,5 @@ func (this *NodeLogCleanerTask) Start() { func (this *NodeLogCleanerTask) loop() error { // TODO 30天这个数值改成可以设置 - return models.SharedNodeLogDAO.DeleteExpiredLogs(30) + return models.SharedNodeLogDAO.DeleteExpiredLogs(nil, 30) } diff --git a/internal/tasks/node_monitor_task.go b/internal/tasks/node_monitor_task.go index c32a9540..78ebd93a 100644 --- a/internal/tasks/node_monitor_task.go +++ b/internal/tasks/node_monitor_task.go @@ -42,7 +42,7 @@ func (this *NodeMonitorTask) loop() error { // 检查上次运行时间,防止重复运行 settingKey := models.SettingCodeNodeMonitor + "Loop" timestamp := time.Now().Unix() - c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-int64(this.intervalSeconds)) + c, err := models.SharedSysSettingDAO.CompareInt64Setting(nil, settingKey, timestamp-int64(this.intervalSeconds)) if err != nil { return err } @@ -51,12 +51,12 @@ func (this *NodeMonitorTask) loop() error { } // 记录时间 - err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + err = models.SharedSysSettingDAO.UpdateSetting(nil, settingKey, []byte(numberutils.FormatInt64(timestamp))) if err != nil { return err } - clusters, err := models.SharedNodeClusterDAO.FindAllEnableClusters() + clusters, err := models.SharedNodeClusterDAO.FindAllEnableClusters(nil) if err != nil { return err } @@ -74,18 +74,18 @@ func (this *NodeMonitorTask) monitorCluster(cluster *models.NodeCluster) error { clusterId := int64(cluster.Id) // 检查离线节点 - inactiveNodes, err := models.SharedNodeDAO.FindAllInactiveNodesWithClusterId(clusterId) + inactiveNodes, err := models.SharedNodeDAO.FindAllInactiveNodesWithClusterId(nil, clusterId) if err != nil { return err } for _, node := range inactiveNodes { - err = models.SharedMessageDAO.CreateNodeMessage(clusterId, int64(node.Id), models.MessageTypeNodeInactive, models.LevelError, "节点已处于离线状态", nil) + err = models.SharedMessageDAO.CreateNodeMessage(nil, clusterId, int64(node.Id), models.MessageTypeNodeInactive, models.LevelError, "节点已处于离线状态", nil) if err != nil { return err } // 修改在线状态 - err = models.SharedNodeDAO.UpdateNodeActive(int64(node.Id), false) + err = models.SharedNodeDAO.UpdateNodeActive(nil, int64(node.Id), false) if err != nil { return err } diff --git a/internal/tasks/ssl_cert_expire_check_executor.go b/internal/tasks/ssl_cert_expire_check_executor.go index cfbc5c2e..141f7775 100644 --- a/internal/tasks/ssl_cert_expire_check_executor.go +++ b/internal/tasks/ssl_cert_expire_check_executor.go @@ -42,7 +42,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 检查上次运行时间,防止重复运行 settingKey := "sslCertExpiringCheckLoop" timestamp := time.Now().Unix() - c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-seconds) + c, err := models.SharedSysSettingDAO.CompareInt64Setting(nil, settingKey, timestamp-seconds) if err != nil { return err } @@ -51,7 +51,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { } // 记录时间 - err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + err = models.SharedSysSettingDAO.UpdateSetting(nil, settingKey, []byte(numberutils.FormatInt64(timestamp))) if err != nil { return err } @@ -59,7 +59,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 查找需要自动更新的证书 // 30, 14 ... 是到期的天数 for _, days := range []int{30, 14, 7} { - certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(nil, days) if err != nil { return err } @@ -69,7 +69,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 是否有自动更新任务 if cert.AcmeTaskId > 0 { - task, err := models.SharedACMETaskDAO.FindEnabledACMETask(int64(cert.AcmeTaskId)) + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(nil, int64(cert.AcmeTaskId)) if err != nil { return err } @@ -84,7 +84,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { msg += "请及时更新证书。" } - err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(nil, int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) @@ -93,7 +93,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { } // 设置最后通知时间 - err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) if err != nil { return err } @@ -102,7 +102,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 自动续期 for _, days := range []int{3, 2, 1} { - certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(nil, days) if err != nil { return err } @@ -112,36 +112,36 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 是否有自动更新任务 if cert.AcmeTaskId > 0 { - task, err := models.SharedACMETaskDAO.FindEnabledACMETask(int64(cert.AcmeTaskId)) + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(nil, int64(cert.AcmeTaskId)) if err != nil { return err } if task != nil { if task.AutoRenew == 1 { - isOk, errMsg, _ := models.SharedACMETaskDAO.RunTask(int64(cert.AcmeTaskId)) + isOk, errMsg, _ := models.SharedACMETaskDAO.RunTask(nil, int64(cert.AcmeTaskId)) if isOk { // 发送成功通知 msg = "系统已成功为你自动更新了证书\"" + cert.Name + "\"(" + cert.DnsNames + ")。" - err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskSuccess, models.MessageLevelSuccess, msg, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(nil, int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskSuccess, models.MessageLevelSuccess, msg, maps.Map{ "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) // 更新通知时间 - err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) if err != nil { return err } } else { // 发送失败通知 msg = "系统在尝试自动更新证书\"" + cert.Name + "\"(" + cert.DnsNames + ")时发生错误:" + errMsg + "。请检查系统设置并修复错误。" - err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskFailed, models.MessageLevelError, msg, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(nil, int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskFailed, models.MessageLevelError, msg, maps.Map{ "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) // 更新通知时间 - err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) if err != nil { return err } @@ -158,7 +158,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { msg += "请及时更新证书。" } - err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(nil, int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) @@ -167,7 +167,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { } // 设置最后通知时间 - err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) if err != nil { return err } @@ -176,7 +176,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 当天过期 for _, days := range []int{0} { - certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(nil, days) if err != nil { return err } @@ -184,7 +184,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { // 发送消息 today := timeutil.Format("Y-m-d") msg := "SSL证书\"" + cert.Name + "\"(" + cert.DnsNames + ")在今天(" + today + ")过期,请及时更新证书,之后将不再重复提醒。" - err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + err = models.SharedMessageDAO.CreateMessage(nil, int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) @@ -193,7 +193,7 @@ func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { } // 设置最后通知时间 - err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) if err != nil { return err }