fix: 使用最新版 vitess sqlparser 解析 SQL 语句

解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题
This commit is contained in:
kanzihuang
2023-09-26 22:47:19 +08:00
parent 7544288451
commit b4ddbbd38f
7 changed files with 239 additions and 149 deletions

View File

@@ -1,7 +1,6 @@
package api
import (
"bufio"
"fmt"
"github.com/lib/pq"
"io"
@@ -20,14 +19,13 @@ import (
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
"mayfly-go/pkg/sqlparser"
"mayfly-go/pkg/utils/stringx"
"regexp"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/xwb1989/sqlparser"
)
type Db struct {
@@ -182,19 +180,17 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
defer func() {
if err := recover(); err != nil {
var errInfo string
switch t := err.(type) {
case biz.BizError:
errInfo = t.Error()
case *biz.BizError:
errInfo = t.Error()
case string:
errInfo = t
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
var errInfo string
switch t := recover().(type) {
case biz.BizError:
errInfo = t.Error()
case *biz.BizError:
errInfo = t.Error()
case string:
errInfo = t
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}()
@@ -206,19 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
LoginAccount: rc.LoginAccount,
}
defer func() {
if err := recover(); err != nil {
var errInfo string
switch t := err.(type) {
case error:
errInfo = t.Error()
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}
}()
progressId := uniqueid.IncrementID()
executedStatements := 0
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
@@ -228,10 +211,16 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
Terminated: true,
}).WithCategory(progressCategory))
var parser sqlparser.Parser
if dbConn.Info.Type == entity.DbTypeMysql {
parser = sqlparser.NewMysqlParser(file)
} else {
parser = sqlparser.NewPostgresParser(file)
}
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
sqlScanner := SplitSqls(file)
for sqlScanner.Scan() {
for {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
@@ -242,7 +231,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
}).WithCategory(progressCategory))
default:
}
sql := sqlScanner.Text()
err = parser.Next()
if err == io.EOF {
break
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
return
}
sql := parser.Current()
const prefixUse = "use "
if strings.HasPrefix(sql, prefixUse) {
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
@@ -553,35 +550,3 @@ func getDbName(g *gin.Context) string {
biz.NotEmpty(db, "db不能为空")
return db
}
// 根据;\n切割sql
func SplitSqls(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
re := regexp.MustCompile(`\s*;\s*\n`)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, io.EOF
}
match := re.FindIndex(data)
if match != nil {
// 如果找到了";\n",判断是否为最后一行
if match[1] == len(data) {
// 如果是最后一行,则返回完整的切片
return len(data), data, nil
}
// 否则,返回到";\n"之后,并且包括";\n"本身
return match[1], data[:match[1]], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
return scanner
}