From b4ddbbd38f0b7771cc77fcf67a11bbc10f9370b6 Mon Sep 17 00:00:00 2001 From: kanzihuang Date: Tue, 26 Sep 2023 22:47:19 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BD=BF=E7=94=A8=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E7=89=88=20vitess=20sqlparser=20=E8=A7=A3=E6=9E=90=20SQL=20?= =?UTF-8?q?=E8=AF=AD=E5=8F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题 --- server/go.mod | 17 +++- server/internal/db/api/db.go | 93 ++++++----------- server/internal/db/api/db_test.go | 35 ------- server/internal/db/api/sqlparser_test.go | 44 --------- server/internal/db/application/db_sql_exec.go | 2 +- server/pkg/sqlparser/sqlparser.go | 99 +++++++++++++++++++ server/pkg/sqlparser/sqlparser_test.go | 98 ++++++++++++++++++ 7 files changed, 239 insertions(+), 149 deletions(-) delete mode 100644 server/internal/db/api/sqlparser_test.go create mode 100644 server/pkg/sqlparser/sqlparser.go create mode 100644 server/pkg/sqlparser/sqlparser_test.go diff --git a/server/go.mod b/server/go.mod index e5cdd97b..5ea46c5b 100644 --- a/server/go.mod +++ b/server/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/golang-jwt/jwt/v5 v5.0.0 github.com/gorilla/websocket v1.5.0 + github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31 github.com/lib/pq v1.10.9 github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d github.com/mojocn/base64Captcha v1.3.5 // 验证码 @@ -28,11 +29,13 @@ require ( gopkg.in/yaml.v3 v3.0.1 // gorm gorm.io/driver/mysql v1.5.1 - gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.4 ) -require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 +require ( + github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + gorm.io/driver/sqlite v1.5.1 +) require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect @@ -47,6 +50,7 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect + github.com/golang/glog v1.0.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -55,15 +59,15 @@ require ( github.com/klauspost/compress v1.16.5 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/kr/fs v0.1.0 // indirect - github.com/kr/pretty v0.3.1 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.7.0 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect @@ -71,12 +75,15 @@ require ( github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect golang.org/x/net v0.16.0 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect + google.golang.org/grpc v1.52.3 // indirect google.golang.org/protobuf v1.31.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + vitess.io/vitess v0.17.3 // indirect ) diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 4c820a95..fb13518f 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -1,7 +1,6 @@ package api import ( - "bufio" "fmt" "github.com/lib/pq" "io" @@ -20,14 +19,13 @@ import ( "mayfly-go/pkg/gormx" "mayfly-go/pkg/model" "mayfly-go/pkg/req" + "mayfly-go/pkg/sqlparser" "mayfly-go/pkg/utils/stringx" - "regexp" "strconv" "strings" "time" "github.com/gin-gonic/gin" - "github.com/xwb1989/sqlparser" ) type Db struct { @@ -182,19 +180,17 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) defer func() { - if err := recover(); err != nil { - var errInfo string - switch t := err.(type) { - case biz.BizError: - errInfo = t.Error() - case *biz.BizError: - errInfo = t.Error() - case string: - errInfo = t - } - if len(errInfo) > 0 { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) - } + var errInfo string + switch t := recover().(type) { + case biz.BizError: + errInfo = t.Error() + case *biz.BizError: + errInfo = t.Error() + case string: + errInfo = t + } + if len(errInfo) > 0 { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) } }() @@ -206,19 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { LoginAccount: rc.LoginAccount, } - defer func() { - if err := recover(); err != nil { - var errInfo string - switch t := err.(type) { - case error: - errInfo = t.Error() - } - if len(errInfo) > 0 { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) - } - } - }() - progressId := uniqueid.IncrementID() executedStatements := 0 defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ @@ -228,10 +211,16 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { Terminated: true, }).WithCategory(progressCategory)) + var parser sqlparser.Parser + if dbConn.Info.Type == entity.DbTypeMysql { + parser = sqlparser.NewMysqlParser(file) + } else { + parser = sqlparser.NewPostgresParser(file) + } + ticker := time.NewTicker(time.Second * 1) defer ticker.Stop() - sqlScanner := SplitSqls(file) - for sqlScanner.Scan() { + for { select { case <-ticker.C: ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ @@ -242,7 +231,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { }).WithCategory(progressCategory)) default: } - sql := sqlScanner.Text() + err = parser.Next() + if err == io.EOF { + break + } + if err != nil { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error()))) + return + } + sql := parser.Current() const prefixUse = "use " if strings.HasPrefix(sql, prefixUse) { dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n") @@ -553,35 +550,3 @@ func getDbName(g *gin.Context) string { biz.NotEmpty(db, "db不能为空") return db } - -// 根据;\n切割sql -func SplitSqls(r io.Reader) *bufio.Scanner { - scanner := bufio.NewScanner(r) - re := regexp.MustCompile(`\s*;\s*\n`) - - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, io.EOF - } - - match := re.FindIndex(data) - - if match != nil { - // 如果找到了";\n",判断是否为最后一行 - if match[1] == len(data) { - // 如果是最后一行,则返回完整的切片 - return len(data), data, nil - } - // 否则,返回到";\n"之后,并且包括";\n"本身 - return match[1], data[:match[1]], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - - return scanner -} diff --git a/server/internal/db/api/db_test.go b/server/internal/db/api/db_test.go index 4f8547ee..ec6b9480 100644 --- a/server/internal/db/api/db_test.go +++ b/server/internal/db/api/db_test.go @@ -3,7 +3,6 @@ package api import ( "github.com/stretchr/testify/require" "mayfly-go/internal/db/domain/entity" - "strings" "testing" ) @@ -54,37 +53,3 @@ func Test_escapeSql(t *testing.T) { }) } } - -func Test_SplitSqls(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "create table with current_timestamp", - input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", - }, - { - name: "create table with current_date", - input: "create table tbl (\n\tcreate_at date default current_date()\n)", - }, - { - name: "select with ';\n'", - input: "select 'the first line;\nthe second line;\n'", - // SplitSqls split statements by ';\n' - want: "select 'the first line;\n", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - scanner := SplitSqls(strings.NewReader(test.input)) - require.True(t, scanner.Scan()) - got := scanner.Text() - if len(test.want) == 0 { - test.want = test.input - } - require.Equal(t, test.want, got) - }) - } -} diff --git a/server/internal/db/api/sqlparser_test.go b/server/internal/db/api/sqlparser_test.go deleted file mode 100644 index 3c41c878..00000000 --- a/server/internal/db/api/sqlparser_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package api - -import ( - "github.com/stretchr/testify/require" - "github.com/xwb1989/sqlparser" - "strings" - "testing" -) - -func Test_ParseNext_WithCurrentDate(t *testing.T) { - tests := []struct { - name string - input string - want string - wantXwb1989 string - err string - }{ - { - name: "create table with current_timestamp", - input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", - // xwb1989/sqlparser 不支持 current_timestamp() - wantXwb1989: "create table tbl", - }, - { - name: "create table with current_date", - input: "create table tbl (\n\tcreate_at date default current_date()\n)", - // xwb1989/sqlparser 不支持 current_date() - wantXwb1989: "create table tbl", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - token := sqlparser.NewTokenizer(strings.NewReader(test.input)) - tree, err := sqlparser.ParseNext(token) - if len(test.err) > 0 { - require.Error(t, err) - require.Contains(t, err.Error(), test.err) - return - } - require.NoError(t, err) - require.Equal(t, test.wantXwb1989, sqlparser.String(tree)) - }) - } -} diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 6f2d62bb..1490bb50 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -3,7 +3,7 @@ package application import ( "encoding/json" "fmt" - "github.com/xwb1989/sqlparser" + "github.com/kanzihuang/vitess/go/vt/sqlparser" "mayfly-go/internal/db/config" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" diff --git a/server/pkg/sqlparser/sqlparser.go b/server/pkg/sqlparser/sqlparser.go new file mode 100644 index 00000000..1215a5d6 --- /dev/null +++ b/server/pkg/sqlparser/sqlparser.go @@ -0,0 +1,99 @@ +package sqlparser + +import ( + "bufio" + "github.com/kanzihuang/vitess/go/vt/sqlparser" + "io" + "regexp" +) + +type Parser interface { + Next() error + Current() string +} + +var _ Parser = &MysqlParser{} +var _ Parser = &PostgresParser{} + +type MysqlParser struct { + tokenizer *sqlparser.Tokenizer + statement string +} + +func NewMysqlParser(reader io.Reader) *MysqlParser { + return &MysqlParser{ + tokenizer: sqlparser.NewReaderTokenizer(reader), + } +} + +func (parser *MysqlParser) Next() error { + statement, err := sqlparser.ParseNext(parser.tokenizer) + if err != nil { + parser.statement = "" + return err + } + parser.statement = sqlparser.String(statement) + return nil +} + +func (parser *MysqlParser) Current() string { + return parser.statement +} + +type PostgresParser struct { + scanner *bufio.Scanner + statement string +} + +func NewPostgresParser(reader io.Reader) *PostgresParser { + return &PostgresParser{ + scanner: splitSqls(reader), + } +} + +func (parser *PostgresParser) Next() error { + if !parser.scanner.Scan() { + return io.EOF + } + return nil +} + +func (parser *PostgresParser) Current() string { + return parser.scanner.Text() +} + +// 根据;\n切割sql +func splitSqls(r io.Reader) *bufio.Scanner { + scanner := bufio.NewScanner(r) + re := regexp.MustCompile(`\s*;\s*\n`) + + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, io.EOF + } + + match := re.FindIndex(data) + + if match != nil { + // 如果找到了";\n",判断是否为最后一行 + if match[1] == len(data) { + // 如果是最后一行,则返回完整的切片 + return len(data), data, nil + } + // 否则,返回到";\n"之后,并且包括";\n"本身 + return match[1], data[:match[1]], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + + return scanner +} + +func SplitStatementToPieces(sql string) ([]string, error) { + return sqlparser.SplitStatementToPieces(sql) +} diff --git a/server/pkg/sqlparser/sqlparser_test.go b/server/pkg/sqlparser/sqlparser_test.go new file mode 100644 index 00000000..1f2b65fd --- /dev/null +++ b/server/pkg/sqlparser/sqlparser_test.go @@ -0,0 +1,98 @@ +package sqlparser + +import ( + "github.com/kanzihuang/vitess/go/vt/sqlparser" + "github.com/stretchr/testify/require" + sqlparser_xwb1989 "github.com/xwb1989/sqlparser" + "strings" + "testing" +) + +func Test_ParseNext_WithCurrentDate(t *testing.T) { + tests := []struct { + name string + input string + want string + wantXwb1989 string + err string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + // xwb1989/sqlparser 不支持 current_timestamp() + wantXwb1989: "create table tbl", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + // xwb1989/sqlparser 不支持 current_date() + wantXwb1989: "create table tbl", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input)) + tree, err := sqlparser.ParseNext(token) + if len(test.err) > 0 { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + return + } + require.NoError(t, err) + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.want, sqlparser.String(tree)) + }) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input)) + tree, err := sqlparser_xwb1989.ParseNext(token) + if len(test.err) > 0 { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + return + } + require.NoError(t, err) + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree)) + }) + } +} + +func Test_SplitSqls(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + }, + { + name: "select with ';\n'", + input: "select 'the first line;\nthe second line;\n'", + // SplitSqls split statements by ';\n' + want: "select 'the first line;\n", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scanner := splitSqls(strings.NewReader(test.input)) + require.True(t, scanner.Scan()) + got := scanner.Text() + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.want, got) + }) + } +}