实现URL跳转功能

This commit is contained in:
GoEdgeLab
2021-01-10 17:34:35 +08:00
parent 87828af844
commit af5584c768
18 changed files with 381 additions and 62 deletions

View File

@@ -2,14 +2,16 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"runtime"
"testing"
)
func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) {
dao := NewAPINodeDAO()
var tx *dbs.Tx
{
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 123)
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 123)
if err != nil {
t.Fatal(err)
}
@@ -17,7 +19,7 @@ func TestAPINodeDAO_FindEnabledAPINodeIdWithAddr(t *testing.T) {
}
{
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr("http", "127.0.0.1", 8003)
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(tx, "http", "127.0.0.1", 8003)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,33 +2,36 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestDNSDomainDAO_ExistDomainRecord(t *testing.T) {
var tx *dbs.Tx
{
b, err := NewDNSDomainDAO().ExistDomainRecord(1, "mycluster", "A", "", "")
b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 1, "mycluster", "A", "", "")
if err != nil {
t.Fatal(err)
}
t.Log(b)
}
{
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "A", "", "")
b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "A", "", "")
if err != nil {
t.Fatal(err)
}
t.Log(b)
}
{
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster", "MX", "", "")
b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster", "MX", "", "")
if err != nil {
t.Fatal(err)
}
t.Log(b)
}
{
b, err := NewDNSDomainDAO().ExistDomainRecord(2, "mycluster123", "A", "", "")
b, err := NewDNSDomainDAO().ExistDomainRecord(tx, 2, "mycluster123", "A", "", "")
if err != nil {
t.Fatal(err)
}

View File

@@ -4,12 +4,15 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
timeutil "github.com/iwind/TeaGo/utils/time"
"testing"
"time"
)
func TestCreateHTTPAccessLogs(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop()
if err != nil {
t.Fatal(err)
@@ -23,7 +26,7 @@ func TestCreateHTTPAccessLogs(t *testing.T) {
}
dao := randomAccessLogDAO()
t.Log("dao:", dao)
err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(dao, []*pb.HTTPAccessLog{accessLog})
err = SharedHTTPAccessLogDAO.CreateHTTPAccessLogsWithDAO(tx, dao, []*pb.HTTPAccessLog{accessLog})
if err != nil {
t.Fatal(err)
}
@@ -31,12 +34,14 @@ func TestCreateHTTPAccessLogs(t *testing.T) {
}
func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop()
if err != nil {
t.Fatal(err)
}
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "", 10, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
if err != nil {
t.Fatal(err)
}
@@ -51,6 +56,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) {
}
func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop()
if err != nil {
t.Fatal(err)
@@ -61,7 +68,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
times := 0 // 防止循环次数太多
for {
before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd"), 0, false, false, 0, 0, 0)
cost := time.Since(before).Seconds()
if err != nil {
t.Fatal(err)
@@ -84,13 +91,15 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
}
func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop()
if err != nil {
t.Fatal(err)
}
before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs("16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0)
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, true, false, 0, 0, 0)
cost := time.Since(before).Seconds()
if err != nil {
t.Fatal(err)
@@ -103,6 +112,8 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) {
}
func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) {
var tx *dbs.Tx
err := NewDBNodeInitializer().loop()
if err != nil {
t.Fatal(err)
@@ -113,7 +124,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) {
times := 0 // 防止循环次数太多
for {
before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0)
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, false, false, 0, 0, 0)
cost := time.Since(before).Seconds()
if err != nil {
t.Fatal(err)

View File

@@ -322,6 +322,16 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
}
}
// 主机跳转
if IsNotNull(web.HostRedirects) {
redirects := []*serverconfigs.HTTPHostRedirectConfig{}
err = json.Unmarshal([]byte(web.HostRedirects), &redirects)
if err != nil {
return nil, err
}
config.HostRedirects = redirects
}
return config, nil
}
@@ -618,3 +628,79 @@ func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(tx *dbs.Tx, locationId in
Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用
FindInt64Col(0)
}
// 查找使用此Web的Server
func (this *HTTPWebDAO) FindWebServerId(tx *dbs.Tx, webId int64) (serverId int64, err error) {
if webId <= 0 {
return 0, nil
}
serverId, err = SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId)
if err != nil {
return
}
if serverId > 0 {
return
}
// web在Location中的情况
locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(tx, webId)
if err != nil {
return 0, err
}
if locationId == 0 {
return
}
webId, err = this.FindEnabledWebIdWithLocationId(tx, locationId)
if err != nil {
return
}
if webId <= 0 {
return
}
// 第二轮查找
return this.FindWebServerId(tx, webId)
}
// 检查用户权限
func (this *HTTPWebDAO) CheckUserWeb(tx *dbs.Tx, userId int64, webId int64) error {
serverId, err := this.FindWebServerId(tx, webId)
if err != nil {
return err
}
if serverId == 0 {
return ErrNotFound
}
return SharedServerDAO.CheckUserServer(tx, serverId, userId)
}
// 设置主机跳转
func (this *HTTPWebDAO) UpdateWebHostRedirects(tx *dbs.Tx, webId int64, hostRedirects []*serverconfigs.HTTPHostRedirectConfig) error {
if webId <= 0 {
return errors.New("invalid ")
}
if hostRedirects == nil {
hostRedirects = []*serverconfigs.HTTPHostRedirectConfig{}
}
hostRedirectsJSON, err := json.Marshal(hostRedirects)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(webId).
Set("hostRedirects", hostRedirectsJSON).
Update()
return err
}
// 查找主机跳转
func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, error) {
col, err := this.Query(tx).
Pk(webId).
Result("hostRedirects").
FindStringCol("")
if err != nil {
return nil, err
}
return []byte(col), nil
}

View File

@@ -7,15 +7,16 @@ import (
)
func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) {
var tx *dbs.Tx
{
err := SharedHTTPWebDAO.UpdateWebShutdown(1, []byte("{}"))
err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, []byte("{}"))
if err != nil {
t.Fatal(err)
}
}
{
err := SharedHTTPWebDAO.UpdateWebShutdown(1, nil)
err := SharedHTTPWebDAO.UpdateWebShutdown(tx, 1, nil)
if err != nil {
t.Fatal(err)
}
@@ -27,15 +28,50 @@ func TestHTTPWebDAO_UpdateWebShutdown(t *testing.T) {
func TestHTTPWebDAO_FindAllWebIdsWithHTTPFirewallPolicyId(t *testing.T) {
dbs.NotifyReady()
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(9)
var tx *dbs.Tx
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, 9)
if err != nil {
t.Fatal(err)
}
t.Log("webIds:", webIds)
count, err := SharedServerDAO.CountEnabledServersWithWebIds(webIds)
count, err := SharedServerDAO.CountEnabledServersWithWebIds(tx, webIds)
if err != nil {
t.Fatal(err)
}
t.Log("count:", count)
}
func TestHTTPWebDAO_FindWebServerId(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
// server
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 60)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
// location
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 45)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
{
serverId, err := SharedHTTPWebDAO.FindWebServerId(tx, 100)
if err != nil {
t.Fatal(err)
}
t.Log("serverId:", serverId)
}
}

View File

@@ -26,6 +26,7 @@ type HTTPWeb struct {
Locations string `field:"locations"` // 路径规则配置
Websocket string `field:"websocket"` // Websocket设置
RewriteRules string `field:"rewriteRules"` // 重写规则配置
HostRedirects string `field:"hostRedirects"` // 域名跳转
}
type HTTPWebOperator struct {
@@ -53,6 +54,7 @@ type HTTPWebOperator struct {
Locations interface{} // 路径规则配置
Websocket interface{} // Websocket设置
RewriteRules interface{} // 重写规则配置
HostRedirects interface{} // 域名跳转
}
func NewHTTPWebOperator() *HTTPWebOperator {

View File

@@ -10,8 +10,10 @@ import (
func TestIPListDAO_IncreaseVersion(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
dao := NewIPListDAO()
version, err := dao.IncreaseVersion()
version, err := dao.IncreaseVersion(tx)
if err != nil {
t.Fatal(err)
}
@@ -23,8 +25,10 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
dbs.NotifyReady()
var tx *dbs.Tx
dao := NewIPListDAO()
for i := 0; i < b.N; i++ {
_, _ = dao.IncreaseVersion()
_, _ = dao.IncreaseVersion(tx)
}
}

View File

@@ -2,13 +2,16 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
"time"
)
func TestMessageDAO_CreateClusterMessage(t *testing.T) {
var tx *dbs.Tx
dao := NewMessageDAO()
err := dao.CreateClusterMessage(1, "test", "error", "123", []byte("456"))
err := dao.CreateClusterMessage(tx, 1, "test", "error", "123", []byte("456"))
if err != nil {
t.Fatal(err)
}
@@ -16,8 +19,10 @@ func TestMessageDAO_CreateClusterMessage(t *testing.T) {
}
func TestMessageDAO_DeleteMessagesBeforeDay(t *testing.T) {
var tx *dbs.Tx
dao := NewMessageDAO()
err := dao.DeleteMessagesBeforeDay(time.Now())
err := dao.DeleteMessagesBeforeDay(tx, time.Now())
if err != nil {
t.Fatal(err)
}

View File

@@ -7,7 +7,8 @@ import (
)
func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) {
nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(1)
var tx *dbs.Tx
nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, 1)
if err != nil {
t.Fatal(err)
}
@@ -15,7 +16,8 @@ func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) {
}
func TestNodeDAO_FindChangedClusterIds(t *testing.T) {
clusterIds, err := SharedNodeDAO.FindChangedClusterIds()
var tx *dbs.Tx
clusterIds, err := SharedNodeDAO.FindChangedClusterIds(tx)
if err != nil {
t.Fatal(err)
}
@@ -24,7 +26,8 @@ func TestNodeDAO_FindChangedClusterIds(t *testing.T) {
func TestNodeDAO_UpdateNodeUp(t *testing.T) {
dbs.NotifyReady()
isChanged, err := SharedNodeDAO.UpdateNodeUp(57, false, 3, 3)
var tx *dbs.Tx
isChanged, err := SharedNodeDAO.UpdateNodeUp(tx, 57, false, 3, 3)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,11 +2,13 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) {
config, err := SharedOriginDAO.ComposeOriginConfig(1)
var tx *dbs.Tx
config, err := SharedOriginDAO.ComposeOriginConfig(tx, 1)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,11 +2,13 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) {
config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(1)
var tx *dbs.Tx
config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, 1)
if err != nil {
t.Fatal(err)
}

View File

@@ -9,6 +9,7 @@ import (
)
func TestServerDailyStatDAO_SaveStats(t *testing.T) {
var tx *dbs.Tx
stats := []*pb.ServerDailyStat{
{
ServerId: 1,
@@ -17,7 +18,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) {
CreatedAt: 1607671488,
},
}
err := NewServerDailyStatDAO().SaveStats(stats)
err := NewServerDailyStatDAO().SaveStats(tx, stats)
if err != nil {
t.Fatal(err)
}
@@ -25,6 +26,7 @@ func TestServerDailyStatDAO_SaveStats(t *testing.T) {
}
func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
var tx *dbs.Tx
stats := []*pb.ServerDailyStat{
{
ServerId: 1,
@@ -33,7 +35,7 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
CreatedAt: 1607671488,
},
}
err := NewServerDailyStatDAO().SaveStats(stats)
err := NewServerDailyStatDAO().SaveStats(tx, stats)
if err != nil {
t.Fatal(err)
}
@@ -42,7 +44,8 @@ func TestServerDailyStatDAO_SaveStats2(t *testing.T) {
func TestServerDailyStatDAO_SumUserMonthly(t *testing.T) {
dbs.NotifyReady()
bytes, err := NewServerDailyStatDAO().SumUserMonthly(1, 1, timeutil.Format("Ym"))
var tx *dbs.Tx
bytes, err := NewServerDailyStatDAO().SumUserMonthly(tx, 1, 1, timeutil.Format("Ym"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1077,6 +1077,18 @@ func (this *ServerDAO) FindAllEnabledServersWithUserId(tx *dbs.Tx, userId int64)
return
}
// 根据WebId查找ServerId
func (this *ServerDAO) FindEnabledServerIdWithWebId(tx *dbs.Tx, webId int64) (serverId int64, err error) {
if webId <= 0 {
return 0, nil
}
return this.Query(tx).
State(ServerStateEnabled).
Attr("webId", webId).
ResultPk().
FindInt64Col(0)
}
// 生成DNS Name
func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) {
for {

View File

@@ -11,7 +11,8 @@ import (
func TestServerDAO_ComposeServerConfig(t *testing.T) {
dbs.NotifyReady()
config, err := SharedServerDAO.ComposeServerConfig(1)
var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(tx, 1)
if err != nil {
t.Fatal(err)
}
@@ -20,7 +21,8 @@ func TestServerDAO_ComposeServerConfig(t *testing.T) {
func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) {
dbs.NotifyReady()
config, err := SharedServerDAO.ComposeServerConfig(14)
var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(tx, 14)
if err != nil {
t.Fatal(err)
}
@@ -29,8 +31,8 @@ func TestServerDAO_ComposeServerConfig_AliasServerNames(t *testing.T) {
func TestServerDAO_UpdateServerConfig(t *testing.T) {
dbs.NotifyReady()
config, err := SharedServerDAO.ComposeServerConfig(1)
var tx *dbs.Tx
config, err := SharedServerDAO.ComposeServerConfig(tx, 1)
if err != nil {
t.Fatal(err)
}
@@ -39,7 +41,7 @@ func TestServerDAO_UpdateServerConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, err = SharedServerDAO.UpdateServerConfig(1, configJSON, false)
_, err = SharedServerDAO.UpdateServerConfig(tx, 1, configJSON, false)
if err != nil {
t.Fatal(err)
}
@@ -58,7 +60,8 @@ func TestNewServerDAO_md5(t *testing.T) {
func TestServerDAO_genDNSName(t *testing.T) {
dbs.NotifyReady()
dnsName, err := SharedServerDAO.genDNSName()
var tx *dbs.Tx
dnsName, err := SharedServerDAO.genDNSName(tx)
if err != nil {
t.Fatal(err)
}
@@ -67,8 +70,8 @@ func TestServerDAO_genDNSName(t *testing.T) {
func TestServerDAO_FindAllServerDNSNamesWithDNSDomainId(t *testing.T) {
dbs.NotifyReady()
dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(2)
var tx *dbs.Tx
dnsNames, err := SharedServerDAO.FindAllServerDNSNamesWithDNSDomainId(tx, 2)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,18 +2,21 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestSysLockerDAO_Lock(t *testing.T) {
isOk, err := SharedSysLockerDAO.Lock("test", 600)
var tx *dbs.Tx
isOk, err := SharedSysLockerDAO.Lock(tx, "test", 600)
if err != nil {
t.Fatal(err)
}
t.Log(isOk)
if isOk {
err = SharedSysLockerDAO.Unlock("test")
err = SharedSysLockerDAO.Unlock(tx, "test")
if err != nil {
t.Fatal(err)
}

View File

@@ -2,16 +2,18 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestSysSettingDAO_UpdateSetting(t *testing.T) {
err := NewSysSettingDAO().UpdateSetting("hello", []byte(`"world"`))
var tx *dbs.Tx
err := NewSysSettingDAO().UpdateSetting(tx, "hello", []byte(`"world"`))
if err != nil {
t.Fatal(err)
}
value, err := NewSysSettingDAO().ReadSetting("hello")
value, err := NewSysSettingDAO().ReadSetting(tx, "hello")
if err != nil {
t.Fatal(err)
}
@@ -19,12 +21,13 @@ func TestSysSettingDAO_UpdateSetting(t *testing.T) {
}
func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) {
err := NewSysSettingDAO().UpdateSetting("hello %d", []byte(`"world 123"`), 123)
var tx *dbs.Tx
err := NewSysSettingDAO().UpdateSetting(tx, "hello %d", []byte(`"world 123"`), 123)
if err != nil {
t.Fatal(err)
}
value, err := NewSysSettingDAO().ReadSetting("hello %d", 123)
value, err := NewSysSettingDAO().ReadSetting(tx, "hello %d", 123)
if err != nil {
t.Fatal(err)
}
@@ -32,7 +35,8 @@ func TestSysSettingDAO_UpdateSetting_Args(t *testing.T) {
}
func TestSysSettingDAO_CompareInt64Setting(t *testing.T) {
i, err := NewSysSettingDAO().CompareInt64Setting("int64", 1024)
var tx *dbs.Tx
i, err := NewSysSettingDAO().CompareInt64Setting(tx, "int64", 1024)
if err != nil {
t.Fatal(err)
}

View File

@@ -10,7 +10,9 @@ import (
func TestUserBillDAO_GenerateBills(t *testing.T) {
dbs.NotifyReady()
err := SharedUserBillDAO.GenerateBills(timeutil.Format("Ym"))
var tx *dbs.Tx
err := SharedUserBillDAO.GenerateBills(tx, timeutil.Format("Ym"))
if err != nil {
t.Fatal(err)
}

View File

@@ -4,7 +4,10 @@ import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
)
type HTTPWebService struct {
@@ -38,7 +41,11 @@ func (this *HTTPWebService) FindEnabledHTTPWeb(ctx context.Context, req *pb.Find
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -67,7 +74,11 @@ func (this *HTTPWebService) FindEnabledHTTPWebConfig(ctx context.Context, req *p
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -93,7 +104,11 @@ func (this *HTTPWebService) UpdateHTTPWeb(ctx context.Context, req *pb.UpdateHTT
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -115,7 +130,11 @@ func (this *HTTPWebService) UpdateHTTPWebGzip(ctx context.Context, req *pb.Updat
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -137,7 +156,11 @@ func (this *HTTPWebService) UpdateHTTPWebCharset(ctx context.Context, req *pb.Up
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -158,7 +181,11 @@ func (this *HTTPWebService) UpdateHTTPWebRequestHeader(ctx context.Context, req
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -180,7 +207,11 @@ func (this *HTTPWebService) UpdateHTTPWebResponseHeader(ctx context.Context, req
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -202,7 +233,11 @@ func (this *HTTPWebService) UpdateHTTPWebShutdown(ctx context.Context, req *pb.U
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -223,7 +258,11 @@ func (this *HTTPWebService) UpdateHTTPWebPages(ctx context.Context, req *pb.Upda
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -244,7 +283,11 @@ func (this *HTTPWebService) UpdateHTTPWebAccessLog(ctx context.Context, req *pb.
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -265,7 +308,11 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -286,7 +333,11 @@ func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.Upda
}
if userId > 0 {
// TODO 检查权限
// 检查权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -308,7 +359,11 @@ func (this *HTTPWebService) UpdateHTTPWebFirewall(ctx context.Context, req *pb.U
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -330,7 +385,11 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb.
}
if userId > 0 {
// TODO 检查用户权限
// 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -346,12 +405,18 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb.
// 更改跳转到HTTPS设置
func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
// TODO 检查权限
// 检查权限
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -365,12 +430,17 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re
// 更改Websocket设置
func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
// TODO 检查权限
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -390,7 +460,10 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req *
}
if userId > 0 {
// TODO 检查用户权限
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
tx := this.NullTx()
@@ -401,3 +474,66 @@ func (this *HTTPWebService) UpdateHTTPWebRewriteRules(ctx context.Context, req *
}
return this.Success()
}
// 更改主机跳转设置
func (this *HTTPWebService) UpdateHTTPWebHostRedirects(ctx context.Context, req *pb.UpdateHTTPWebHostRedirectsRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
hostRedirects := []*serverconfigs.HTTPHostRedirectConfig{}
if len(req.HostRedirectsJSON) == 0 {
return nil, errors.New("'hostRedirectsJSON' should not be empty")
}
err = json.Unmarshal(req.HostRedirectsJSON, &hostRedirects)
if err != nil {
return nil, err
}
// 校验
for _, redirect := range hostRedirects {
err := redirect.Init()
if err != nil {
return nil, err
}
}
var tx *dbs.Tx
err = models.SharedHTTPWebDAO.UpdateWebHostRedirects(tx, req.WebId, hostRedirects)
if err != nil {
return nil, err
}
return this.Success()
}
// 查找主机跳转设置
func (this *HTTPWebService) FindHTTPWebHostRedirects(ctx context.Context, req *pb.FindHTTPWebHostRedirectsRequest) (*pb.FindHTTPWebHostRedirectsResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(nil, userId, req.WebId)
if err != nil {
return nil, err
}
}
var tx *dbs.Tx
redirectsJSON, err := models.SharedHTTPWebDAO.FindWebHostRedirects(tx, req.WebId)
if err != nil {
return nil, err
}
return &pb.FindHTTPWebHostRedirectsResponse{HostRedirectsJSON: redirectsJSON}, nil
}