diff --git a/go.mod b/go.mod index b6e2981..7e45ca1 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/dop251/goja v0.0.0-20210804101310-32956a348b49 github.com/go-ole/go-ole v1.2.4 // indirect github.com/go-yaml/yaml v2.1.0+incompatible - github.com/golang/protobuf v1.5.2 github.com/iwind/TeaGo v0.0.0-20211026123858-7de7a21cad24 github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 diff --git a/internal/nodes/http_access_log_queue.go b/internal/nodes/http_access_log_queue.go index 8bbb2f7..9328b6f 100644 --- a/internal/nodes/http_access_log_queue.go +++ b/internal/nodes/http_access_log_queue.go @@ -5,7 +5,9 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/utils" + "reflect" "strconv" + "strings" "time" ) @@ -100,8 +102,30 @@ Loop: _, err := this.rpcClient.HTTPAccessLogRPC().CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs}) if err != nil { + // 是否包含了invalid UTF-8 + if strings.Contains(err.Error(), "string field contains invalid UTF-8") { + for _, accessLog := range accessLogs { + this.toValidUTF8(accessLog) + } + + // 重新提交 + _, err = this.rpcClient.HTTPAccessLogRPC().CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs}) + return err + } + return err } return nil } + +func (this *HTTPAccessLogQueue) toValidUTF8(accessLog *pb.HTTPAccessLog) { + var v = reflect.Indirect(reflect.ValueOf(accessLog)) + var countFields = v.NumField() + for i := 0; i < countFields; i++ { + var field = v.Field(i) + if field.Kind() == reflect.String { + field.SetString(strings.ToValidUTF8(field.String(), "")) + } + } +} diff --git a/internal/nodes/http_access_log_queue_test.go b/internal/nodes/http_access_log_queue_test.go new file mode 100644 index 0000000..be0d9f3 --- /dev/null +++ b/internal/nodes/http_access_log_queue_test.go @@ -0,0 +1,122 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nodes + +import ( + "bytes" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + _ "github.com/iwind/TeaGo/bootstrap" + "google.golang.org/grpc/status" + "reflect" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +func TestHTTPAccessLogQueue_Push(t *testing.T) { + // 发送到API + client, err := rpc.SharedRPC() + if err != nil { + t.Fatal(err) + } + + var requestId = 1_000_000 + + var utf8Bytes = []byte{} + for i := 0; i < 254; i++ { + utf8Bytes = append(utf8Bytes, uint8(i)) + } + + //bytes = []byte("真不错") + + //t.Log(strings.ToValidUTF8(string(utf8Bytes), "")) + _, err = client.HTTPAccessLogRPC().CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{ + { + ServerId: 23, + RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(requestId) + strconv.FormatInt(1, 10), + NodeId: 48, + Host: "www.hello.com", + RequestURI: string(utf8Bytes), + RequestPath: string(utf8Bytes), + Timestamp: time.Now().Unix(), + }, + }}) + if err != nil { + // 这里只是为了重现错误 + t.Logf("%#v, %s", err, err.Error()) + + statusErr, ok := status.FromError(err) + if ok { + t.Logf("%#v", statusErr) + } + return + } + t.Log("ok") +} + +func TestHTTPAccessLogQueue_Push2(t *testing.T) { + var utf8Bytes = []byte{} + for i := 0; i < 254; i++ { + utf8Bytes = append(utf8Bytes, uint8(i)) + } + + var accessLog = &pb.HTTPAccessLog{ + ServerId: 23, + RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(1) + strconv.FormatInt(1, 10), + NodeId: 48, + Host: "www.hello.com", + RequestURI: string(utf8Bytes), + RequestPath: string(utf8Bytes), + Timestamp: time.Now().Unix(), + } + var v = reflect.Indirect(reflect.ValueOf(accessLog)) + var countFields = v.NumField() + for i := 0; i < countFields; i++ { + var field = v.Field(i) + if field.Kind() == reflect.String { + field.SetString(strings.ToValidUTF8(field.String(), "")) + } + } + + client, err := rpc.SharedRPC() + if err != nil { + t.Fatal(err) + } + _, err = client.HTTPAccessLogRPC().CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{ + accessLog, + }}) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func BenchmarkHTTPAccessLogQueue_ToValidUTF8(b *testing.B) { + runtime.GOMAXPROCS(1) + + var utf8Bytes = []byte{} + for i := 0; i < 254; i++ { + utf8Bytes = append(utf8Bytes, uint8(i)) + } + + for i := 0; i < b.N; i++ { + _ = bytes.ToValidUTF8(utf8Bytes, nil) + } +} + +func BenchmarkHTTPAccessLogQueue_ToValidUTF8String(b *testing.B) { + runtime.GOMAXPROCS(1) + + var utf8Bytes = []byte{} + for i := 0; i < 254; i++ { + utf8Bytes = append(utf8Bytes, uint8(i)) + } + + var s = string(utf8Bytes) + for i := 0; i < b.N; i++ { + _ = strings.ToValidUTF8(s, "") + } +} diff --git a/internal/nodes/listener_base_test.go b/internal/nodes/listener_base_test.go index 2b7fe1a..d34a9df 100644 --- a/internal/nodes/listener_base_test.go +++ b/internal/nodes/listener_base_test.go @@ -13,7 +13,7 @@ import ( func TestBaseListener_FindServer(t *testing.T) { sharedNodeConfig = &nodeconfigs.NodeConfig{} - var listener = &BaseListener{namedServers: map[string]*NamedServer{}} + var listener = &BaseListener{} listener.Group = &serverconfigs.ServerAddressGroup{} for i := 0; i < 1_000_000; i++ { var server = &serverconfigs.ServerConfig{ @@ -24,7 +24,7 @@ func TestBaseListener_FindServer(t *testing.T) { }, } _ = server.Init() - listener.Group.Servers = append(listener.Group.Servers, server) + listener.Group.Add(server) } var before = time.Now()