mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-03 16:00:25 +08:00
fix: 使用最新版 vitess sqlparser 解析 SQL 语句
解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题
This commit is contained in:
@@ -13,6 +13,7 @@ require (
|
|||||||
github.com/go-sql-driver/mysql v1.7.1
|
github.com/go-sql-driver/mysql v1.7.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0
|
github.com/golang-jwt/jwt/v5 v5.0.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
|
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
|
||||||
github.com/mojocn/base64Captcha v1.3.5 // 验证码
|
github.com/mojocn/base64Captcha v1.3.5 // 验证码
|
||||||
@@ -28,11 +29,13 @@ require (
|
|||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
// gorm
|
// gorm
|
||||||
gorm.io/driver/mysql v1.5.1
|
gorm.io/driver/mysql v1.5.1
|
||||||
gorm.io/driver/sqlite v1.5.4
|
|
||||||
gorm.io/gorm v1.25.4
|
gorm.io/gorm v1.25.4
|
||||||
)
|
)
|
||||||
|
|
||||||
require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
require (
|
||||||
|
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
||||||
|
gorm.io/driver/sqlite v1.5.1
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||||
@@ -47,6 +50,7 @@ require (
|
|||||||
github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||||
|
github.com/golang/glog v1.0.0 // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
github.com/golang/snappy v0.0.4 // indirect
|
github.com/golang/snappy v0.0.4 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
@@ -55,15 +59,15 @@ require (
|
|||||||
github.com/klauspost/compress v1.16.5 // indirect
|
github.com/klauspost/compress v1.16.5 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
github.com/kr/fs v0.1.0 // indirect
|
github.com/kr/fs v0.1.0 // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/montanaflynn/stats v0.7.0 // indirect
|
github.com/montanaflynn/stats v0.7.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||||
@@ -71,12 +75,15 @@ require (
|
|||||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||||
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
|
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect
|
||||||
golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
|
golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
|
||||||
golang.org/x/net v0.16.0 // indirect
|
golang.org/x/net v0.16.0 // indirect
|
||||||
golang.org/x/sync v0.1.0 // indirect
|
golang.org/x/sync v0.1.0 // indirect
|
||||||
golang.org/x/sys v0.13.0 // indirect
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
golang.org/x/text v0.13.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
google.golang.org/appengine v1.6.7 // indirect
|
google.golang.org/appengine v1.6.7 // indirect
|
||||||
|
google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect
|
||||||
|
google.golang.org/grpc v1.52.3 // indirect
|
||||||
google.golang.org/protobuf v1.31.0 // indirect
|
google.golang.org/protobuf v1.31.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
vitess.io/vitess v0.17.3 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"io"
|
"io"
|
||||||
@@ -20,14 +19,13 @@ import (
|
|||||||
"mayfly-go/pkg/gormx"
|
"mayfly-go/pkg/gormx"
|
||||||
"mayfly-go/pkg/model"
|
"mayfly-go/pkg/model"
|
||||||
"mayfly-go/pkg/req"
|
"mayfly-go/pkg/req"
|
||||||
|
"mayfly-go/pkg/sqlparser"
|
||||||
"mayfly-go/pkg/utils/stringx"
|
"mayfly-go/pkg/utils/stringx"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/xwb1989/sqlparser"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Db struct {
|
type Db struct {
|
||||||
@@ -182,19 +180,17 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
|||||||
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
|
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
var errInfo string
|
||||||
var errInfo string
|
switch t := recover().(type) {
|
||||||
switch t := err.(type) {
|
case biz.BizError:
|
||||||
case biz.BizError:
|
errInfo = t.Error()
|
||||||
errInfo = t.Error()
|
case *biz.BizError:
|
||||||
case *biz.BizError:
|
errInfo = t.Error()
|
||||||
errInfo = t.Error()
|
case string:
|
||||||
case string:
|
errInfo = t
|
||||||
errInfo = t
|
}
|
||||||
}
|
if len(errInfo) > 0 {
|
||||||
if len(errInfo) > 0 {
|
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -206,19 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
|||||||
LoginAccount: rc.LoginAccount,
|
LoginAccount: rc.LoginAccount,
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
var errInfo string
|
|
||||||
switch t := err.(type) {
|
|
||||||
case error:
|
|
||||||
errInfo = t.Error()
|
|
||||||
}
|
|
||||||
if len(errInfo) > 0 {
|
|
||||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
progressId := uniqueid.IncrementID()
|
progressId := uniqueid.IncrementID()
|
||||||
executedStatements := 0
|
executedStatements := 0
|
||||||
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||||
@@ -228,10 +211,16 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
|||||||
Terminated: true,
|
Terminated: true,
|
||||||
}).WithCategory(progressCategory))
|
}).WithCategory(progressCategory))
|
||||||
|
|
||||||
|
var parser sqlparser.Parser
|
||||||
|
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||||
|
parser = sqlparser.NewMysqlParser(file)
|
||||||
|
} else {
|
||||||
|
parser = sqlparser.NewPostgresParser(file)
|
||||||
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Second * 1)
|
ticker := time.NewTicker(time.Second * 1)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
sqlScanner := SplitSqls(file)
|
for {
|
||||||
for sqlScanner.Scan() {
|
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||||
@@ -242,7 +231,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
|||||||
}).WithCategory(progressCategory))
|
}).WithCategory(progressCategory))
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
sql := sqlScanner.Text()
|
err = parser.Next()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sql := parser.Current()
|
||||||
const prefixUse = "use "
|
const prefixUse = "use "
|
||||||
if strings.HasPrefix(sql, prefixUse) {
|
if strings.HasPrefix(sql, prefixUse) {
|
||||||
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
|
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
|
||||||
@@ -553,35 +550,3 @@ func getDbName(g *gin.Context) string {
|
|||||||
biz.NotEmpty(db, "db不能为空")
|
biz.NotEmpty(db, "db不能为空")
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据;\n切割sql
|
|
||||||
func SplitSqls(r io.Reader) *bufio.Scanner {
|
|
||||||
scanner := bufio.NewScanner(r)
|
|
||||||
re := regexp.MustCompile(`\s*;\s*\n`)
|
|
||||||
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
match := re.FindIndex(data)
|
|
||||||
|
|
||||||
if match != nil {
|
|
||||||
// 如果找到了";\n",判断是否为最后一行
|
|
||||||
if match[1] == len(data) {
|
|
||||||
// 如果是最后一行,则返回完整的切片
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
// 否则,返回到";\n"之后,并且包括";\n"本身
|
|
||||||
return match[1], data[:match[1]], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return scanner
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package api
|
|||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"mayfly-go/internal/db/domain/entity"
|
"mayfly-go/internal/db/domain/entity"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,37 +53,3 @@ func Test_escapeSql(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SplitSqls(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "create table with current_timestamp",
|
|
||||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "create table with current_date",
|
|
||||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "select with ';\n'",
|
|
||||||
input: "select 'the first line;\nthe second line;\n'",
|
|
||||||
// SplitSqls split statements by ';\n'
|
|
||||||
want: "select 'the first line;\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
scanner := SplitSqls(strings.NewReader(test.input))
|
|
||||||
require.True(t, scanner.Scan())
|
|
||||||
got := scanner.Text()
|
|
||||||
if len(test.want) == 0 {
|
|
||||||
test.want = test.input
|
|
||||||
}
|
|
||||||
require.Equal(t, test.want, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/xwb1989/sqlparser"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
wantXwb1989 string
|
|
||||||
err string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "create table with current_timestamp",
|
|
||||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
|
||||||
// xwb1989/sqlparser 不支持 current_timestamp()
|
|
||||||
wantXwb1989: "create table tbl",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "create table with current_date",
|
|
||||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
|
||||||
// xwb1989/sqlparser 不支持 current_date()
|
|
||||||
wantXwb1989: "create table tbl",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
token := sqlparser.NewTokenizer(strings.NewReader(test.input))
|
|
||||||
tree, err := sqlparser.ParseNext(token)
|
|
||||||
if len(test.err) > 0 {
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), test.err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, test.wantXwb1989, sqlparser.String(tree))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,7 @@ package application
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/xwb1989/sqlparser"
|
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||||
"mayfly-go/internal/db/config"
|
"mayfly-go/internal/db/config"
|
||||||
"mayfly-go/internal/db/domain/entity"
|
"mayfly-go/internal/db/domain/entity"
|
||||||
"mayfly-go/internal/db/domain/repository"
|
"mayfly-go/internal/db/domain/repository"
|
||||||
|
|||||||
99
server/pkg/sqlparser/sqlparser.go
Normal file
99
server/pkg/sqlparser/sqlparser.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package sqlparser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||||
|
"io"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parser interface {
|
||||||
|
Next() error
|
||||||
|
Current() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Parser = &MysqlParser{}
|
||||||
|
var _ Parser = &PostgresParser{}
|
||||||
|
|
||||||
|
type MysqlParser struct {
|
||||||
|
tokenizer *sqlparser.Tokenizer
|
||||||
|
statement string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMysqlParser(reader io.Reader) *MysqlParser {
|
||||||
|
return &MysqlParser{
|
||||||
|
tokenizer: sqlparser.NewReaderTokenizer(reader),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *MysqlParser) Next() error {
|
||||||
|
statement, err := sqlparser.ParseNext(parser.tokenizer)
|
||||||
|
if err != nil {
|
||||||
|
parser.statement = ""
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
parser.statement = sqlparser.String(statement)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *MysqlParser) Current() string {
|
||||||
|
return parser.statement
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresParser struct {
|
||||||
|
scanner *bufio.Scanner
|
||||||
|
statement string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresParser(reader io.Reader) *PostgresParser {
|
||||||
|
return &PostgresParser{
|
||||||
|
scanner: splitSqls(reader),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *PostgresParser) Next() error {
|
||||||
|
if !parser.scanner.Scan() {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *PostgresParser) Current() string {
|
||||||
|
return parser.scanner.Text()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据;\n切割sql
|
||||||
|
func splitSqls(r io.Reader) *bufio.Scanner {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
re := regexp.MustCompile(`\s*;\s*\n`)
|
||||||
|
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
match := re.FindIndex(data)
|
||||||
|
|
||||||
|
if match != nil {
|
||||||
|
// 如果找到了";\n",判断是否为最后一行
|
||||||
|
if match[1] == len(data) {
|
||||||
|
// 如果是最后一行,则返回完整的切片
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
// 否则,返回到";\n"之后,并且包括";\n"本身
|
||||||
|
return match[1], data[:match[1]], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return scanner
|
||||||
|
}
|
||||||
|
|
||||||
|
func SplitStatementToPieces(sql string) ([]string, error) {
|
||||||
|
return sqlparser.SplitStatementToPieces(sql)
|
||||||
|
}
|
||||||
98
server/pkg/sqlparser/sqlparser_test.go
Normal file
98
server/pkg/sqlparser/sqlparser_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package sqlparser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
sqlparser_xwb1989 "github.com/xwb1989/sqlparser"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_ParseNext_WithCurrentDate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
wantXwb1989 string
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create table with current_timestamp",
|
||||||
|
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||||
|
// xwb1989/sqlparser 不支持 current_timestamp()
|
||||||
|
wantXwb1989: "create table tbl",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create table with current_date",
|
||||||
|
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||||
|
// xwb1989/sqlparser 不支持 current_date()
|
||||||
|
wantXwb1989: "create table tbl",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
|
||||||
|
tree, err := sqlparser.ParseNext(token)
|
||||||
|
if len(test.err) > 0 {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), test.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
if len(test.want) == 0 {
|
||||||
|
test.want = test.input
|
||||||
|
}
|
||||||
|
require.Equal(t, test.want, sqlparser.String(tree))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
|
||||||
|
tree, err := sqlparser_xwb1989.ParseNext(token)
|
||||||
|
if len(test.err) > 0 {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), test.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
if len(test.want) == 0 {
|
||||||
|
test.want = test.input
|
||||||
|
}
|
||||||
|
require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_SplitSqls(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create table with current_timestamp",
|
||||||
|
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create table with current_date",
|
||||||
|
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with ';\n'",
|
||||||
|
input: "select 'the first line;\nthe second line;\n'",
|
||||||
|
// SplitSqls split statements by ';\n'
|
||||||
|
want: "select 'the first line;\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
scanner := splitSqls(strings.NewReader(test.input))
|
||||||
|
require.True(t, scanner.Scan())
|
||||||
|
got := scanner.Text()
|
||||||
|
if len(test.want) == 0 {
|
||||||
|
test.want = test.input
|
||||||
|
}
|
||||||
|
require.Equal(t, test.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user