From af5584c7683316ca652aa920b8c85c370ad87fad Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sun, 10 Jan 2021 17:34:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0URL=E8=B7=B3=E8=BD=AC?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/api_node_dao_test.go | 6 +- internal/db/models/dns_domain_dao_test.go | 11 +- .../db/models/http_access_log_dao_test.go | 21 ++- internal/db/models/http_web_dao.go | 86 +++++++++ internal/db/models/http_web_dao_test.go | 44 ++++- internal/db/models/http_web_model.go | 2 + internal/db/models/ip_list_dao_test.go | 8 +- internal/db/models/message_dao_test.go | 9 +- internal/db/models/node_dao_test.go | 9 +- internal/db/models/origin_dao_test.go | 4 +- internal/db/models/reverse_proxy_dao_test.go | 4 +- .../db/models/server_daily_stat_dao_test.go | 9 +- internal/db/models/server_dao.go | 12 ++ internal/db/models/server_dao_test.go | 19 +- internal/db/models/sys_locker_dao_test.go | 7 +- internal/db/models/sys_setting_dao_test.go | 14 +- internal/db/models/user_bill_dao_test.go | 4 +- internal/rpc/services/service_http_web.go | 174 ++++++++++++++++-- 18 files changed, 381 insertions(+), 62 deletions(-) diff --git a/internal/db/models/api_node_dao_test.go b/internal/db/models/api_node_dao_test.go index 43c86d80..30018a9e 100644 --- a/internal/db/models/api_node_dao_test.go +++ b/internal/db/models/api_node_dao_test.go @@ -2,14 +2,16 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "runtime" "testing" ) func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) { dao := NewAPINodeDAO() + var tx *dbs.Tx { - apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 123) + apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 123) if err != nil { t.Fatal(err) } @@ -17,7 +19,7 @@ func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) { } { - apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 8003) + apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 8003) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/dns_domain_dao_test.go b/internal/db/models/dns_domain_dao_test.go index 36ae0232..40848bbc 100644 --- a/internal/db/models/dns_domain_dao_test.go +++ b/internal/db/models/dns_domain_dao_test.go @@ -2,33 +2,36 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" ) func TestDNSDomainDAO_ExistDomainRecord(t *testing.T) { + var tx *dbs.Tx + { - b, err := NewDNSDomainDAO().ExistDomainRecord(1, "mycluster", "A", "", "") + b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 1, "mycluster", "A", "", "") if err != nil { t.Fatal(err) } t.Log(b) } { - b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "A", "", "") + b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "A", "", "") if err != nil { t.Fatal(err) } t.Log(b) } { - b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "MX", "", "") + b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "MX", "", "") if err != nil { t.Fatal(err) } t.Log(b) } { - b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster123", "A", "", "") + b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster123", "A", "", "") if err != nil { t.Fatal(err) } diff --git a/internal/db/models/http_access_log_dao_test.go b/internal/db/models/http_access_log_dao_test.go index f14096b7..a7cdfc91 100644 --- a/internal/db/models/http_access_log_dao_test.go +++ b/internal/db/models/http_access_log_dao_test.go @@ -4,12 +4,15 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" _ "github.com/go-sql-driver/mysql" _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/dbs" timeutil "github.com/iwind/TeaGo/utils/time" "testing" "time" ) func TestCreateHTTPAccessLogs(t *testing.T) { + var tx *dbs.Tx + err := NewDBNodeInitializer().loop() if err != nil { t.Fatal(err) @@ -23,7 +26,7 @@ func TestCreateHTTPAccessLogs(t *testing.T) { } dao := randomAccessLogDAO() t.Log("dao:", dao) - err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(dao, []*pb.HTTPAccessLog{accessLog}) + err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(tx, dao, []*pb.HTTPAccessLog{accessLog}) if err != nil { t.Fatal(err) } @@ -31,12 +34,14 @@ func TestCreateHTTPAccessLogs(t *testing.T) { } func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) { + var tx *dbs.Tx + err := NewDBNodeInitializer().loop() if err != nil { t.Fatal(err) } - accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) + accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) if err != nil { t.Fatal(err) } @@ -51,6 +56,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) { } func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) { + var tx *dbs.Tx + err := NewDBNodeInitializer().loop() if err != nil { t.Fatal(err) @@ -61,7 +68,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) { times := 0 // 防止循环次数太多 for { before := time.Now() - accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) + accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) cost := time.Since(before).Seconds() if err != nil { t.Fatal(err) @@ -84,13 +91,15 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) { } func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) { + var tx *dbs.Tx + err := NewDBNodeInitializer().loop() if err != nil { t.Fatal(err) } before := time.Now() - accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0) + accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0) cost := time.Since(before).Seconds() if err != nil { t.Fatal(err) @@ -103,6 +112,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) { } func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) { + var tx *dbs.Tx + err := NewDBNodeInitializer().loop() if err != nil { t.Fatal(err) @@ -113,7 +124,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) { times := 0 // 防止循环次数太多 for { before := time.Now() - accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0) + accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0) cost := time.Since(before).Seconds() if err != nil { t.Fatal(err) diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index af33b8d4..75809ee5 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -322,6 +322,16 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig } } + // 主机跳转 + if IsNotNull(web.HostRedirects) { + redirects := []*serverconfigs.HTTPHostRedirectConfig{} + err = json.Unmarshal([]byte(web.HostRedirects), &redirects) + if err != nil { + return nil, err + } + config.HostRedirects = redirects + } + return config, nil } @@ -618,3 +628,79 @@ func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(tx *dbs.Tx, locationId in Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 FindInt64Col(0) } + +// 查找使用此Web的Server +func (this *HTTPWebDAO) FindWebServerId(tx *dbs.Tx, webId int64) (serverId int64, err error) { + if webId <= 0 { + return 0, nil + } + serverId, err = SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId) + if err != nil { + return + } + if serverId > 0 { + return + } + + // web在Location中的情况 + locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(tx, webId) + if err != nil { + return 0, err + } + if locationId == 0 { + return + } + webId, err = this.FindEnabledWebIdWithLocationId(tx, locationId) + if err != nil { + return + } + if webId <= 0 { + return + } + + // 第二轮查找 + return this.FindWebServerId(tx, webId) +} + +// 检查用户权限 +func (this *HTTPWebDAO) CheckUserWeb(tx *dbs.Tx, userId int64, webId int64) error { + serverId, err := this.FindWebServerId(tx, webId) + if err != nil { + return err + } + if serverId == 0 { + return ErrNotFound + } + return SharedServerDAO.CheckUserServer(tx, serverId, userId) +} + +// 设置主机跳转 +func (this *HTTPWebDAO) UpdateWebHostRedirects(tx *dbs.Tx, webId int64, hostRedirects []*serverconfigs.HTTPHostRedirectConfig) error { + if webId <= 0 { + return errors.New("invalid ") + } + if hostRedirects == nil { + hostRedirects = []*serverconfigs.HTTPHostRedirectConfig{} + } + hostRedirectsJSON, err := json.Marshal(hostRedirects) + if err != nil { + return err + } + _, err = this.Query(tx). + Pk(webId). + Set("hostRedirects", hostRedirectsJSON). + Update() + return err +} + +// 查找主机跳转 +func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, error) { + col, err := this.Query(tx). + Pk(webId). + Result("hostRedirects"). + FindStringCol("") + if err != nil { + return nil, err + } + return []byte(col), nil +} diff --git a/internal/db/models/http_web_dao_test.go b/internal/db/models/http_web_dao_test.go index c2dc50c1..d6863e00 100644 --- a/internal/db/models/http_web_dao_test.go +++ b/internal/db/models/http_web_dao_test.go @@ -7,15 +7,16 @@ import ( ) func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) { + var tx *dbs.Tx { - err := SharedHTTPWebDAO.UpdateWebShutdown(1, []byte("{}")) + err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, []byte("{}")) if err != nil { t.Fatal(err) } } { - err := SharedHTTPWebDAO.UpdateWebShutdown(1, nil) + err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, nil) if err != nil { t.Fatal(err) } @@ -27,15 +28,50 @@ func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) { func TestHTTPWebDAO_FindAllWebIdsWithHTTPFirewallPolicyId(t *testing.T) { dbs.NotifyReady() - webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(9) + var tx *dbs.Tx + + webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, 9) if err != nil { t.Fatal(err) } t.Log("webIds:", webIds) - count, err := SharedServerDAO.CountEnabledServersWithWebIds(webIds) + count, err := SharedServerDAO.CountEnabledServersWithWebIds(tx, webIds) if err != nil { t.Fatal(err) } t.Log("count:", count) } + + +func TestHTTPWebDAO_FindWebServerId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + + // server + { + serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 60) + if err != nil { + t.Fatal(err) + } + t.Log("serverId:", serverId) + } + + // location + { + serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 45) + if err != nil { + t.Fatal(err) + } + t.Log("serverId:", serverId) + } + + { + serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 100) + if err != nil { + t.Fatal(err) + } + t.Log("serverId:", serverId) + } +} \ No newline at end of file diff --git a/internal/db/models/http_web_model.go b/internal/db/models/http_web_model.go index a10be6f7..c30f8255 100644 --- a/internal/db/models/http_web_model.go +++ b/internal/db/models/http_web_model.go @@ -26,6 +26,7 @@ type HTTPWeb struct { Locations string `field:"locations"` // 路径规则配置 Websocket string `field:"websocket"` // Websocket设置 RewriteRules string `field:"rewriteRules"` // 重写规则配置 + HostRedirects string `field:"hostRedirects"` // 域名跳转 } type HTTPWebOperator struct { @@ -53,6 +54,7 @@ type HTTPWebOperator struct { Locations interface{} // 路径规则配置 Websocket interface{} // Websocket设置 RewriteRules interface{} // 重写规则配置 + HostRedirects interface{} // 域名跳转 } func NewHTTPWebOperator() *HTTPWebOperator { diff --git a/internal/db/models/ip_list_dao_test.go b/internal/db/models/ip_list_dao_test.go index 5f80a2be..1b832e83 100644 --- a/internal/db/models/ip_list_dao_test.go +++ b/internal/db/models/ip_list_dao_test.go @@ -10,8 +10,10 @@ import ( func TestIPListDAO_IncreaseVersion(t *testing.T) { dbs.NotifyReady() + var tx *dbs.Tx + dao := NewIPListDAO() - version, err := dao.IncreaseVersion() + version, err := dao.IncreaseVersion(tx) if err != nil { t.Fatal(err) } @@ -23,8 +25,10 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { dbs.NotifyReady() + var tx *dbs.Tx + dao := NewIPListDAO() for i := 0; i < b.N; i++ { - _, _ = dao.IncreaseVersion() + _, _ = dao.IncreaseVersion(tx) } } diff --git a/internal/db/models/message_dao_test.go b/internal/db/models/message_dao_test.go index 27539fa5..96e2ebc9 100644 --- a/internal/db/models/message_dao_test.go +++ b/internal/db/models/message_dao_test.go @@ -2,13 +2,16 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" "time" ) func TestMessageDAO_CreateClusterMessage(t *testing.T) { + var tx *dbs.Tx + dao := NewMessageDAO() - err := dao.CreateClusterMessage(1, "test", "error", "123", []byte("456")) + err := dao.CreateClusterMessage(tx, 1, "test", "error", "123", []byte("456")) if err != nil { t.Fatal(err) } @@ -16,8 +19,10 @@ func TestMessageDAO_CreateClusterMessage(t *testing.T) { } func TestMessageDAO_DeleteMessagesBeforeDay(t *testing.T) { + var tx *dbs.Tx + dao := NewMessageDAO() - err := dao.DeleteMessagesBeforeDay(time.Now()) + err := dao.DeleteMessagesBeforeDay(tx, time.Now()) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/node_dao_test.go b/internal/db/models/node_dao_test.go index c1078458..20d23959 100644 --- a/internal/db/models/node_dao_test.go +++ b/internal/db/models/node_dao_test.go @@ -7,7 +7,8 @@ import ( ) func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) { - nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(1) + var tx *dbs.Tx + nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, 1) if err != nil { t.Fatal(err) } @@ -15,7 +16,8 @@ func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) { } func TestNodeDAO_FindChangedClusterIds(t *testing.T) { - clusterIds, err := SharedNodeDAO.FindChangedClusterIds() + var tx *dbs.Tx + clusterIds, err := SharedNodeDAO.FindChangedClusterIds(tx) if err != nil { t.Fatal(err) } @@ -24,7 +26,8 @@ func TestNodeDAO_FindChangedClusterIds(t *testing.T) { func TestNodeDAO_UpdateNodeUp(t *testing.T) { dbs.NotifyReady() - isChanged, err := SharedNodeDAO.UpdateNodeUp(57, false, 3, 3) + var tx *dbs.Tx + isChanged, err := SharedNodeDAO.UpdateNodeUp(tx, 57, false, 3, 3) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/origin_dao_test.go b/internal/db/models/origin_dao_test.go index 9e75391f..5679d1ed 100644 --- a/internal/db/models/origin_dao_test.go +++ b/internal/db/models/origin_dao_test.go @@ -2,11 +2,13 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" ) func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) { - config, err := SharedOriginDAO.ComposeOriginConfig(1) + var tx *dbs.Tx + config, err := SharedOriginDAO.ComposeOriginConfig(tx, 1) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/reverse_proxy_dao_test.go b/internal/db/models/reverse_proxy_dao_test.go index dbfe1f86..eb0d2444 100644 --- a/internal/db/models/reverse_proxy_dao_test.go +++ b/internal/db/models/reverse_proxy_dao_test.go @@ -2,11 +2,13 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" ) func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) { - config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(1) + var tx *dbs.Tx + config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, 1) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/server_daily_stat_dao_test.go b/internal/db/models/server_daily_stat_dao_test.go index 99c33f07..f1aaefa5 100644 --- a/internal/db/models/server_daily_stat_dao_test.go +++ b/internal/db/models/server_daily_stat_dao_test.go @@ -9,6 +9,7 @@ import ( ) func TestServerDailyStatDAO_SaveStats(t *testing.T) { + var tx *dbs.Tx stats := []*pb.ServerDailyStat{ { ServerId: 1, @@ -17,7 +18,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) { CreatedAt: 1607671488, }, } - err := NewServerDailyStatDAO().SaveStats(stats) + err := NewServerDailyStatDAO().SaveStats(tx, stats) if err != nil { t.Fatal(err) } @@ -25,6 +26,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) { } func TestServerDailyStatDAO_SaveStats2(t *testing.T) { + var tx *dbs.Tx stats := []*pb.ServerDailyStat{ { ServerId: 1, @@ -33,7 +35,7 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) { CreatedAt: 1607671488, }, } - err := NewServerDailyStatDAO().SaveStats(stats) + err := NewServerDailyStatDAO().SaveStats(tx, stats) if err != nil { t.Fatal(err) } @@ -42,7 +44,8 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) { func TestServerDailyStatDAO_SumUserMonthly(t *testing.T) { dbs.NotifyReady() - bytes, err := NewServerDailyStatDAO().SumUserMonthly(1, 1, timeutil.Format("Ym")) + var tx *dbs.Tx + bytes, err := NewServerDailyStatDAO().SumUserMonthly(tx, 1, 1, timeutil.Format("Ym")) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 079f8d89..4aa417bb 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1077,6 +1077,18 @@ func (this *ServerDAO) FindAllEnabledServersWithUserId(tx *dbs.Tx, userId int64) return } +// 根据WebId查找ServerId +func (this *ServerDAO) FindEnabledServerIdWithWebId(tx *dbs.Tx, webId int64) (serverId int64, err error) { + if webId <= 0 { + return 0, nil + } + return this.Query(tx). + State(ServerStateEnabled). + Attr("webId", webId). + ResultPk(). + FindInt64Col(0) +} + // 生成DNS Name func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) { for { diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index c518e79f..505d8421 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -11,7 +11,8 @@ import ( func TestServerDAO_ComposeServerConfig(t *testing.T) { dbs.NotifyReady() - config, err := SharedServerDAO.ComposeServerConfig(1) + var tx *dbs.Tx + config, err := SharedServerDAO.ComposeServerConfig(tx, 1) if err != nil { t.Fatal(err) } @@ -20,7 +21,8 @@ func TestServerDAO_ComposeServerConfig(t *testing.T) { func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) { dbs.NotifyReady() - config, err := SharedServerDAO.ComposeServerConfig(14) + var tx *dbs.Tx + config, err := SharedServerDAO.ComposeServerConfig(tx, 14) if err != nil { t.Fatal(err) } @@ -29,8 +31,8 @@ func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) { func TestServerDAO_UpdateServerConfig(t *testing.T) { dbs.NotifyReady() - - config, err := SharedServerDAO.ComposeServerConfig(1) + var tx *dbs.Tx + config, err := SharedServerDAO.ComposeServerConfig(tx, 1) if err != nil { t.Fatal(err) } @@ -39,7 +41,7 @@ func TestServerDAO_UpdateServerConfig(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = SharedServerDAO.UpdateServerConfig(1, configJSON, false) + _, err = SharedServerDAO.UpdateServerConfig(tx, 1, configJSON, false) if err != nil { t.Fatal(err) } @@ -58,7 +60,8 @@ func TestNewServerDAO_md5(t *testing.T) { func TestServerDAO_genDNSName(t *testing.T) { dbs.NotifyReady() - dnsName, err := SharedServerDAO.genDNSName() + var tx *dbs.Tx + dnsName, err := SharedServerDAO.genDNSName(tx) if err != nil { t.Fatal(err) } @@ -67,8 +70,8 @@ func TestServerDAO_genDNSName(t *testing.T) { func TestServerDAO_FindAllServerDNSNamesWithDNSDomainId(t *testing.T) { dbs.NotifyReady() - - dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(2) + var tx *dbs.Tx + dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(tx, 2) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/sys_locker_dao_test.go b/internal/db/models/sys_locker_dao_test.go index 9f2faa4f..a37742a3 100644 --- a/internal/db/models/sys_locker_dao_test.go +++ b/internal/db/models/sys_locker_dao_test.go @@ -2,18 +2,21 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" ) func TestSysLockerDAO_Lock(t *testing.T) { - isOk, err := SharedSysLockerDAO.Lock("test", 600) + var tx *dbs.Tx + + isOk, err := SharedSysLockerDAO.Lock(tx, "test", 600) if err != nil { t.Fatal(err) } t.Log(isOk) if isOk { - err = SharedSysLockerDAO.Unlock("test") + err = SharedSysLockerDAO.Unlock(tx, "test") if err != nil { t.Fatal(err) } diff --git a/internal/db/models/sys_setting_dao_test.go b/internal/db/models/sys_setting_dao_test.go index b4f18eb0..fa9b43ce 100644 --- a/internal/db/models/sys_setting_dao_test.go +++ b/internal/db/models/sys_setting_dao_test.go @@ -2,16 +2,18 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "testing" ) func TestSysSettingDAO_UpdateSetting(t *testing.T) { - err := NewSysSettingDAO().UpdateSetting("hello", []byte(`"world"`)) + var tx *dbs.Tx + err := NewSysSettingDAO().UpdateSetting(tx, "hello", []byte(`"world"`)) if err != nil { t.Fatal(err) } - value, err := NewSysSettingDAO().ReadSetting("hello") + value, err := NewSysSettingDAO().ReadSetting(tx, "hello") if err != nil { t.Fatal(err) } @@ -19,12 +21,13 @@ func TestSysSettingDAO_UpdateSetting(t *testing.T) { } func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) { - err := NewSysSettingDAO().UpdateSetting("hello %d", []byte(`"world 123"`), 123) + var tx *dbs.Tx + err := NewSysSettingDAO().UpdateSetting(tx, "hello %d", []byte(`"world 123"`), 123) if err != nil { t.Fatal(err) } - value, err := NewSysSettingDAO().ReadSetting("hello %d", 123) + value, err := NewSysSettingDAO().ReadSetting(tx, "hello %d", 123) if err != nil { t.Fatal(err) } @@ -32,7 +35,8 @@ func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) { } func TestSysSettingDAO_CompareInt64Setting(t *testing.T) { - i, err := NewSysSettingDAO().CompareInt64Setting("int64", 1024) + var tx *dbs.Tx + i, err := NewSysSettingDAO().CompareInt64Setting(tx, "int64", 1024) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/user_bill_dao_test.go b/internal/db/models/user_bill_dao_test.go index ed0204e2..779ef415 100644 --- a/internal/db/models/user_bill_dao_test.go +++ b/internal/db/models/user_bill_dao_test.go @@ -10,7 +10,9 @@ import ( func TestUserBillDAO_GenerateBills(t *testing.T) { dbs.NotifyReady() - err := SharedUserBillDAO.GenerateBills(timeutil.Format("Ym")) + var tx *dbs.Tx + + err := SharedUserBillDAO.GenerateBills(tx, timeutil.Format("Ym")) if err != nil { t.Fatal(err) } diff --git a/internal/rpc/services/service_http_web.go b/internal/rpc/services/service_http_web.go index 2c33daeb..c24c3c2e 100644 --- a/internal/rpc/services/service_http_web.go +++ b/internal/rpc/services/service_http_web.go @@ -4,7 +4,10 @@ import ( "context" "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/iwind/TeaGo/dbs" ) type HTTPWebService struct { @@ -38,7 +41,11 @@ func (this *HTTPWebService) FindEnabledHTTPWeb(ctx context.Context, req *pb.Find } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -67,7 +74,11 @@ func (this *HTTPWebService) FindEnabledHTTPWebConfig(ctx context.Context, req *p } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -93,7 +104,11 @@ func (this *HTTPWebService) UpdateHTTPWeb(ctx context.Context, req *pb.UpdateHTT } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -115,7 +130,11 @@ func (this *HTTPWebService) UpdateHTTPWebGzip(ctx context.Context, req *pb.Updat } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -137,7 +156,11 @@ func (this *HTTPWebService) UpdateHTTPWebCharset(ctx context.Context, req *pb.Up } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -158,7 +181,11 @@ func (this *HTTPWebService) UpdateHTTPWebRequestHeader(ctx context.Context, req } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -180,7 +207,11 @@ func (this *HTTPWebService) UpdateHTTPWebResponseHeader(ctx context.Context, req } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -202,7 +233,11 @@ func (this *HTTPWebService) UpdateHTTPWebShutdown(ctx context.Context, req *pb.U } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -223,7 +258,11 @@ func (this *HTTPWebService) UpdateHTTPWebPages(ctx context.Context, req *pb.Upda } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -244,7 +283,11 @@ func (this *HTTPWebService) UpdateHTTPWebAccessLog(ctx context.Context, req *pb. } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -265,7 +308,11 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -286,7 +333,11 @@ func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.Upda } if userId > 0 { - // TODO 检查权限 + // 检查权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -308,7 +359,11 @@ func (this *HTTPWebService) UpdateHTTPWebFirewall(ctx context.Context, req *pb.U } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -330,7 +385,11 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb. } if userId > 0 { - // TODO 检查用户权限 + // 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -346,12 +405,18 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb. // 更改跳转到HTTPS设置 func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - // TODO 检查权限 + // 检查权限 + if userId > 0 { + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } + } tx := this.NullTx() @@ -365,12 +430,17 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re // 更改Websocket设置 func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - // TODO 检查权限 + if userId > 0 { + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } + } tx := this.NullTx() @@ -390,7 +460,10 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req * } if userId > 0 { - // TODO 检查用户权限 + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -401,3 +474,66 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req * } return this.Success() } + +// 更改主机跳转设置 +func (this *HTTPWebService) UpdateHTTPWebHostRedirects(ctx context.Context, req *pb.UpdateHTTPWebHostRedirectsRequest) (*pb.RPCSuccess, error) { + // 校验请求 + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + if userId > 0 { + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } + } + + hostRedirects := []*serverconfigs.HTTPHostRedirectConfig{} + if len(req.HostRedirectsJSON) == 0 { + return nil, errors.New("'hostRedirectsJSON' should not be empty") + } + err = json.Unmarshal(req.HostRedirectsJSON, &hostRedirects) + if err != nil { + return nil, err + } + + // 校验 + for _, redirect := range hostRedirects { + err := redirect.Init() + if err != nil { + return nil, err + } + } + + var tx *dbs.Tx + err = models.SharedHTTPWebDAO.UpdateWebHostRedirects(tx, req.WebId, hostRedirects) + if err != nil { + return nil, err + } + return this.Success() +} + +// 查找主机跳转设置 +func (this *HTTPWebService) FindHTTPWebHostRedirects(ctx context.Context, req *pb.FindHTTPWebHostRedirectsRequest) (*pb.FindHTTPWebHostRedirectsResponse, error) { + // 校验请求 + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + if userId > 0 { + err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId) + if err != nil { + return nil, err + } + } + + var tx *dbs.Tx + redirectsJSON, err := models.SharedHTTPWebDAO.FindWebHostRedirects(tx, req.WebId) + if err != nil { + return nil, err + } + return &pb.FindHTTPWebHostRedirectsResponse{HostRedirectsJSON: redirectsJSON}, nil +}