mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	fix: 执行或导入 SQL 脚本支持 PostgreSQL
This commit is contained in:
		@@ -13,7 +13,7 @@ require (
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.7.1
 | 
			
		||||
	github.com/golang-jwt/jwt/v5 v5.0.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
 | 
			
		||||
	github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231014104824-e3b9aa5415a4
 | 
			
		||||
	github.com/lib/pq v1.10.9
 | 
			
		||||
	github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
 | 
			
		||||
	github.com/mojocn/base64Captcha v1.3.5 // 验证码
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,10 @@ package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/uniqueid"
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"github.com/lib/pq"
 | 
			
		||||
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/internal/db/api/form"
 | 
			
		||||
	"mayfly-go/internal/db/api/vo"
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
@@ -21,13 +18,13 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/gormx"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"mayfly-go/pkg/req"
 | 
			
		||||
	"mayfly-go/pkg/sqlparser"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"mayfly-go/pkg/utils/uniqueid"
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Db struct {
 | 
			
		||||
@@ -179,7 +176,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
 | 
			
		||||
	dbConn := d.DbApp.GetDbConnection(dbId, dbName)
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		var errInfo string
 | 
			
		||||
@@ -190,7 +187,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
			errInfo = t
 | 
			
		||||
		}
 | 
			
		||||
		if len(errInfo) > 0 {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@@ -211,13 +208,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		Terminated:         true,
 | 
			
		||||
	}).WithCategory(progressCategory))
 | 
			
		||||
 | 
			
		||||
	var parser sqlparser.Parser
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		parser = sqlparser.NewMysqlParser(file)
 | 
			
		||||
	} else {
 | 
			
		||||
		parser = sqlparser.NewPostgresParser(file)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sql string
 | 
			
		||||
	tokenizer := sqlparser.NewReaderTokenizer(file, sqlparser.WithCacheInBuffer())
 | 
			
		||||
	ticker := time.NewTicker(time.Second * 1)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
@@ -231,24 +223,30 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
			}).WithCategory(progressCategory))
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
		err = parser.Next()
 | 
			
		||||
		sql, err = sqlparser.SplitNext(tokenizer)
 | 
			
		||||
		if err == io.EOF {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		sql := parser.Current()
 | 
			
		||||
		const prefixUse = "use "
 | 
			
		||||
		if strings.HasPrefix(sql, prefixUse) {
 | 
			
		||||
			dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
			
		||||
			if len(dbNameExec) > 0 {
 | 
			
		||||
				dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
 | 
			
		||||
		const prefixUSE = "USE "
 | 
			
		||||
		if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) {
 | 
			
		||||
			var stmt sqlparser.Statement
 | 
			
		||||
			stmt, err = sqlparser.Parse(sql)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
			
		||||
			}
 | 
			
		||||
			stmtUse, ok := stmt.(*sqlparser.Use)
 | 
			
		||||
			if !ok {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql)))
 | 
			
		||||
			}
 | 
			
		||||
			dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
 | 
			
		||||
			biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
			execReq.DbConn = dbConn
 | 
			
		||||
		}
 | 
			
		||||
		}
 | 
			
		||||
		// 需要记录执行记录
 | 
			
		||||
		const maxRecordStatements = 64
 | 
			
		||||
		if executedStatements < maxRecordStatements {
 | 
			
		||||
@@ -264,7 +262,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		}
 | 
			
		||||
		executedStatements++
 | 
			
		||||
	}
 | 
			
		||||
	d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
 | 
			
		||||
	d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成:%s", rc.ReqParam)))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 数据库dump
 | 
			
		||||
 
 | 
			
		||||
@@ -1,99 +0,0 @@
 | 
			
		||||
package sqlparser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"io"
 | 
			
		||||
	"regexp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Parser interface {
 | 
			
		||||
	Next() error
 | 
			
		||||
	Current() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Parser = &MysqlParser{}
 | 
			
		||||
var _ Parser = &PostgresParser{}
 | 
			
		||||
 | 
			
		||||
type MysqlParser struct {
 | 
			
		||||
	tokenizer *sqlparser.Tokenizer
 | 
			
		||||
	statement string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMysqlParser(reader io.Reader) *MysqlParser {
 | 
			
		||||
	return &MysqlParser{
 | 
			
		||||
		tokenizer: sqlparser.NewReaderTokenizer(reader),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *MysqlParser) Next() error {
 | 
			
		||||
	statement, err := sqlparser.ParseNext(parser.tokenizer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		parser.statement = ""
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	parser.statement = sqlparser.String(statement)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *MysqlParser) Current() string {
 | 
			
		||||
	return parser.statement
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PostgresParser struct {
 | 
			
		||||
	scanner   *bufio.Scanner
 | 
			
		||||
	statement string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPostgresParser(reader io.Reader) *PostgresParser {
 | 
			
		||||
	return &PostgresParser{
 | 
			
		||||
		scanner: splitSqls(reader),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *PostgresParser) Next() error {
 | 
			
		||||
	if !parser.scanner.Scan() {
 | 
			
		||||
		return io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *PostgresParser) Current() string {
 | 
			
		||||
	return parser.scanner.Text()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据;\n切割sql
 | 
			
		||||
func splitSqls(r io.Reader) *bufio.Scanner {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	re := regexp.MustCompile(`\s*;\s*\n`)
 | 
			
		||||
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, io.EOF
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		match := re.FindIndex(data)
 | 
			
		||||
 | 
			
		||||
		if match != nil {
 | 
			
		||||
			// 如果找到了";\n",判断是否为最后一行
 | 
			
		||||
			if match[1] == len(data) {
 | 
			
		||||
				// 如果是最后一行,则返回完整的切片
 | 
			
		||||
				return len(data), data, nil
 | 
			
		||||
			}
 | 
			
		||||
			// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
			
		||||
			return match[1], data[:match[1]], nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return scanner
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SplitStatementToPieces(sql string) ([]string, error) {
 | 
			
		||||
	return sqlparser.SplitStatementToPieces(sql)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,98 +0,0 @@
 | 
			
		||||
package sqlparser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		input       string
 | 
			
		||||
		want        string
 | 
			
		||||
		wantXwb1989 string
 | 
			
		||||
		err         string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_timestamp",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
			
		||||
			// xwb1989/sqlparser 不支持 current_timestamp()
 | 
			
		||||
			wantXwb1989: "create table tbl",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_date",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
			
		||||
			// xwb1989/sqlparser 不支持 current_date()
 | 
			
		||||
			wantXwb1989: "create table tbl",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
 | 
			
		||||
			tree, err := sqlparser.ParseNext(token)
 | 
			
		||||
			if len(test.err) > 0 {
 | 
			
		||||
				require.Error(t, err)
 | 
			
		||||
				require.Contains(t, err.Error(), test.err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
			if len(test.want) == 0 {
 | 
			
		||||
				test.want = test.input
 | 
			
		||||
			}
 | 
			
		||||
			require.Equal(t, test.want, sqlparser.String(tree))
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// for _, test := range tests {
 | 
			
		||||
	// 	t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
	// 		token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
 | 
			
		||||
	// 		tree, err := sqlparser_xwb1989.ParseNext(token)
 | 
			
		||||
	// 		if len(test.err) > 0 {
 | 
			
		||||
	// 			require.Error(t, err)
 | 
			
		||||
	// 			require.Contains(t, err.Error(), test.err)
 | 
			
		||||
	// 			return
 | 
			
		||||
	// 		}
 | 
			
		||||
	// 		require.NoError(t, err)
 | 
			
		||||
	// 		if len(test.want) == 0 {
 | 
			
		||||
	// 			test.want = test.input
 | 
			
		||||
	// 		}
 | 
			
		||||
	// 		require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
 | 
			
		||||
	// 	})
 | 
			
		||||
	// }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_SplitSqls(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		input string
 | 
			
		||||
		want  string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_timestamp",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_date",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "select with ';\n'",
 | 
			
		||||
			input: "select 'the first line;\nthe second line;\n'",
 | 
			
		||||
			// SplitSqls split statements by ';\n'
 | 
			
		||||
			want: "select 'the first line;\n",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			scanner := splitSqls(strings.NewReader(test.input))
 | 
			
		||||
			require.True(t, scanner.Scan())
 | 
			
		||||
			got := scanner.Text()
 | 
			
		||||
			if len(test.want) == 0 {
 | 
			
		||||
				test.want = test.input
 | 
			
		||||
			}
 | 
			
		||||
			require.Equal(t, test.want, got)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user