mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	fix: 使用最新版 vitess sqlparser 解析 SQL 语句
解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题
This commit is contained in:
		@@ -13,6 +13,7 @@ require (
 | 
				
			|||||||
	github.com/go-sql-driver/mysql v1.7.1
 | 
						github.com/go-sql-driver/mysql v1.7.1
 | 
				
			||||||
	github.com/golang-jwt/jwt/v5 v5.0.0
 | 
						github.com/golang-jwt/jwt/v5 v5.0.0
 | 
				
			||||||
	github.com/gorilla/websocket v1.5.0
 | 
						github.com/gorilla/websocket v1.5.0
 | 
				
			||||||
 | 
						github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
 | 
				
			||||||
	github.com/lib/pq v1.10.9
 | 
						github.com/lib/pq v1.10.9
 | 
				
			||||||
	github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
 | 
						github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
 | 
				
			||||||
	github.com/mojocn/base64Captcha v1.3.5 // 验证码
 | 
						github.com/mojocn/base64Captcha v1.3.5 // 验证码
 | 
				
			||||||
@@ -28,11 +29,13 @@ require (
 | 
				
			|||||||
	gopkg.in/yaml.v3 v3.0.1
 | 
						gopkg.in/yaml.v3 v3.0.1
 | 
				
			||||||
	// gorm
 | 
						// gorm
 | 
				
			||||||
	gorm.io/driver/mysql v1.5.1
 | 
						gorm.io/driver/mysql v1.5.1
 | 
				
			||||||
	gorm.io/driver/sqlite v1.5.4
 | 
					 | 
				
			||||||
	gorm.io/gorm v1.25.4
 | 
						gorm.io/gorm v1.25.4
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
 | 
					require (
 | 
				
			||||||
 | 
						github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
 | 
				
			||||||
 | 
						gorm.io/driver/sqlite v1.5.1
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
 | 
						github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
 | 
				
			||||||
@@ -47,6 +50,7 @@ require (
 | 
				
			|||||||
	github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
 | 
						github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
 | 
				
			||||||
	github.com/goccy/go-json v0.10.2 // indirect
 | 
						github.com/goccy/go-json v0.10.2 // indirect
 | 
				
			||||||
	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
						github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
				
			||||||
 | 
						github.com/golang/glog v1.0.0 // indirect
 | 
				
			||||||
	github.com/golang/protobuf v1.5.3 // indirect
 | 
						github.com/golang/protobuf v1.5.3 // indirect
 | 
				
			||||||
	github.com/golang/snappy v0.0.4 // indirect
 | 
						github.com/golang/snappy v0.0.4 // indirect
 | 
				
			||||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
						github.com/jinzhu/inflection v1.0.0 // indirect
 | 
				
			||||||
@@ -55,15 +59,15 @@ require (
 | 
				
			|||||||
	github.com/klauspost/compress v1.16.5 // indirect
 | 
						github.com/klauspost/compress v1.16.5 // indirect
 | 
				
			||||||
	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 | 
						github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 | 
				
			||||||
	github.com/kr/fs v0.1.0 // indirect
 | 
						github.com/kr/fs v0.1.0 // indirect
 | 
				
			||||||
	github.com/kr/pretty v0.3.1 // indirect
 | 
					 | 
				
			||||||
	github.com/leodido/go-urn v1.2.4 // indirect
 | 
						github.com/leodido/go-urn v1.2.4 // indirect
 | 
				
			||||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
						github.com/mattn/go-isatty v0.0.19 // indirect
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.17 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.16 // indirect
 | 
				
			||||||
	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
						github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
				
			||||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
						github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
				
			||||||
	github.com/montanaflynn/stats v0.7.0 // indirect
 | 
						github.com/montanaflynn/stats v0.7.0 // indirect
 | 
				
			||||||
	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
						github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
				
			||||||
	github.com/pmezard/go-difflib v1.0.0 // indirect
 | 
						github.com/pmezard/go-difflib v1.0.0 // indirect
 | 
				
			||||||
 | 
						github.com/spf13/pflag v1.0.5 // indirect
 | 
				
			||||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
						github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
				
			||||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
						github.com/ugorji/go/codec v1.2.11 // indirect
 | 
				
			||||||
	github.com/xdg-go/pbkdf2 v1.0.0 // indirect
 | 
						github.com/xdg-go/pbkdf2 v1.0.0 // indirect
 | 
				
			||||||
@@ -71,12 +75,15 @@ require (
 | 
				
			|||||||
	github.com/xdg-go/stringprep v1.0.4 // indirect
 | 
						github.com/xdg-go/stringprep v1.0.4 // indirect
 | 
				
			||||||
	github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
 | 
						github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
 | 
				
			||||||
	golang.org/x/arch v0.3.0 // indirect
 | 
						golang.org/x/arch v0.3.0 // indirect
 | 
				
			||||||
 | 
						golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect
 | 
				
			||||||
	golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
 | 
						golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
 | 
				
			||||||
	golang.org/x/net v0.16.0 // indirect
 | 
						golang.org/x/net v0.16.0 // indirect
 | 
				
			||||||
	golang.org/x/sync v0.1.0 // indirect
 | 
						golang.org/x/sync v0.1.0 // indirect
 | 
				
			||||||
	golang.org/x/sys v0.13.0 // indirect
 | 
						golang.org/x/sys v0.13.0 // indirect
 | 
				
			||||||
	golang.org/x/text v0.13.0 // indirect
 | 
						golang.org/x/text v0.13.0 // indirect
 | 
				
			||||||
	google.golang.org/appengine v1.6.7 // indirect
 | 
						google.golang.org/appengine v1.6.7 // indirect
 | 
				
			||||||
 | 
						google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect
 | 
				
			||||||
 | 
						google.golang.org/grpc v1.52.3 // indirect
 | 
				
			||||||
	google.golang.org/protobuf v1.31.0 // indirect
 | 
						google.golang.org/protobuf v1.31.0 // indirect
 | 
				
			||||||
	gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
 | 
						vitess.io/vitess v0.17.3 // indirect
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,6 @@
 | 
				
			|||||||
package api
 | 
					package api
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/lib/pq"
 | 
						"github.com/lib/pq"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
@@ -20,14 +19,13 @@ import (
 | 
				
			|||||||
	"mayfly-go/pkg/gormx"
 | 
						"mayfly-go/pkg/gormx"
 | 
				
			||||||
	"mayfly-go/pkg/model"
 | 
						"mayfly-go/pkg/model"
 | 
				
			||||||
	"mayfly-go/pkg/req"
 | 
						"mayfly-go/pkg/req"
 | 
				
			||||||
 | 
						"mayfly-go/pkg/sqlparser"
 | 
				
			||||||
	"mayfly-go/pkg/utils/stringx"
 | 
						"mayfly-go/pkg/utils/stringx"
 | 
				
			||||||
	"regexp"
 | 
					 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/xwb1989/sqlparser"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Db struct {
 | 
					type Db struct {
 | 
				
			||||||
@@ -182,9 +180,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
				
			|||||||
	rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
						rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
		if err := recover(); err != nil {
 | 
					 | 
				
			||||||
		var errInfo string
 | 
							var errInfo string
 | 
				
			||||||
			switch t := err.(type) {
 | 
							switch t := recover().(type) {
 | 
				
			||||||
		case biz.BizError:
 | 
							case biz.BizError:
 | 
				
			||||||
			errInfo = t.Error()
 | 
								errInfo = t.Error()
 | 
				
			||||||
		case *biz.BizError:
 | 
							case *biz.BizError:
 | 
				
			||||||
@@ -195,7 +192,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
				
			|||||||
		if len(errInfo) > 0 {
 | 
							if len(errInfo) > 0 {
 | 
				
			||||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
								d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	execReq := &application.DbSqlExecReq{
 | 
						execReq := &application.DbSqlExecReq{
 | 
				
			||||||
@@ -206,19 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
				
			|||||||
		LoginAccount: rc.LoginAccount,
 | 
							LoginAccount: rc.LoginAccount,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
					 | 
				
			||||||
		if err := recover(); err != nil {
 | 
					 | 
				
			||||||
			var errInfo string
 | 
					 | 
				
			||||||
			switch t := err.(type) {
 | 
					 | 
				
			||||||
			case error:
 | 
					 | 
				
			||||||
				errInfo = t.Error()
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if len(errInfo) > 0 {
 | 
					 | 
				
			||||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	progressId := uniqueid.IncrementID()
 | 
						progressId := uniqueid.IncrementID()
 | 
				
			||||||
	executedStatements := 0
 | 
						executedStatements := 0
 | 
				
			||||||
	defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
						defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
				
			||||||
@@ -228,10 +211,16 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
				
			|||||||
		Terminated:         true,
 | 
							Terminated:         true,
 | 
				
			||||||
	}).WithCategory(progressCategory))
 | 
						}).WithCategory(progressCategory))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var parser sqlparser.Parser
 | 
				
			||||||
 | 
						if dbConn.Info.Type == entity.DbTypeMysql {
 | 
				
			||||||
 | 
							parser = sqlparser.NewMysqlParser(file)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							parser = sqlparser.NewPostgresParser(file)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ticker := time.NewTicker(time.Second * 1)
 | 
						ticker := time.NewTicker(time.Second * 1)
 | 
				
			||||||
	defer ticker.Stop()
 | 
						defer ticker.Stop()
 | 
				
			||||||
	sqlScanner := SplitSqls(file)
 | 
						for {
 | 
				
			||||||
	for sqlScanner.Scan() {
 | 
					 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case <-ticker.C:
 | 
							case <-ticker.C:
 | 
				
			||||||
			ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
								ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
				
			||||||
@@ -242,7 +231,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
				
			|||||||
			}).WithCategory(progressCategory))
 | 
								}).WithCategory(progressCategory))
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		sql := sqlScanner.Text()
 | 
							err = parser.Next()
 | 
				
			||||||
 | 
							if err == io.EOF {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							sql := parser.Current()
 | 
				
			||||||
		const prefixUse = "use "
 | 
							const prefixUse = "use "
 | 
				
			||||||
		if strings.HasPrefix(sql, prefixUse) {
 | 
							if strings.HasPrefix(sql, prefixUse) {
 | 
				
			||||||
			dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
								dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
				
			||||||
@@ -553,35 +550,3 @@ func getDbName(g *gin.Context) string {
 | 
				
			|||||||
	biz.NotEmpty(db, "db不能为空")
 | 
						biz.NotEmpty(db, "db不能为空")
 | 
				
			||||||
	return db
 | 
						return db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
// 根据;\n切割sql
 | 
					 | 
				
			||||||
func SplitSqls(r io.Reader) *bufio.Scanner {
 | 
					 | 
				
			||||||
	scanner := bufio.NewScanner(r)
 | 
					 | 
				
			||||||
	re := regexp.MustCompile(`\s*;\s*\n`)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
					 | 
				
			||||||
		if atEOF && len(data) == 0 {
 | 
					 | 
				
			||||||
			return 0, nil, io.EOF
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		match := re.FindIndex(data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if match != nil {
 | 
					 | 
				
			||||||
			// 如果找到了";\n",判断是否为最后一行
 | 
					 | 
				
			||||||
			if match[1] == len(data) {
 | 
					 | 
				
			||||||
				// 如果是最后一行,则返回完整的切片
 | 
					 | 
				
			||||||
				return len(data), data, nil
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
					 | 
				
			||||||
			return match[1], data[:match[1]], nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if atEOF {
 | 
					 | 
				
			||||||
			return len(data), data, nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		return 0, nil, nil
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return scanner
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,7 +3,6 @@ package api
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/stretchr/testify/require"
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
	"mayfly-go/internal/db/domain/entity"
 | 
						"mayfly-go/internal/db/domain/entity"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -54,37 +53,3 @@ func Test_escapeSql(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func Test_SplitSqls(t *testing.T) {
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name  string
 | 
					 | 
				
			||||||
		input string
 | 
					 | 
				
			||||||
		want  string
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:  "create table with current_timestamp",
 | 
					 | 
				
			||||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:  "create table with current_date",
 | 
					 | 
				
			||||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:  "select with ';\n'",
 | 
					 | 
				
			||||||
			input: "select 'the first line;\nthe second line;\n'",
 | 
					 | 
				
			||||||
			// SplitSqls split statements by ';\n'
 | 
					 | 
				
			||||||
			want: "select 'the first line;\n",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, test := range tests {
 | 
					 | 
				
			||||||
		t.Run(test.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			scanner := SplitSqls(strings.NewReader(test.input))
 | 
					 | 
				
			||||||
			require.True(t, scanner.Scan())
 | 
					 | 
				
			||||||
			got := scanner.Text()
 | 
					 | 
				
			||||||
			if len(test.want) == 0 {
 | 
					 | 
				
			||||||
				test.want = test.input
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			require.Equal(t, test.want, got)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,44 +0,0 @@
 | 
				
			|||||||
package api
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"github.com/stretchr/testify/require"
 | 
					 | 
				
			||||||
	"github.com/xwb1989/sqlparser"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name        string
 | 
					 | 
				
			||||||
		input       string
 | 
					 | 
				
			||||||
		want        string
 | 
					 | 
				
			||||||
		wantXwb1989 string
 | 
					 | 
				
			||||||
		err         string
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:  "create table with current_timestamp",
 | 
					 | 
				
			||||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
					 | 
				
			||||||
			// xwb1989/sqlparser 不支持 current_timestamp()
 | 
					 | 
				
			||||||
			wantXwb1989: "create table tbl",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:  "create table with current_date",
 | 
					 | 
				
			||||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
					 | 
				
			||||||
			// xwb1989/sqlparser 不支持 current_date()
 | 
					 | 
				
			||||||
			wantXwb1989: "create table tbl",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, test := range tests {
 | 
					 | 
				
			||||||
		t.Run(test.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			token := sqlparser.NewTokenizer(strings.NewReader(test.input))
 | 
					 | 
				
			||||||
			tree, err := sqlparser.ParseNext(token)
 | 
					 | 
				
			||||||
			if len(test.err) > 0 {
 | 
					 | 
				
			||||||
				require.Error(t, err)
 | 
					 | 
				
			||||||
				require.Contains(t, err.Error(), test.err)
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			require.NoError(t, err)
 | 
					 | 
				
			||||||
			require.Equal(t, test.wantXwb1989, sqlparser.String(tree))
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -3,7 +3,7 @@ package application
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/xwb1989/sqlparser"
 | 
						"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
				
			||||||
	"mayfly-go/internal/db/config"
 | 
						"mayfly-go/internal/db/config"
 | 
				
			||||||
	"mayfly-go/internal/db/domain/entity"
 | 
						"mayfly-go/internal/db/domain/entity"
 | 
				
			||||||
	"mayfly-go/internal/db/domain/repository"
 | 
						"mayfly-go/internal/db/domain/repository"
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										99
									
								
								server/pkg/sqlparser/sqlparser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								server/pkg/sqlparser/sqlparser.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
				
			|||||||
 | 
					package sqlparser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Parser interface {
 | 
				
			||||||
 | 
						Next() error
 | 
				
			||||||
 | 
						Current() string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ Parser = &MysqlParser{}
 | 
				
			||||||
 | 
					var _ Parser = &PostgresParser{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MysqlParser struct {
 | 
				
			||||||
 | 
						tokenizer *sqlparser.Tokenizer
 | 
				
			||||||
 | 
						statement string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewMysqlParser(reader io.Reader) *MysqlParser {
 | 
				
			||||||
 | 
						return &MysqlParser{
 | 
				
			||||||
 | 
							tokenizer: sqlparser.NewReaderTokenizer(reader),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (parser *MysqlParser) Next() error {
 | 
				
			||||||
 | 
						statement, err := sqlparser.ParseNext(parser.tokenizer)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							parser.statement = ""
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						parser.statement = sqlparser.String(statement)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (parser *MysqlParser) Current() string {
 | 
				
			||||||
 | 
						return parser.statement
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PostgresParser struct {
 | 
				
			||||||
 | 
						scanner   *bufio.Scanner
 | 
				
			||||||
 | 
						statement string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewPostgresParser(reader io.Reader) *PostgresParser {
 | 
				
			||||||
 | 
						return &PostgresParser{
 | 
				
			||||||
 | 
							scanner: splitSqls(reader),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (parser *PostgresParser) Next() error {
 | 
				
			||||||
 | 
						if !parser.scanner.Scan() {
 | 
				
			||||||
 | 
							return io.EOF
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (parser *PostgresParser) Current() string {
 | 
				
			||||||
 | 
						return parser.scanner.Text()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 根据;\n切割sql
 | 
				
			||||||
 | 
					func splitSqls(r io.Reader) *bufio.Scanner {
 | 
				
			||||||
 | 
						scanner := bufio.NewScanner(r)
 | 
				
			||||||
 | 
						re := regexp.MustCompile(`\s*;\s*\n`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
 | 
								return 0, nil, io.EOF
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							match := re.FindIndex(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if match != nil {
 | 
				
			||||||
 | 
								// 如果找到了";\n",判断是否为最后一行
 | 
				
			||||||
 | 
								if match[1] == len(data) {
 | 
				
			||||||
 | 
									// 如果是最后一行,则返回完整的切片
 | 
				
			||||||
 | 
									return len(data), data, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
				
			||||||
 | 
								return match[1], data[:match[1]], nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if atEOF {
 | 
				
			||||||
 | 
								return len(data), data, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return 0, nil, nil
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return scanner
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SplitStatementToPieces(sql string) ([]string, error) {
 | 
				
			||||||
 | 
						return sqlparser.SplitStatementToPieces(sql)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										98
									
								
								server/pkg/sqlparser/sqlparser_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								server/pkg/sqlparser/sqlparser_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
				
			|||||||
 | 
					package sqlparser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
						sqlparser_xwb1989 "github.com/xwb1989/sqlparser"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Test_ParseNext_WithCurrentDate(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name        string
 | 
				
			||||||
 | 
							input       string
 | 
				
			||||||
 | 
							want        string
 | 
				
			||||||
 | 
							wantXwb1989 string
 | 
				
			||||||
 | 
							err         string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "create table with current_timestamp",
 | 
				
			||||||
 | 
								input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
				
			||||||
 | 
								// xwb1989/sqlparser 不支持 current_timestamp()
 | 
				
			||||||
 | 
								wantXwb1989: "create table tbl",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "create table with current_date",
 | 
				
			||||||
 | 
								input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
				
			||||||
 | 
								// xwb1989/sqlparser 不支持 current_date()
 | 
				
			||||||
 | 
								wantXwb1989: "create table tbl",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
 | 
				
			||||||
 | 
								tree, err := sqlparser.ParseNext(token)
 | 
				
			||||||
 | 
								if len(test.err) > 0 {
 | 
				
			||||||
 | 
									require.Error(t, err)
 | 
				
			||||||
 | 
									require.Contains(t, err.Error(), test.err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
								if len(test.want) == 0 {
 | 
				
			||||||
 | 
									test.want = test.input
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								require.Equal(t, test.want, sqlparser.String(tree))
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
 | 
				
			||||||
 | 
								tree, err := sqlparser_xwb1989.ParseNext(token)
 | 
				
			||||||
 | 
								if len(test.err) > 0 {
 | 
				
			||||||
 | 
									require.Error(t, err)
 | 
				
			||||||
 | 
									require.Contains(t, err.Error(), test.err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								require.NoError(t, err)
 | 
				
			||||||
 | 
								if len(test.want) == 0 {
 | 
				
			||||||
 | 
									test.want = test.input
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Test_SplitSqls(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name  string
 | 
				
			||||||
 | 
							input string
 | 
				
			||||||
 | 
							want  string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "create table with current_timestamp",
 | 
				
			||||||
 | 
								input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "create table with current_date",
 | 
				
			||||||
 | 
								input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "select with ';\n'",
 | 
				
			||||||
 | 
								input: "select 'the first line;\nthe second line;\n'",
 | 
				
			||||||
 | 
								// SplitSqls split statements by ';\n'
 | 
				
			||||||
 | 
								want: "select 'the first line;\n",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								scanner := splitSqls(strings.NewReader(test.input))
 | 
				
			||||||
 | 
								require.True(t, scanner.Scan())
 | 
				
			||||||
 | 
								got := scanner.Text()
 | 
				
			||||||
 | 
								if len(test.want) == 0 {
 | 
				
			||||||
 | 
									test.want = test.input
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								require.Equal(t, test.want, got)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user