diff --git a/server/go.mod b/server/go.mod index 341230d3..486ca00f 100644 --- a/server/go.mod +++ b/server/go.mod @@ -13,7 +13,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/golang-jwt/jwt/v5 v5.0.0 github.com/gorilla/websocket v1.5.0 - github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31 + github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231014104824-e3b9aa5415a4 github.com/lib/pq v1.10.9 github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d github.com/mojocn/base64Captcha v1.3.5 // 验证码 diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 30a9edab..4b7732f7 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -2,13 +2,10 @@ package api import ( "fmt" - "io" - "mayfly-go/pkg/utils/collx" - "mayfly-go/pkg/utils/uniqueid" - "mayfly-go/pkg/ws" - + "github.com/gin-gonic/gin" + "github.com/kanzihuang/vitess/go/vt/sqlparser" "github.com/lib/pq" - + "io" "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" @@ -21,13 +18,13 @@ import ( "mayfly-go/pkg/gormx" "mayfly-go/pkg/model" "mayfly-go/pkg/req" - "mayfly-go/pkg/sqlparser" + "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/stringx" + "mayfly-go/pkg/utils/uniqueid" + "mayfly-go/pkg/ws" "strconv" "strings" "time" - - "github.com/gin-gonic/gin" ) type Db struct { @@ -179,7 +176,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { dbConn := d.DbApp.GetDbConnection(dbId, dbName) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") - rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) + rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc()) defer func() { var errInfo string @@ -190,7 +187,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { errInfo = t } if len(errInfo) > 0 { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) } }() @@ -211,13 +208,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { Terminated: true, }).WithCategory(progressCategory)) - var parser sqlparser.Parser - if dbConn.Info.Type == entity.DbTypeMysql { - parser = sqlparser.NewMysqlParser(file) - } else { - parser = sqlparser.NewPostgresParser(file) - } - + var sql string + tokenizer := sqlparser.NewReaderTokenizer(file, sqlparser.WithCacheInBuffer()) ticker := time.NewTicker(time.Second * 1) defer ticker.Stop() for { @@ -231,23 +223,29 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { }).WithCategory(progressCategory)) default: } - err = parser.Next() + sql, err = sqlparser.SplitNext(tokenizer) if err == io.EOF { break } if err != nil { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error()))) + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error()))) return } - sql := parser.Current() const prefixUse = "use " - if strings.HasPrefix(sql, prefixUse) { - dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n") - if len(dbNameExec) > 0 { - dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec) - biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") - execReq.DbConn = dbConn + const prefixUSE = "USE " + if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) { + var stmt sqlparser.Statement + stmt, err = sqlparser.Parse(sql) + if err != nil { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error()))) } + stmtUse, ok := stmt.(*sqlparser.Use) + if !ok { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql))) + } + dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String()) + biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") + execReq.DbConn = dbConn } // 需要记录执行记录 const maxRecordStatements = 64 @@ -264,7 +262,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { } executedStatements++ } - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成:%s", rc.ReqParam))) } // 数据库dump diff --git a/server/pkg/sqlparser/sqlparser.go b/server/pkg/sqlparser/sqlparser.go deleted file mode 100644 index 1215a5d6..00000000 --- a/server/pkg/sqlparser/sqlparser.go +++ /dev/null @@ -1,99 +0,0 @@ -package sqlparser - -import ( - "bufio" - "github.com/kanzihuang/vitess/go/vt/sqlparser" - "io" - "regexp" -) - -type Parser interface { - Next() error - Current() string -} - -var _ Parser = &MysqlParser{} -var _ Parser = &PostgresParser{} - -type MysqlParser struct { - tokenizer *sqlparser.Tokenizer - statement string -} - -func NewMysqlParser(reader io.Reader) *MysqlParser { - return &MysqlParser{ - tokenizer: sqlparser.NewReaderTokenizer(reader), - } -} - -func (parser *MysqlParser) Next() error { - statement, err := sqlparser.ParseNext(parser.tokenizer) - if err != nil { - parser.statement = "" - return err - } - parser.statement = sqlparser.String(statement) - return nil -} - -func (parser *MysqlParser) Current() string { - return parser.statement -} - -type PostgresParser struct { - scanner *bufio.Scanner - statement string -} - -func NewPostgresParser(reader io.Reader) *PostgresParser { - return &PostgresParser{ - scanner: splitSqls(reader), - } -} - -func (parser *PostgresParser) Next() error { - if !parser.scanner.Scan() { - return io.EOF - } - return nil -} - -func (parser *PostgresParser) Current() string { - return parser.scanner.Text() -} - -// 根据;\n切割sql -func splitSqls(r io.Reader) *bufio.Scanner { - scanner := bufio.NewScanner(r) - re := regexp.MustCompile(`\s*;\s*\n`) - - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, io.EOF - } - - match := re.FindIndex(data) - - if match != nil { - // 如果找到了";\n",判断是否为最后一行 - if match[1] == len(data) { - // 如果是最后一行,则返回完整的切片 - return len(data), data, nil - } - // 否则,返回到";\n"之后,并且包括";\n"本身 - return match[1], data[:match[1]], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - - return scanner -} - -func SplitStatementToPieces(sql string) ([]string, error) { - return sqlparser.SplitStatementToPieces(sql) -} diff --git a/server/pkg/sqlparser/sqlparser_test.go b/server/pkg/sqlparser/sqlparser_test.go deleted file mode 100644 index 168e611f..00000000 --- a/server/pkg/sqlparser/sqlparser_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package sqlparser - -import ( - "strings" - "testing" - - "github.com/kanzihuang/vitess/go/vt/sqlparser" - "github.com/stretchr/testify/require" -) - -func Test_ParseNext_WithCurrentDate(t *testing.T) { - tests := []struct { - name string - input string - want string - wantXwb1989 string - err string - }{ - { - name: "create table with current_timestamp", - input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", - // xwb1989/sqlparser 不支持 current_timestamp() - wantXwb1989: "create table tbl", - }, - { - name: "create table with current_date", - input: "create table tbl (\n\tcreate_at date default current_date()\n)", - // xwb1989/sqlparser 不支持 current_date() - wantXwb1989: "create table tbl", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input)) - tree, err := sqlparser.ParseNext(token) - if len(test.err) > 0 { - require.Error(t, err) - require.Contains(t, err.Error(), test.err) - return - } - require.NoError(t, err) - if len(test.want) == 0 { - test.want = test.input - } - require.Equal(t, test.want, sqlparser.String(tree)) - }) - } - // for _, test := range tests { - // t.Run(test.name, func(t *testing.T) { - // token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input)) - // tree, err := sqlparser_xwb1989.ParseNext(token) - // if len(test.err) > 0 { - // require.Error(t, err) - // require.Contains(t, err.Error(), test.err) - // return - // } - // require.NoError(t, err) - // if len(test.want) == 0 { - // test.want = test.input - // } - // require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree)) - // }) - // } -} - -func Test_SplitSqls(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "create table with current_timestamp", - input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", - }, - { - name: "create table with current_date", - input: "create table tbl (\n\tcreate_at date default current_date()\n)", - }, - { - name: "select with ';\n'", - input: "select 'the first line;\nthe second line;\n'", - // SplitSqls split statements by ';\n' - want: "select 'the first line;\n", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - scanner := splitSqls(strings.NewReader(test.input)) - require.True(t, scanner.Scan()) - got := scanner.Text() - if len(test.want) == 0 { - test.want = test.input - } - require.Equal(t, test.want, got) - }) - } -}