mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-14 05:10:24 +08:00
feat: sql脚本执行调整
This commit is contained in:
@@ -2,18 +2,15 @@ package sqlparser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type DbDialect string
|
||||
|
||||
// const (
|
||||
// mysql DbDialect = "mysql"
|
||||
// pgsql DbDialect = "pgsql"
|
||||
// )
|
||||
|
||||
type SqlParser interface {
|
||||
|
||||
// sql解析
|
||||
@@ -21,63 +18,102 @@ type SqlParser interface {
|
||||
Parse(stmt string) ([]sqlstmt.Stmt, error)
|
||||
}
|
||||
|
||||
// var (
|
||||
// parsers = make(map[string]SqlParser)
|
||||
// )
|
||||
|
||||
// // 注册数据库类型与dbmeta
|
||||
// func Register(dialect string, parser SqlParser) {
|
||||
// parsers[dialect] = parser
|
||||
// }
|
||||
|
||||
// func getParser(dialect string) (SqlParser, error) {
|
||||
// parser, ok := parsers[dialect]
|
||||
// if !ok {
|
||||
// return nil, errors.New("不存在该parser")
|
||||
// }
|
||||
// return parser, nil
|
||||
// }
|
||||
|
||||
// 解析sql
|
||||
// @param dialect 方言
|
||||
// @param stmt sql语句
|
||||
// func Parse(dialect string, stmt string) ([]sqlstmt.Stmt, error) {
|
||||
// if parser, err := getParser(dialect); err != nil {
|
||||
// return nil, err
|
||||
// } else {
|
||||
// return parser.Parse(stmt)
|
||||
// }
|
||||
// }
|
||||
|
||||
var sqlSplitRegexp = regexp.MustCompile(`\s*;\s*\n`)
|
||||
|
||||
// SplitSqls 根据;\n切割sql
|
||||
func SplitSqls(r io.Reader) *bufio.Scanner {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
match := sqlSplitRegexp.FindIndex(data)
|
||||
|
||||
if match != nil {
|
||||
// 如果找到了";\n",判断是否为最后一行
|
||||
if match[1] == len(data) {
|
||||
// 如果是最后一行,则返回完整的切片
|
||||
return len(data), data, nil
|
||||
}
|
||||
// 否则,返回到";\n"之后,并且包括";\n"本身
|
||||
return match[1], data[:match[1]], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
|
||||
return scanner
|
||||
// SQLSplit sql切割
|
||||
func SQLSplit(r io.Reader, callback SQLCallback) error {
|
||||
return parseSQL(r, callback)
|
||||
}
|
||||
|
||||
// SQLCallback 是解析出一条 SQL 语句后的回调函数
|
||||
type SQLCallback func(sql string) error
|
||||
|
||||
func parseSQL(r io.Reader, callback SQLCallback) error {
|
||||
reader := bufio.NewReaderSize(r, 512*1024)
|
||||
buffer := new(bytes.Buffer) // 使用 bytes.Buffer 来处理数据
|
||||
var currentStatement bytes.Buffer
|
||||
var inString bool
|
||||
var inMultiLineComment bool
|
||||
var inSingleLineComment bool
|
||||
var stringDelimiter rune
|
||||
|
||||
for {
|
||||
// 读取数据到缓冲区
|
||||
data, err := reader.ReadBytes('\n') // 按行读取
|
||||
if err == io.EOF && len(data) == 0 {
|
||||
break
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
buffer.Write(data)
|
||||
|
||||
// 处理缓冲区中的数据
|
||||
for buffer.Len() > 0 {
|
||||
r, size := utf8.DecodeRune(buffer.Bytes())
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
// 如果解码出错,说明数据不完整,继续读取更多数据
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case inMultiLineComment:
|
||||
if r == '*' && buffer.Len() >= 2 && buffer.Bytes()[1] == '/' {
|
||||
inMultiLineComment = false
|
||||
buffer.Next(2) // 跳过 '*/'
|
||||
} else {
|
||||
buffer.Next(size)
|
||||
}
|
||||
case inSingleLineComment:
|
||||
if r == '\n' {
|
||||
inSingleLineComment = false
|
||||
}
|
||||
buffer.Next(size)
|
||||
case inString:
|
||||
if r == stringDelimiter {
|
||||
inString = false
|
||||
}
|
||||
currentStatement.WriteRune(r)
|
||||
buffer.Next(size)
|
||||
case r == '/' && buffer.Len() >= 2 && buffer.Bytes()[1] == '*':
|
||||
inMultiLineComment = true
|
||||
buffer.Next(2) // 跳过 '/*'
|
||||
case r == '-' && buffer.Len() >= 2 && buffer.Bytes()[1] == '-':
|
||||
inSingleLineComment = true
|
||||
buffer.Next(2) // 跳过 '--'
|
||||
case r == '\'' || r == '"':
|
||||
inString = true
|
||||
stringDelimiter = r
|
||||
currentStatement.WriteRune(r)
|
||||
buffer.Next(size)
|
||||
case r == ';' && !inString && !inMultiLineComment && !inSingleLineComment:
|
||||
sql := strings.TrimSpace(currentStatement.String())
|
||||
if sql != "" {
|
||||
if err := callback(sql); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
currentStatement.Reset()
|
||||
buffer.Next(size)
|
||||
default:
|
||||
currentStatement.WriteRune(r)
|
||||
buffer.Next(size)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果读取到 EOF 并且缓冲区为空,退出循环
|
||||
if err == io.EOF && buffer.Len() == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 处理最后剩余的缓冲区
|
||||
if currentStatement.Len() > 0 {
|
||||
sql := strings.TrimSpace(currentStatement.String())
|
||||
if sql != "" {
|
||||
if err := callback(sql); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user