mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-19 01:15:40 +08:00
419 lines
9.0 KiB
Go
419 lines
9.0 KiB
Go
package tokenizer
|
||
|
||
import (
|
||
"strings"
|
||
"unicode"
|
||
)
|
||
|
||
// DialectConfig 定义不同 SQL 方言的配置
|
||
type DialectConfig struct {
|
||
// 反引号作为标识符引号(MySQL)
|
||
BacktickAsIdentifier bool
|
||
// 双引号作为标识符引号(PostgreSQL)
|
||
DoubleQuoteAsIdentifier bool
|
||
// 支持 # 行注释(MySQL)
|
||
HashLineComment bool
|
||
// 支持 $tag$ 风格字符串/标识符(PostgreSQL)
|
||
DollarQuote bool
|
||
// 额外关键字集合(合并到标准关键字中)
|
||
ExtraKeywords map[string]bool
|
||
}
|
||
|
||
// Tokenizer 将 SQL 字符串拆分为 Token 序列
|
||
type Tokenizer struct {
|
||
sql string
|
||
pos int
|
||
length int
|
||
config DialectConfig
|
||
Tokens []Token
|
||
current int
|
||
}
|
||
|
||
// New 创建一个新的 Tokenizer
|
||
func New(sql string, config DialectConfig) *Tokenizer {
|
||
t := &Tokenizer{
|
||
sql: sql,
|
||
pos: 0,
|
||
length: len(sql),
|
||
config: config,
|
||
Tokens: make([]Token, 0),
|
||
}
|
||
t.tokenize()
|
||
// 追加 EOF token
|
||
t.Tokens = append(t.Tokens, Token{Type: TokenEOF, Value: "", Pos: t.length, End: t.length})
|
||
return t
|
||
}
|
||
|
||
// tokenize 执行词法分析
|
||
func (t *Tokenizer) tokenize() {
|
||
for t.pos < t.length {
|
||
ch := t.sql[t.pos]
|
||
|
||
// 跳过空白字符
|
||
if isWhitespace(ch) {
|
||
t.pos++
|
||
continue
|
||
}
|
||
|
||
// 行注释 --
|
||
if ch == '-' && t.pos+1 < t.length && t.sql[t.pos+1] == '-' {
|
||
t.skipLineComment()
|
||
continue
|
||
}
|
||
|
||
// MySQL # 行注释
|
||
if t.config.HashLineComment && ch == '#' {
|
||
t.skipLineComment()
|
||
continue
|
||
}
|
||
|
||
// 块注释 /* */
|
||
if ch == '/' && t.pos+1 < t.length && t.sql[t.pos+1] == '*' {
|
||
t.skipBlockComment()
|
||
continue
|
||
}
|
||
|
||
// 字符串字面量 '...'
|
||
if ch == '\'' {
|
||
t.readString('\'')
|
||
continue
|
||
}
|
||
|
||
// 双引号字符串 "..."(如果不作为标识符引号)
|
||
if ch == '"' && !t.config.DoubleQuoteAsIdentifier {
|
||
t.readString('"')
|
||
continue
|
||
}
|
||
|
||
// 双引号标识符 "..."(PostgreSQL)
|
||
if ch == '"' && t.config.DoubleQuoteAsIdentifier {
|
||
t.readQuotedIdentifier('"')
|
||
continue
|
||
}
|
||
|
||
// 反引号标识符 `...`(MySQL)
|
||
if ch == '`' && t.config.BacktickAsIdentifier {
|
||
t.readQuotedIdentifier('`')
|
||
continue
|
||
}
|
||
|
||
// PostgreSQL $tag$ ... $tag$
|
||
if t.config.DollarQuote && ch == '$' {
|
||
if t.readDollarQuote() {
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 数字
|
||
if isDigit(ch) {
|
||
t.readNumber()
|
||
continue
|
||
}
|
||
|
||
// 标识符或关键字(字母、_、@ 开头)
|
||
if isIdentifierStart(ch) {
|
||
t.readIdentifierOrKeyword()
|
||
continue
|
||
}
|
||
|
||
// 运算符和标点符号
|
||
if isOperatorStart(ch) {
|
||
t.readOperator()
|
||
continue
|
||
}
|
||
|
||
if isPunctuation(ch) {
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenPunctuation,
|
||
Value: string(ch),
|
||
Pos: t.pos,
|
||
End: t.pos + 1,
|
||
})
|
||
t.pos++
|
||
continue
|
||
}
|
||
|
||
// 未知字符,跳过
|
||
t.pos++
|
||
}
|
||
}
|
||
|
||
// skipLineComment 跳过一个行注释(到行尾或 EOF)
|
||
func (t *Tokenizer) skipLineComment() {
|
||
for t.pos < t.length && t.sql[t.pos] != '\n' {
|
||
t.pos++
|
||
}
|
||
}
|
||
|
||
// skipBlockComment 跳过一个块注释 /* */
|
||
func (t *Tokenizer) skipBlockComment() {
|
||
t.pos += 2 // 跳过 /*
|
||
for t.pos < t.length {
|
||
if t.sql[t.pos] == '*' && t.pos+1 < t.length && t.sql[t.pos+1] == '/' {
|
||
t.pos += 2
|
||
return
|
||
}
|
||
t.pos++
|
||
}
|
||
}
|
||
|
||
// readString 读取一个单引号或双引号字符串字面量
|
||
func (t *Tokenizer) readString(quote byte) {
|
||
start := t.pos
|
||
t.pos++ // 跳过起始引号
|
||
for t.pos < t.length {
|
||
ch := t.sql[t.pos]
|
||
if ch == quote {
|
||
// 检查是否是转义(两个连续引号)
|
||
if t.pos+1 < t.length && t.sql[t.pos+1] == quote {
|
||
t.pos += 2
|
||
continue
|
||
}
|
||
t.pos++ // 跳过结束引号
|
||
break
|
||
}
|
||
// MySQL 风格转义 \'
|
||
if ch == '\\' && t.pos+1 < t.length {
|
||
t.pos += 2
|
||
continue
|
||
}
|
||
t.pos++
|
||
}
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenString,
|
||
Value: t.sql[start:t.pos],
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
}
|
||
|
||
// readQuotedIdentifier 读取一个带引号的标识符(反引号或双引号)
|
||
func (t *Tokenizer) readQuotedIdentifier(quote byte) {
|
||
start := t.pos
|
||
t.pos++ // 跳过起始引号
|
||
for t.pos < t.length {
|
||
ch := t.sql[t.pos]
|
||
if ch == quote {
|
||
// 检查转义引号(如 "a""b" 或 ``a``b``)
|
||
if t.pos+1 < t.length && t.sql[t.pos+1] == quote {
|
||
t.pos += 2
|
||
continue
|
||
}
|
||
t.pos++ // 跳过结束引号
|
||
break
|
||
}
|
||
t.pos++
|
||
}
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenIdentifier,
|
||
Value: t.sql[start:t.pos],
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
}
|
||
|
||
// readDollarQuote 读取 PostgreSQL $tag$ ... $tag$ 风格的引号内容
|
||
func (t *Tokenizer) readDollarQuote() bool {
|
||
start := t.pos
|
||
// 读取 $tag$
|
||
tagEnd := t.readDollarTag()
|
||
if tagEnd < 0 {
|
||
return false
|
||
}
|
||
tag := t.sql[start : tagEnd+1] // 包含 $tag$
|
||
// 查找结束标记(从 tag 之后开始)
|
||
searchPos := tagEnd + 1
|
||
for searchPos < t.length {
|
||
if strings.HasPrefix(t.sql[searchPos:], tag) {
|
||
searchPos += len(tag)
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenString,
|
||
Value: t.sql[start:searchPos],
|
||
Pos: start,
|
||
End: searchPos,
|
||
})
|
||
t.pos = searchPos
|
||
return true
|
||
}
|
||
searchPos++
|
||
}
|
||
// 未找到结束标记,回退
|
||
return false
|
||
}
|
||
|
||
// readDollarTag 读取 PostgreSQL $tag$ 中的 tag,返回结束 $ 的位置
|
||
func (t *Tokenizer) readDollarTag() int {
|
||
if t.sql[t.pos] != '$' {
|
||
return -1
|
||
}
|
||
pos := t.pos + 1
|
||
for pos < t.length {
|
||
ch := t.sql[pos]
|
||
if ch == '$' {
|
||
return pos
|
||
}
|
||
if !unicode.IsLetter(rune(ch)) && !unicode.IsDigit(rune(ch)) && ch != '_' {
|
||
return -1
|
||
}
|
||
pos++
|
||
}
|
||
return -1
|
||
}
|
||
|
||
// readNumber 读取一个数字(整数或浮点数)
|
||
func (t *Tokenizer) readNumber() {
|
||
start := t.pos
|
||
for t.pos < t.length && (isDigit(t.sql[t.pos]) || t.sql[t.pos] == '.') {
|
||
t.pos++
|
||
}
|
||
// 支持科学计数法 e.g. 1e10, 1.5E-3
|
||
if t.pos < t.length && (t.sql[t.pos] == 'e' || t.sql[t.pos] == 'E') {
|
||
t.pos++
|
||
if t.pos < t.length && (t.sql[t.pos] == '+' || t.sql[t.pos] == '-') {
|
||
t.pos++
|
||
}
|
||
for t.pos < t.length && isDigit(t.sql[t.pos]) {
|
||
t.pos++
|
||
}
|
||
}
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenNumber,
|
||
Value: t.sql[start:t.pos],
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
}
|
||
|
||
// readIdentifierOrKeyword 读取标识符或关键字
|
||
func (t *Tokenizer) readIdentifierOrKeyword() {
|
||
start := t.pos
|
||
for t.pos < t.length && isIdentifierPart(t.sql[t.pos]) {
|
||
t.pos++
|
||
}
|
||
value := t.sql[start:t.pos]
|
||
upper := strings.ToUpper(value)
|
||
|
||
tokType := TokenIdentifier
|
||
if Keywords[upper] {
|
||
tokType = TokenKeyword
|
||
} else if t.config.ExtraKeywords[upper] {
|
||
tokType = TokenKeyword
|
||
}
|
||
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: tokType,
|
||
Value: value,
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
}
|
||
|
||
// readOperator 读取运算符
|
||
func (t *Tokenizer) readOperator() {
|
||
start := t.pos
|
||
// 尝试读取多字符运算符
|
||
if t.pos+1 < t.length {
|
||
two := t.sql[t.pos : t.pos+2]
|
||
if two == "<=" || two == ">=" || two == "<>" || two == "!=" ||
|
||
two == "||" || two == "::" || two == "->" || two == "->>" ||
|
||
two == "=>" || two == ".." {
|
||
// PostgreSQL :: 类型转换, -> JSON, ->> JSON text, => key-value, .. 范围
|
||
t.pos += 2
|
||
// 检查 ->>(三字符)
|
||
if two == "->" && t.pos < t.length && t.sql[t.pos] == '>' {
|
||
t.pos++
|
||
}
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenOperator,
|
||
Value: t.sql[start:t.pos],
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
return
|
||
}
|
||
}
|
||
t.pos++
|
||
t.Tokens = append(t.Tokens, Token{
|
||
Type: TokenOperator,
|
||
Value: t.sql[start:t.pos],
|
||
Pos: start,
|
||
End: t.pos,
|
||
})
|
||
}
|
||
|
||
// Peek 预览当前 token(不移动位置)
|
||
func (t *Tokenizer) Peek() Token {
|
||
if t.current >= len(t.Tokens) {
|
||
return t.Tokens[len(t.Tokens)-1] // EOF
|
||
}
|
||
return t.Tokens[t.current]
|
||
}
|
||
|
||
// Next 返回当前 token 并移动到下一个
|
||
func (t *Tokenizer) Next() Token {
|
||
if t.current >= len(t.Tokens) {
|
||
return t.Tokens[len(t.Tokens)-1] // EOF
|
||
}
|
||
tok := t.Tokens[t.current]
|
||
t.current++
|
||
return tok
|
||
}
|
||
|
||
// Consume 消耗当前位置的 token(等同于 Next,为了可读性)
|
||
func (t *Tokenizer) Consume() Token {
|
||
return t.Next()
|
||
}
|
||
|
||
// Pos 返回当前 token 索引
|
||
func (t *Tokenizer) Pos() int {
|
||
return t.current
|
||
}
|
||
|
||
// SetPos 设置当前 token 索引
|
||
func (t *Tokenizer) SetPos(p int) {
|
||
t.current = p
|
||
}
|
||
|
||
// Length 返回 token 总数
|
||
func (t *Tokenizer) Length() int {
|
||
return len(t.Tokens)
|
||
}
|
||
|
||
// TokenAt 获取指定索引的 token
|
||
func (t *Tokenizer) TokenAt(idx int) Token {
|
||
if idx < 0 || idx >= len(t.Tokens) {
|
||
return t.Tokens[len(t.Tokens)-1]
|
||
}
|
||
return t.Tokens[idx]
|
||
}
|
||
|
||
// helper functions
|
||
func isWhitespace(ch byte) bool {
|
||
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
|
||
}
|
||
|
||
func isDigit(ch byte) bool {
|
||
return ch >= '0' && ch <= '9'
|
||
}
|
||
|
||
func isLetter(ch byte) bool {
|
||
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
|
||
}
|
||
|
||
func isIdentifierStart(ch byte) bool {
|
||
return isLetter(ch) || ch == '_' || ch == '@'
|
||
}
|
||
|
||
func isIdentifierPart(ch byte) bool {
|
||
return isLetter(ch) || isDigit(ch) || ch == '_' || ch == '@' || ch == '$'
|
||
}
|
||
|
||
func isOperatorStart(ch byte) bool {
|
||
return ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '=' ||
|
||
ch == '<' || ch == '>' || ch == '!' || ch == '|' || ch == ':' || ch == '~'
|
||
}
|
||
|
||
func isPunctuation(ch byte) bool {
|
||
return ch == '(' || ch == ')' || ch == ',' || ch == ';' || ch == '.'
|
||
}
|