Files
mayfly-go/server/internal/common/utils/stmt.go
2024-10-28 12:13:41 +08:00

119 lines
3.1 KiB
Go

package utils
import (
"bufio"
"bytes"
"io"
"strings"
"unicode/utf8"
)
// StmtCallback stmt回调函数
type StmtCallback func(stmt string) error
// SplitStmts 语句切割(用于以;结尾为一条语句,并且去除// -- /**/等注释)主要由阿里通义灵码提供
func SplitStmts(r io.Reader, callback StmtCallback) error {
reader := bufio.NewReaderSize(r, 512*1024)
buffer := new(bytes.Buffer) // 使用 bytes.Buffer 来处理数据
var currentStatement bytes.Buffer
var inString bool
var inMultiLineComment bool
var inSingleLineComment bool
var stringDelimiter rune
var escapeNextChar bool // 用于处理转义符
for {
// 读取数据到缓冲区
data, err := reader.ReadBytes('\n') // 按行读取
if err == io.EOF && len(data) == 0 {
break
}
if err != nil && err != io.EOF {
return err
}
buffer.Write(data)
// 处理缓冲区中的数据
for buffer.Len() > 0 {
r, size := utf8.DecodeRune(buffer.Bytes())
if r == utf8.RuneError && size == 1 {
// 如果解码出错,说明数据不完整,继续读取更多数据
break
}
switch {
case inMultiLineComment:
if r == '*' && buffer.Len() >= 2 && buffer.Bytes()[1] == '/' {
inMultiLineComment = false
buffer.Next(2) // 跳过 '*/'
} else {
buffer.Next(size)
}
case inSingleLineComment:
if r == '\n' {
inSingleLineComment = false
}
buffer.Next(size)
case inString:
if escapeNextChar {
// 当前字符是转义后的字符,直接写入。如后一个为" 避免进入r==stringDelimiter判断被当做字符串结束符中断
currentStatement.WriteRune(r)
escapeNextChar = false
} else if r == '\\' {
// 当前字符是转义符,设置标志位并写入
escapeNextChar = true
currentStatement.WriteRune(r)
} else if r == stringDelimiter {
// 当前字符是字符串结束符,结束字符串处理
inString = false
currentStatement.WriteRune(r)
} else {
// 其他字符,直接写入
currentStatement.WriteRune(r)
}
buffer.Next(size)
case r == '/' && buffer.Len() >= 2 && buffer.Bytes()[1] == '*':
inMultiLineComment = true
buffer.Next(2) // 跳过 '/*'
case r == '-' && buffer.Len() >= 2 && buffer.Bytes()[1] == '-':
inSingleLineComment = true
buffer.Next(2) // 跳过 '--'
case r == '\'' || r == '"':
inString = true
stringDelimiter = r
currentStatement.WriteRune(r)
buffer.Next(size)
case r == ';' && !inString && !inMultiLineComment && !inSingleLineComment:
sql := strings.TrimSpace(currentStatement.String())
if sql != "" {
if err := callback(sql); err != nil {
return err
}
}
currentStatement.Reset()
buffer.Next(size)
default:
currentStatement.WriteRune(r)
buffer.Next(size)
}
}
// 如果读取到 EOF 并且缓冲区为空,退出循环
if err == io.EOF && buffer.Len() == 0 {
break
}
}
// 处理最后剩余的缓冲区
if currentStatement.Len() > 0 {
sql := strings.TrimSpace(currentStatement.String())
if sql != "" {
if err := callback(sql); err != nil {
return err
}
}
}
return nil
}