fix: 使用最新版 vitess sqlparser 解析 SQL 语句

解决 xwb1989/sqlparser 不支持 current_timestamp() 的问题
This commit is contained in:
kanzihuang
2023-09-26 22:47:19 +08:00
parent 7544288451
commit b4ddbbd38f
7 changed files with 239 additions and 149 deletions

View File

@@ -0,0 +1,99 @@
package sqlparser
import (
"bufio"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"io"
"regexp"
)
type Parser interface {
Next() error
Current() string
}
var _ Parser = &MysqlParser{}
var _ Parser = &PostgresParser{}
type MysqlParser struct {
tokenizer *sqlparser.Tokenizer
statement string
}
func NewMysqlParser(reader io.Reader) *MysqlParser {
return &MysqlParser{
tokenizer: sqlparser.NewReaderTokenizer(reader),
}
}
func (parser *MysqlParser) Next() error {
statement, err := sqlparser.ParseNext(parser.tokenizer)
if err != nil {
parser.statement = ""
return err
}
parser.statement = sqlparser.String(statement)
return nil
}
func (parser *MysqlParser) Current() string {
return parser.statement
}
type PostgresParser struct {
scanner *bufio.Scanner
statement string
}
func NewPostgresParser(reader io.Reader) *PostgresParser {
return &PostgresParser{
scanner: splitSqls(reader),
}
}
func (parser *PostgresParser) Next() error {
if !parser.scanner.Scan() {
return io.EOF
}
return nil
}
func (parser *PostgresParser) Current() string {
return parser.scanner.Text()
}
// 根据;\n切割sql
func splitSqls(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
re := regexp.MustCompile(`\s*;\s*\n`)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, io.EOF
}
match := re.FindIndex(data)
if match != nil {
// 如果找到了";\n",判断是否为最后一行
if match[1] == len(data) {
// 如果是最后一行,则返回完整的切片
return len(data), data, nil
}
// 否则,返回到";\n"之后,并且包括";\n"本身
return match[1], data[:match[1]], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
return scanner
}
func SplitStatementToPieces(sql string) ([]string, error) {
return sqlparser.SplitStatementToPieces(sql)
}

View File

@@ -0,0 +1,98 @@
package sqlparser
import (
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"github.com/stretchr/testify/require"
sqlparser_xwb1989 "github.com/xwb1989/sqlparser"
"strings"
"testing"
)
func Test_ParseNext_WithCurrentDate(t *testing.T) {
tests := []struct {
name string
input string
want string
wantXwb1989 string
err string
}{
{
name: "create table with current_timestamp",
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
// xwb1989/sqlparser 不支持 current_timestamp()
wantXwb1989: "create table tbl",
},
{
name: "create table with current_date",
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
// xwb1989/sqlparser 不支持 current_date()
wantXwb1989: "create table tbl",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
tree, err := sqlparser.ParseNext(token)
if len(test.err) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
return
}
require.NoError(t, err)
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.want, sqlparser.String(tree))
})
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
tree, err := sqlparser_xwb1989.ParseNext(token)
if len(test.err) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
return
}
require.NoError(t, err)
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
})
}
}
func Test_SplitSqls(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "create table with current_timestamp",
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
},
{
name: "create table with current_date",
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
},
{
name: "select with ';\n'",
input: "select 'the first line;\nthe second line;\n'",
// SplitSqls split statements by ';\n'
want: "select 'the first line;\n",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scanner := splitSqls(strings.NewReader(test.input))
require.True(t, scanner.Scan())
got := scanner.Text()
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.want, got)
})
}
}