feat: sql脚本执行调整

This commit is contained in:
meilin.huang
2024-10-18 12:32:53 +08:00
parent e135e4ce64
commit a726927a28
8 changed files with 161 additions and 187 deletions

View File

@@ -2,18 +2,15 @@ package sqlparser
import (
"bufio"
"bytes"
"io"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"regexp"
"strings"
"unicode/utf8"
)
type DbDialect string
// const (
// mysql DbDialect = "mysql"
// pgsql DbDialect = "pgsql"
// )
type SqlParser interface {
// sql解析
@@ -21,63 +18,102 @@ type SqlParser interface {
Parse(stmt string) ([]sqlstmt.Stmt, error)
}
// var (
// parsers = make(map[string]SqlParser)
// )
// // 注册数据库类型与dbmeta
// func Register(dialect string, parser SqlParser) {
// parsers[dialect] = parser
// }
// func getParser(dialect string) (SqlParser, error) {
// parser, ok := parsers[dialect]
// if !ok {
// return nil, errors.New("不存在该parser")
// }
// return parser, nil
// }
// 解析sql
// @param dialect 方言
// @param stmt sql语句
// func Parse(dialect string, stmt string) ([]sqlstmt.Stmt, error) {
// if parser, err := getParser(dialect); err != nil {
// return nil, err
// } else {
// return parser.Parse(stmt)
// }
// }
var sqlSplitRegexp = regexp.MustCompile(`\s*;\s*\n`)
// SplitSqls 根据;\n切割sql
func SplitSqls(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, io.EOF
}
match := sqlSplitRegexp.FindIndex(data)
if match != nil {
// 如果找到了";\n",判断是否为最后一行
if match[1] == len(data) {
// 如果是最后一行,则返回完整的切片
return len(data), data, nil
}
// 否则,返回到";\n"之后,并且包括";\n"本身
return match[1], data[:match[1]], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
return scanner
// SQLSplit sql切割
func SQLSplit(r io.Reader, callback SQLCallback) error {
return parseSQL(r, callback)
}
// SQLCallback 是解析出一条 SQL 语句后的回调函数
type SQLCallback func(sql string) error
func parseSQL(r io.Reader, callback SQLCallback) error {
reader := bufio.NewReaderSize(r, 512*1024)
buffer := new(bytes.Buffer) // 使用 bytes.Buffer 来处理数据
var currentStatement bytes.Buffer
var inString bool
var inMultiLineComment bool
var inSingleLineComment bool
var stringDelimiter rune
for {
// 读取数据到缓冲区
data, err := reader.ReadBytes('\n') // 按行读取
if err == io.EOF && len(data) == 0 {
break
}
if err != nil && err != io.EOF {
return err
}
buffer.Write(data)
// 处理缓冲区中的数据
for buffer.Len() > 0 {
r, size := utf8.DecodeRune(buffer.Bytes())
if r == utf8.RuneError && size == 1 {
// 如果解码出错,说明数据不完整,继续读取更多数据
break
}
switch {
case inMultiLineComment:
if r == '*' && buffer.Len() >= 2 && buffer.Bytes()[1] == '/' {
inMultiLineComment = false
buffer.Next(2) // 跳过 '*/'
} else {
buffer.Next(size)
}
case inSingleLineComment:
if r == '\n' {
inSingleLineComment = false
}
buffer.Next(size)
case inString:
if r == stringDelimiter {
inString = false
}
currentStatement.WriteRune(r)
buffer.Next(size)
case r == '/' && buffer.Len() >= 2 && buffer.Bytes()[1] == '*':
inMultiLineComment = true
buffer.Next(2) // 跳过 '/*'
case r == '-' && buffer.Len() >= 2 && buffer.Bytes()[1] == '-':
inSingleLineComment = true
buffer.Next(2) // 跳过 '--'
case r == '\'' || r == '"':
inString = true
stringDelimiter = r
currentStatement.WriteRune(r)
buffer.Next(size)
case r == ';' && !inString && !inMultiLineComment && !inSingleLineComment:
sql := strings.TrimSpace(currentStatement.String())
if sql != "" {
if err := callback(sql); err != nil {
return err
}
}
currentStatement.Reset()
buffer.Next(size)
default:
currentStatement.WriteRune(r)
buffer.Next(size)
}
}
// 如果读取到 EOF 并且缓冲区为空,退出循环
if err == io.EOF && buffer.Len() == 0 {
break
}
}
// 处理最后剩余的缓冲区
if currentStatement.Len() > 0 {
sql := strings.TrimSpace(currentStatement.String())
if sql != "" {
if err := callback(sql); err != nil {
return err
}
}
}
return nil
}