实现URL跳转功能

This commit is contained in:
GoEdgeLab
2021-01-10 17:34:35 +08:00
parent 87828af844
commit af5584c768
18 changed files with 381 additions and 62 deletions

View File

@@ -2,14 +2,16 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"runtime" "runtime"
"testing" "testing"
) )
func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) { func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) {
dao := NewAPINodeDAO() dao := NewAPINodeDAO()
var tx *dbs.Tx
{ {
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 123) apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 123)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -17,7 +19,7 @@ func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) {
} }
{ {
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 8003) apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 8003)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -2,33 +2,36 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
) )
func TestDNSDomainDAO_ExistDomainRecord(t *testing.T) { func TestDNSDomainDAO_ExistDomainRecord(t *testing.T) {
var tx *dbs.Tx
{ {
b, err := NewDNSDomainDAO().ExistDomainRecord(1, "mycluster", "A", "", "") b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 1, "mycluster", "A", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(b) t.Log(b)
} }
{ {
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "A", "", "") b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "A", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(b) t.Log(b)
} }
{ {
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "MX", "", "") b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "MX", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(b) t.Log(b)
} }
{ {
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster123", "A", "", "") b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster123", "A", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -4,12 +4,15 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"testing" "testing"
"time" "time"
) )
func TestCreateHTTPAccessLogs(t *testing.T) { func TestCreateHTTPAccessLogs(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop() err := NewDBNodeInitializer().loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -23,7 +26,7 @@ func TestCreateHTTPAccessLogs(t *testing.T) {
} }
dao := randomAccessLogDAO() dao := randomAccessLogDAO()
t.Log("dao:", dao) t.Log("dao:", dao)
err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(dao, []*pb.HTTPAccessLog{accessLog}) err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(tx, dao, []*pb.HTTPAccessLog{accessLog})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -31,12 +34,14 @@ func TestCreateHTTPAccessLogs(t *testing.T) {
} }
func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) { func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop() err := NewDBNodeInitializer().loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -51,6 +56,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) {
} }
func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) { func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop() err := NewDBNodeInitializer().loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -61,7 +68,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
times := 0 // 防止循环次数太多 times := 0 // 防止循环次数太多
for { for {
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0) accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -84,13 +91,15 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
} }
func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) { func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop() err := NewDBNodeInitializer().loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0) accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0)
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -103,6 +112,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) {
} }
func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) { func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop() err := NewDBNodeInitializer().loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -113,7 +124,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) {
times := 0 // 防止循环次数太多 times := 0 // 防止循环次数太多
for { for {
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0) accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0)
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -322,6 +322,16 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
} }
} }
// 主机跳转
if IsNotNull(web.HostRedirects) {
redirects := []*serverconfigs.HTTPHostRedirectConfig{}
err = json.Unmarshal([]byte(web.HostRedirects), &redirects)
if err != nil {
return nil, err
}
config.HostRedirects = redirects
}
return config, nil return config, nil
} }
@@ -618,3 +628,79 @@ func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(tx *dbs.Tx, locationId in
Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用
FindInt64Col(0) FindInt64Col(0)
} }
// 查找使用此Web的Server
func (this *HTTPWebDAO) FindWebServerId(tx *dbs.Tx, webId int64) (serverId int64, err error) {
if webId <= 0 {
return 0, nil
}
serverId, err = SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId)
if err != nil {
return
}
if serverId > 0 {
return
}
// web在Location中的情况
locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(tx, webId)
if err != nil {
return 0, err
}
if locationId == 0 {
return
}
webId, err = this.FindEnabledWebIdWithLocationId(tx, locationId)
if err != nil {
return
}
if webId <= 0 {
return
}
// 第二轮查找
return this.FindWebServerId(tx, webId)
}
// 检查用户权限
func (this *HTTPWebDAO) CheckUserWeb(tx *dbs.Tx, userId int64, webId int64) error {
serverId, err := this.FindWebServerId(tx, webId)
if err != nil {
return err
}
if serverId == 0 {
return ErrNotFound
}
return SharedServerDAO.CheckUserServer(tx, serverId, userId)
}
// 设置主机跳转
func (this *HTTPWebDAO) UpdateWebHostRedirects(tx *dbs.Tx, webId int64, hostRedirects []*serverconfigs.HTTPHostRedirectConfig) error {
if webId <= 0 {
return errors.New("invalid ")
}
if hostRedirects == nil {
hostRedirects = []*serverconfigs.HTTPHostRedirectConfig{}
}
hostRedirectsJSON, err := json.Marshal(hostRedirects)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(webId).
Set("hostRedirects", hostRedirectsJSON).
Update()
return err
}
// 查找主机跳转
func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, error) {
col, err := this.Query(tx).
Pk(webId).
Result("hostRedirects").
FindStringCol("")
if err != nil {
return nil, err
}
return []byte(col), nil
}

View File

@@ -7,15 +7,16 @@ import (
) )
func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) { func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) {
var tx *dbs.Tx
{ {
err := SharedHTTPWebDAO.UpdateWebShutdown(1, []byte("{}")) err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, []byte("{}"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
{ {
err := SharedHTTPWebDAO.UpdateWebShutdown(1, nil) err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -27,15 +28,50 @@ func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) {
func TestHTTPWebDAO_FindAllWebIdsWithHTTPFirewallPolicyId(t *testing.T) { func TestHTTPWebDAO_FindAllWebIdsWithHTTPFirewallPolicyId(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(9) var tx *dbs.Tx
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, 9)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("webIds:", webIds) t.Log("webIds:", webIds)
count, err := SharedServerDAO.CountEnabledServersWithWebIds(webIds) count, err := SharedServerDAO.CountEnabledServersWithWebIds(tx, webIds)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("count:", count) t.Log("count:", count)
} }
func TestHTTPWebDAO_FindWebServerId(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
// server
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 60)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
// location
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 45)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 100)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
}

View File

@@ -26,6 +26,7 @@ type HTTPWeb struct {
Locations string `field:"locations"` // 路径规则配置 Locations string `field:"locations"` // 路径规则配置
Websocket string `field:"websocket"` // Websocket设置 Websocket string `field:"websocket"` // Websocket设置
RewriteRules string `field:"rewriteRules"` // 重写规则配置 RewriteRules string `field:"rewriteRules"` // 重写规则配置
HostRedirects string `field:"hostRedirects"` // 域名跳转
} }
type HTTPWebOperator struct { type HTTPWebOperator struct {
@@ -53,6 +54,7 @@ type HTTPWebOperator struct {
Locations interface{} // 路径规则配置 Locations interface{} // 路径规则配置
Websocket interface{} // Websocket设置 Websocket interface{} // Websocket设置
RewriteRules interface{} // 重写规则配置 RewriteRules interface{} // 重写规则配置
HostRedirects interface{} // 域名跳转
} }
func NewHTTPWebOperator() *HTTPWebOperator { func NewHTTPWebOperator() *HTTPWebOperator {

View File

@@ -10,8 +10,10 @@ import (
func TestIPListDAO_IncreaseVersion(t *testing.T) { func TestIPListDAO_IncreaseVersion(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx
dao := NewIPListDAO() dao := NewIPListDAO()
version, err := dao.IncreaseVersion() version, err := dao.IncreaseVersion(tx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -23,8 +25,10 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx
dao := NewIPListDAO() dao := NewIPListDAO()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _ = dao.IncreaseVersion() _, _ = dao.IncreaseVersion(tx)
} }
} }

View File

@@ -2,13 +2,16 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
"time" "time"
) )
func TestMessageDAO_CreateClusterMessage(t *testing.T) { func TestMessageDAO_CreateClusterMessage(t *testing.T) {
var tx *dbs.Tx
dao := NewMessageDAO() dao := NewMessageDAO()
err := dao.CreateClusterMessage(1, "test", "error", "123", []byte("456")) err := dao.CreateClusterMessage(tx, 1, "test", "error", "123", []byte("456"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -16,8 +19,10 @@ func TestMessageDAO_CreateClusterMessage(t *testing.T) {
} }
func TestMessageDAO_DeleteMessagesBeforeDay(t *testing.T) { func TestMessageDAO_DeleteMessagesBeforeDay(t *testing.T) {
var tx *dbs.Tx
dao := NewMessageDAO() dao := NewMessageDAO()
err := dao.DeleteMessagesBeforeDay(time.Now()) err := dao.DeleteMessagesBeforeDay(tx, time.Now())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -7,7 +7,8 @@ import (
) )
func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) { func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) {
nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(1) var tx *dbs.Tx
nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -15,7 +16,8 @@ func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) {
} }
func TestNodeDAO_FindChangedClusterIds(t *testing.T) { func TestNodeDAO_FindChangedClusterIds(t *testing.T) {
clusterIds, err := SharedNodeDAO.FindChangedClusterIds() var tx *dbs.Tx
clusterIds, err := SharedNodeDAO.FindChangedClusterIds(tx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -24,7 +26,8 @@ func TestNodeDAO_FindChangedClusterIds(t *testing.T) {
func TestNodeDAO_UpdateNodeUp(t *testing.T) { func TestNodeDAO_UpdateNodeUp(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
isChanged, err := SharedNodeDAO.UpdateNodeUp(57, false, 3, 3) var tx *dbs.Tx
isChanged, err := SharedNodeDAO.UpdateNodeUp(tx, 57, false, 3, 3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -2,11 +2,13 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
) )
func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) { func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) {
config, err := SharedOriginDAO.ComposeOriginConfig(1) var tx *dbs.Tx
config, err := SharedOriginDAO.ComposeOriginConfig(tx, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -2,11 +2,13 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
) )
func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) { func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) {
config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(1) var tx *dbs.Tx
config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -9,6 +9,7 @@ import (
) )
func TestServerDailyStatDAO_SaveStats(t *testing.T) { func TestServerDailyStatDAO_SaveStats(t *testing.T) {
var tx *dbs.Tx
stats := []*pb.ServerDailyStat{ stats := []*pb.ServerDailyStat{
{ {
ServerId: 1, ServerId: 1,
@@ -17,7 +18,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) {
CreatedAt: 1607671488, CreatedAt: 1607671488,
}, },
} }
err := NewServerDailyStatDAO().SaveStats(stats) err := NewServerDailyStatDAO().SaveStats(tx, stats)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -25,6 +26,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) {
} }
func TestServerDailyStatDAO_SaveStats2(t *testing.T) { func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
var tx *dbs.Tx
stats := []*pb.ServerDailyStat{ stats := []*pb.ServerDailyStat{
{ {
ServerId: 1, ServerId: 1,
@@ -33,7 +35,7 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
CreatedAt: 1607671488, CreatedAt: 1607671488,
}, },
} }
err := NewServerDailyStatDAO().SaveStats(stats) err := NewServerDailyStatDAO().SaveStats(tx, stats)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -42,7 +44,8 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
func TestServerDailyStatDAO_SumUserMonthly(t *testing.T) { func TestServerDailyStatDAO_SumUserMonthly(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
bytes, err := NewServerDailyStatDAO().SumUserMonthly(1, 1, timeutil.Format("Ym")) var tx *dbs.Tx
bytes, err := NewServerDailyStatDAO().SumUserMonthly(tx, 1, 1, timeutil.Format("Ym"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1077,6 +1077,18 @@ func (this *ServerDAO) FindAllEnabledServersWithUserId(tx *dbs.Tx, userId int64)
return return
} }
// 根据WebId查找ServerId
func (this *ServerDAO) FindEnabledServerIdWithWebId(tx *dbs.Tx, webId int64) (serverId int64, err error) {
if webId <= 0 {
return 0, nil
}
return this.Query(tx).
State(ServerStateEnabled).
Attr("webId", webId).
ResultPk().
FindInt64Col(0)
}
// 生成DNS Name // 生成DNS Name
func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) { func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) {
for { for {

View File

@@ -11,7 +11,8 @@ import (
func TestServerDAO_ComposeServerConfig(t *testing.T) { func TestServerDAO_ComposeServerConfig(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
config, err := SharedServerDAO.ComposeServerConfig(1) var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(tx, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -20,7 +21,8 @@ func TestServerDAO_ComposeServerConfig(t *testing.T) {
func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) { func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
config, err := SharedServerDAO.ComposeServerConfig(14) var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(tx, 14)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -29,8 +31,8 @@ func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) {
func TestServerDAO_UpdateServerConfig(t *testing.T) { func TestServerDAO_UpdateServerConfig(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(1) config, err := SharedServerDAO.ComposeServerConfig(tx, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -39,7 +41,7 @@ func TestServerDAO_UpdateServerConfig(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = SharedServerDAO.UpdateServerConfig(1, configJSON, false) _, err = SharedServerDAO.UpdateServerConfig(tx, 1, configJSON, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -58,7 +60,8 @@ func TestNewServerDAO_md5(t *testing.T) {
func TestServerDAO_genDNSName(t *testing.T) { func TestServerDAO_genDNSName(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
dnsName, err := SharedServerDAO.genDNSName() var tx *dbs.Tx
dnsName, err := SharedServerDAO.genDNSName(tx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -67,8 +70,8 @@ func TestServerDAO_genDNSName(t *testing.T) {
func TestServerDAO_FindAllServerDNSNamesWithDNSDomainId(t *testing.T) { func TestServerDAO_FindAllServerDNSNamesWithDNSDomainId(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx
dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(2) dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(tx, 2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -2,18 +2,21 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
) )
func TestSysLockerDAO_Lock(t *testing.T) { func TestSysLockerDAO_Lock(t *testing.T) {
isOk, err := SharedSysLockerDAO.Lock("test", 600) var tx *dbs.Tx
isOk, err := SharedSysLockerDAO.Lock(tx, "test", 600)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(isOk) t.Log(isOk)
if isOk { if isOk {
err = SharedSysLockerDAO.Unlock("test") err = SharedSysLockerDAO.Unlock(tx, "test")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -2,16 +2,18 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing" "testing"
) )
func TestSysSettingDAO_UpdateSetting(t *testing.T) { func TestSysSettingDAO_UpdateSetting(t *testing.T) {
err := NewSysSettingDAO().UpdateSetting("hello", []byte(`"world"`)) var tx *dbs.Tx
err := NewSysSettingDAO().UpdateSetting(tx, "hello", []byte(`"world"`))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
value, err := NewSysSettingDAO().ReadSetting("hello") value, err := NewSysSettingDAO().ReadSetting(tx, "hello")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -19,12 +21,13 @@ func TestSysSettingDAO_UpdateSetting(t *testing.T) {
} }
func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) { func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) {
err := NewSysSettingDAO().UpdateSetting("hello %d", []byte(`"world 123"`), 123) var tx *dbs.Tx
err := NewSysSettingDAO().UpdateSetting(tx, "hello %d", []byte(`"world 123"`), 123)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
value, err := NewSysSettingDAO().ReadSetting("hello %d", 123) value, err := NewSysSettingDAO().ReadSetting(tx, "hello %d", 123)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -32,7 +35,8 @@ func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) {
} }
func TestSysSettingDAO_CompareInt64Setting(t *testing.T) { func TestSysSettingDAO_CompareInt64Setting(t *testing.T) {
i, err := NewSysSettingDAO().CompareInt64Setting("int64", 1024) var tx *dbs.Tx
i, err := NewSysSettingDAO().CompareInt64Setting(tx, "int64", 1024)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -10,7 +10,9 @@ import (
func TestUserBillDAO_GenerateBills(t *testing.T) { func TestUserBillDAO_GenerateBills(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
err := SharedUserBillDAO.GenerateBills(timeutil.Format("Ym")) var tx *dbs.Tx
err := SharedUserBillDAO.GenerateBills(tx, timeutil.Format("Ym"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -4,7 +4,10 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
) )
type HTTPWebService struct { type HTTPWebService struct {
@@ -38,7 +41,11 @@ func (this *HTTPWebService) FindEnabledHTTPWeb(ctx context.Context, req *pb.Find
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -67,7 +74,11 @@ func (this *HTTPWebService) FindEnabledHTTPWebConfig(ctx context.Context, req *p
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -93,7 +104,11 @@ func (this *HTTPWebService) UpdateHTTPWeb(ctx context.Context, req *pb.UpdateHTT
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -115,7 +130,11 @@ func (this *HTTPWebService) UpdateHTTPWebGzip(ctx context.Context, req *pb.Updat
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -137,7 +156,11 @@ func (this *HTTPWebService) UpdateHTTPWebCharset(ctx context.Context, req *pb.Up
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -158,7 +181,11 @@ func (this *HTTPWebService) UpdateHTTPWebRequestHeader(ctx context.Context, req
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -180,7 +207,11 @@ func (this *HTTPWebService) UpdateHTTPWebResponseHeader(ctx context.Context, req
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -202,7 +233,11 @@ func (this *HTTPWebService) UpdateHTTPWebShutdown(ctx context.Context, req *pb.U
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -223,7 +258,11 @@ func (this *HTTPWebService) UpdateHTTPWebPages(ctx context.Context, req *pb.Upda
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -244,7 +283,11 @@ func (this *HTTPWebService) UpdateHTTPWebAccessLog(ctx context.Context, req *pb.
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -265,7 +308,11 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -286,7 +333,11 @@ func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.Upda
} }
if userId > 0 { if userId > 0 {
// TODO 检查权限 // 检查权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -308,7 +359,11 @@ func (this *HTTPWebService) UpdateHTTPWebFirewall(ctx context.Context, req *pb.U
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -330,7 +385,11 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb.
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 // 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -346,12 +405,18 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb.
// 更改跳转到HTTPS设置 // 更改跳转到HTTPS设置
func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCSuccess, error) { func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCSuccess, error) {
// 校验请求 // 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO 检查权限 // 检查权限
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx() tx := this.NullTx()
@@ -365,12 +430,17 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re
// 更改Websocket设置 // 更改Websocket设置
func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCSuccess, error) { func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCSuccess, error) {
// 校验请求 // 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO 检查权限 if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx() tx := this.NullTx()
@@ -390,7 +460,10 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req *
} }
if userId > 0 { if userId > 0 {
// TODO 检查用户权限 err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
} }
tx := this.NullTx() tx := this.NullTx()
@@ -401,3 +474,66 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req *
} }
return this.Success() return this.Success()
} }
// 更改主机跳转设置
func (this *HTTPWebService) UpdateHTTPWebHostRedirects(ctx context.Context, req *pb.UpdateHTTPWebHostRedirectsRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
hostRedirects := []*serverconfigs.HTTPHostRedirectConfig{}
if len(req.HostRedirectsJSON) == 0 {
return nil, errors.New("'hostRedirectsJSON' should not be empty")
}
err = json.Unmarshal(req.HostRedirectsJSON, &hostRedirects)
if err != nil {
return nil, err
}
// 校验
for _, redirect := range hostRedirects {
err := redirect.Init()
if err != nil {
return nil, err
}
}
var tx *dbs.Tx
err = models.SharedHTTPWebDAO.UpdateWebHostRedirects(tx, req.WebId, hostRedirects)
if err != nil {
return nil, err
}
return this.Success()
}
// 查找主机跳转设置
func (this *HTTPWebService) FindHTTPWebHostRedirects(ctx context.Context, req *pb.FindHTTPWebHostRedirectsRequest) (*pb.FindHTTPWebHostRedirectsResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
var tx *dbs.Tx
redirectsJSON, err := models.SharedHTTPWebDAO.FindWebHostRedirects(tx, req.WebId)
if err != nil {
return nil, err
}
return &pb.FindHTTPWebHostRedirectsResponse{HostRedirectsJSON: redirectsJSON}, nil
}