mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	fix: 使用最新版 vitess sqlparser 解析 SQL 语句
解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题
This commit is contained in:
		@@ -1,7 +1,6 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/lib/pq"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -20,14 +19,13 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/gormx"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"mayfly-go/pkg/req"
 | 
			
		||||
	"mayfly-go/pkg/sqlparser"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/xwb1989/sqlparser"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Db struct {
 | 
			
		||||
@@ -182,19 +180,17 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			var errInfo string
 | 
			
		||||
			switch t := err.(type) {
 | 
			
		||||
			case biz.BizError:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			case *biz.BizError:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			case string:
 | 
			
		||||
				errInfo = t
 | 
			
		||||
			}
 | 
			
		||||
			if len(errInfo) > 0 {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
			}
 | 
			
		||||
		var errInfo string
 | 
			
		||||
		switch t := recover().(type) {
 | 
			
		||||
		case biz.BizError:
 | 
			
		||||
			errInfo = t.Error()
 | 
			
		||||
		case *biz.BizError:
 | 
			
		||||
			errInfo = t.Error()
 | 
			
		||||
		case string:
 | 
			
		||||
			errInfo = t
 | 
			
		||||
		}
 | 
			
		||||
		if len(errInfo) > 0 {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@@ -206,19 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		LoginAccount: rc.LoginAccount,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			var errInfo string
 | 
			
		||||
			switch t := err.(type) {
 | 
			
		||||
			case error:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			}
 | 
			
		||||
			if len(errInfo) > 0 {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	progressId := uniqueid.IncrementID()
 | 
			
		||||
	executedStatements := 0
 | 
			
		||||
	defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
@@ -228,10 +211,16 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		Terminated:         true,
 | 
			
		||||
	}).WithCategory(progressCategory))
 | 
			
		||||
 | 
			
		||||
	var parser sqlparser.Parser
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		parser = sqlparser.NewMysqlParser(file)
 | 
			
		||||
	} else {
 | 
			
		||||
		parser = sqlparser.NewPostgresParser(file)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ticker := time.NewTicker(time.Second * 1)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	sqlScanner := SplitSqls(file)
 | 
			
		||||
	for sqlScanner.Scan() {
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
@@ -242,7 +231,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
			}).WithCategory(progressCategory))
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
		sql := sqlScanner.Text()
 | 
			
		||||
		err = parser.Next()
 | 
			
		||||
		if err == io.EOF {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		sql := parser.Current()
 | 
			
		||||
		const prefixUse = "use "
 | 
			
		||||
		if strings.HasPrefix(sql, prefixUse) {
 | 
			
		||||
			dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
			
		||||
@@ -553,35 +550,3 @@ func getDbName(g *gin.Context) string {
 | 
			
		||||
	biz.NotEmpty(db, "db不能为空")
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据;\n切割sql
 | 
			
		||||
func SplitSqls(r io.Reader) *bufio.Scanner {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	re := regexp.MustCompile(`\s*;\s*\n`)
 | 
			
		||||
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, io.EOF
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		match := re.FindIndex(data)
 | 
			
		||||
 | 
			
		||||
		if match != nil {
 | 
			
		||||
			// 如果找到了";\n",判断是否为最后一行
 | 
			
		||||
			if match[1] == len(data) {
 | 
			
		||||
				// 如果是最后一行,则返回完整的切片
 | 
			
		||||
				return len(data), data, nil
 | 
			
		||||
			}
 | 
			
		||||
			// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
			
		||||
			return match[1], data[:match[1]], nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return scanner
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user