mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 07:20:24 +08:00
Merge pull request #71 from kanzihuang/fix-exec-postgres-sql-pullrequest
fix: 执行或导入 SQL 脚本支持 PostgreSQL
This commit is contained in:
@@ -13,7 +13,7 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
|
||||
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231014104824-e3b9aa5415a4
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
|
||||
github.com/mojocn/base64Captcha v1.3.5 // 验证码
|
||||
|
||||
@@ -2,13 +2,10 @@ package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/uniqueid"
|
||||
"mayfly-go/pkg/ws"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"io"
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
"mayfly-go/internal/db/application"
|
||||
@@ -21,13 +18,13 @@ import (
|
||||
"mayfly-go/pkg/gormx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/req"
|
||||
"mayfly-go/pkg/sqlparser"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"mayfly-go/pkg/utils/uniqueid"
|
||||
"mayfly-go/pkg/ws"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Db struct {
|
||||
@@ -179,7 +176,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
|
||||
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
|
||||
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
|
||||
|
||||
defer func() {
|
||||
var errInfo string
|
||||
@@ -190,7 +187,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
errInfo = t
|
||||
}
|
||||
if len(errInfo) > 0 {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -211,13 +208,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
|
||||
var parser sqlparser.Parser
|
||||
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||
parser = sqlparser.NewMysqlParser(file)
|
||||
} else {
|
||||
parser = sqlparser.NewPostgresParser(file)
|
||||
}
|
||||
|
||||
var sql string
|
||||
tokenizer := sqlparser.NewReaderTokenizer(file, sqlparser.WithCacheInBuffer())
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
@@ -231,23 +223,29 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
err = parser.Next()
|
||||
sql, err = sqlparser.SplitNext(tokenizer)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
|
||||
return
|
||||
}
|
||||
sql := parser.Current()
|
||||
const prefixUse = "use "
|
||||
if strings.HasPrefix(sql, prefixUse) {
|
||||
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
|
||||
if len(dbNameExec) > 0 {
|
||||
dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
execReq.DbConn = dbConn
|
||||
const prefixUSE = "USE "
|
||||
if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) {
|
||||
var stmt sqlparser.Statement
|
||||
stmt, err = sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
|
||||
}
|
||||
stmtUse, ok := stmt.(*sqlparser.Use)
|
||||
if !ok {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql)))
|
||||
}
|
||||
dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
execReq.DbConn = dbConn
|
||||
}
|
||||
// 需要记录执行记录
|
||||
const maxRecordStatements = 64
|
||||
@@ -264,7 +262,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
}
|
||||
executedStatements++
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成:%s", rc.ReqParam)))
|
||||
}
|
||||
|
||||
// 数据库dump
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"io"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type Parser interface {
|
||||
Next() error
|
||||
Current() string
|
||||
}
|
||||
|
||||
var _ Parser = &MysqlParser{}
|
||||
var _ Parser = &PostgresParser{}
|
||||
|
||||
type MysqlParser struct {
|
||||
tokenizer *sqlparser.Tokenizer
|
||||
statement string
|
||||
}
|
||||
|
||||
func NewMysqlParser(reader io.Reader) *MysqlParser {
|
||||
return &MysqlParser{
|
||||
tokenizer: sqlparser.NewReaderTokenizer(reader),
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *MysqlParser) Next() error {
|
||||
statement, err := sqlparser.ParseNext(parser.tokenizer)
|
||||
if err != nil {
|
||||
parser.statement = ""
|
||||
return err
|
||||
}
|
||||
parser.statement = sqlparser.String(statement)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (parser *MysqlParser) Current() string {
|
||||
return parser.statement
|
||||
}
|
||||
|
||||
type PostgresParser struct {
|
||||
scanner *bufio.Scanner
|
||||
statement string
|
||||
}
|
||||
|
||||
func NewPostgresParser(reader io.Reader) *PostgresParser {
|
||||
return &PostgresParser{
|
||||
scanner: splitSqls(reader),
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *PostgresParser) Next() error {
|
||||
if !parser.scanner.Scan() {
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (parser *PostgresParser) Current() string {
|
||||
return parser.scanner.Text()
|
||||
}
|
||||
|
||||
// 根据;\n切割sql
|
||||
func splitSqls(r io.Reader) *bufio.Scanner {
|
||||
scanner := bufio.NewScanner(r)
|
||||
re := regexp.MustCompile(`\s*;\s*\n`)
|
||||
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
match := re.FindIndex(data)
|
||||
|
||||
if match != nil {
|
||||
// 如果找到了";\n",判断是否为最后一行
|
||||
if match[1] == len(data) {
|
||||
// 如果是最后一行,则返回完整的切片
|
||||
return len(data), data, nil
|
||||
}
|
||||
// 否则,返回到";\n"之后,并且包括";\n"本身
|
||||
return match[1], data[:match[1]], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
|
||||
return scanner
|
||||
}
|
||||
|
||||
func SplitStatementToPieces(sql string) ([]string, error) {
|
||||
return sqlparser.SplitStatementToPieces(sql)
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantXwb1989 string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "create table with current_timestamp",
|
||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||
// xwb1989/sqlparser 不支持 current_timestamp()
|
||||
wantXwb1989: "create table tbl",
|
||||
},
|
||||
{
|
||||
name: "create table with current_date",
|
||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||
// xwb1989/sqlparser 不支持 current_date()
|
||||
wantXwb1989: "create table tbl",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
|
||||
tree, err := sqlparser.ParseNext(token)
|
||||
if len(test.err) > 0 {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), test.err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if len(test.want) == 0 {
|
||||
test.want = test.input
|
||||
}
|
||||
require.Equal(t, test.want, sqlparser.String(tree))
|
||||
})
|
||||
}
|
||||
// for _, test := range tests {
|
||||
// t.Run(test.name, func(t *testing.T) {
|
||||
// token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
|
||||
// tree, err := sqlparser_xwb1989.ParseNext(token)
|
||||
// if len(test.err) > 0 {
|
||||
// require.Error(t, err)
|
||||
// require.Contains(t, err.Error(), test.err)
|
||||
// return
|
||||
// }
|
||||
// require.NoError(t, err)
|
||||
// if len(test.want) == 0 {
|
||||
// test.want = test.input
|
||||
// }
|
||||
// require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
|
||||
// })
|
||||
// }
|
||||
}
|
||||
|
||||
func Test_SplitSqls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "create table with current_timestamp",
|
||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||
},
|
||||
{
|
||||
name: "create table with current_date",
|
||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||
},
|
||||
{
|
||||
name: "select with ';\n'",
|
||||
input: "select 'the first line;\nthe second line;\n'",
|
||||
// SplitSqls split statements by ';\n'
|
||||
want: "select 'the first line;\n",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
scanner := splitSqls(strings.NewReader(test.input))
|
||||
require.True(t, scanner.Scan())
|
||||
got := scanner.Text()
|
||||
if len(test.want) == 0 {
|
||||
test.want = test.input
|
||||
}
|
||||
require.Equal(t, test.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user