Files
mayfly-go/server/internal/db/dbm/sqlparser/tokenizer/tokenizer.go
2026-05-08 20:45:13 +08:00

419 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tokenizer
import (
"strings"
"unicode"
)
// DialectConfig 定义不同 SQL 方言的配置
type DialectConfig struct {
// 反引号作为标识符引号MySQL
BacktickAsIdentifier bool
// 双引号作为标识符引号PostgreSQL
DoubleQuoteAsIdentifier bool
// 支持 # 行注释MySQL
HashLineComment bool
// 支持 $tag$ 风格字符串/标识符PostgreSQL
DollarQuote bool
// 额外关键字集合(合并到标准关键字中)
ExtraKeywords map[string]bool
}
// Tokenizer 将 SQL 字符串拆分为 Token 序列
type Tokenizer struct {
sql string
pos int
length int
config DialectConfig
Tokens []Token
current int
}
// New 创建一个新的 Tokenizer
func New(sql string, config DialectConfig) *Tokenizer {
t := &Tokenizer{
sql: sql,
pos: 0,
length: len(sql),
config: config,
Tokens: make([]Token, 0),
}
t.tokenize()
// 追加 EOF token
t.Tokens = append(t.Tokens, Token{Type: TokenEOF, Value: "", Pos: t.length, End: t.length})
return t
}
// tokenize 执行词法分析
func (t *Tokenizer) tokenize() {
for t.pos < t.length {
ch := t.sql[t.pos]
// 跳过空白字符
if isWhitespace(ch) {
t.pos++
continue
}
// 行注释 --
if ch == '-' && t.pos+1 < t.length && t.sql[t.pos+1] == '-' {
t.skipLineComment()
continue
}
// MySQL # 行注释
if t.config.HashLineComment && ch == '#' {
t.skipLineComment()
continue
}
// 块注释 /* */
if ch == '/' && t.pos+1 < t.length && t.sql[t.pos+1] == '*' {
t.skipBlockComment()
continue
}
// 字符串字面量 '...'
if ch == '\'' {
t.readString('\'')
continue
}
// 双引号字符串 "..."(如果不作为标识符引号)
if ch == '"' && !t.config.DoubleQuoteAsIdentifier {
t.readString('"')
continue
}
// 双引号标识符 "..."PostgreSQL
if ch == '"' && t.config.DoubleQuoteAsIdentifier {
t.readQuotedIdentifier('"')
continue
}
// 反引号标识符 `...`MySQL
if ch == '`' && t.config.BacktickAsIdentifier {
t.readQuotedIdentifier('`')
continue
}
// PostgreSQL $tag$ ... $tag$
if t.config.DollarQuote && ch == '$' {
if t.readDollarQuote() {
continue
}
}
// 数字
if isDigit(ch) {
t.readNumber()
continue
}
// 标识符或关键字字母、_、@ 开头)
if isIdentifierStart(ch) {
t.readIdentifierOrKeyword()
continue
}
// 运算符和标点符号
if isOperatorStart(ch) {
t.readOperator()
continue
}
if isPunctuation(ch) {
t.Tokens = append(t.Tokens, Token{
Type: TokenPunctuation,
Value: string(ch),
Pos: t.pos,
End: t.pos + 1,
})
t.pos++
continue
}
// 未知字符,跳过
t.pos++
}
}
// skipLineComment 跳过一个行注释(到行尾或 EOF
func (t *Tokenizer) skipLineComment() {
for t.pos < t.length && t.sql[t.pos] != '\n' {
t.pos++
}
}
// skipBlockComment 跳过一个块注释 /* */
func (t *Tokenizer) skipBlockComment() {
t.pos += 2 // 跳过 /*
for t.pos < t.length {
if t.sql[t.pos] == '*' && t.pos+1 < t.length && t.sql[t.pos+1] == '/' {
t.pos += 2
return
}
t.pos++
}
}
// readString 读取一个单引号或双引号字符串字面量
func (t *Tokenizer) readString(quote byte) {
start := t.pos
t.pos++ // 跳过起始引号
for t.pos < t.length {
ch := t.sql[t.pos]
if ch == quote {
// 检查是否是转义(两个连续引号)
if t.pos+1 < t.length && t.sql[t.pos+1] == quote {
t.pos += 2
continue
}
t.pos++ // 跳过结束引号
break
}
// MySQL 风格转义 \'
if ch == '\\' && t.pos+1 < t.length {
t.pos += 2
continue
}
t.pos++
}
t.Tokens = append(t.Tokens, Token{
Type: TokenString,
Value: t.sql[start:t.pos],
Pos: start,
End: t.pos,
})
}
// readQuotedIdentifier 读取一个带引号的标识符(反引号或双引号)
func (t *Tokenizer) readQuotedIdentifier(quote byte) {
start := t.pos
t.pos++ // 跳过起始引号
for t.pos < t.length {
ch := t.sql[t.pos]
if ch == quote {
// 检查转义引号(如 "a""b" 或 ``a``b``
if t.pos+1 < t.length && t.sql[t.pos+1] == quote {
t.pos += 2
continue
}
t.pos++ // 跳过结束引号
break
}
t.pos++
}
t.Tokens = append(t.Tokens, Token{
Type: TokenIdentifier,
Value: t.sql[start:t.pos],
Pos: start,
End: t.pos,
})
}
// readDollarQuote 读取 PostgreSQL $tag$ ... $tag$ 风格的引号内容
func (t *Tokenizer) readDollarQuote() bool {
start := t.pos
// 读取 $tag$
tagEnd := t.readDollarTag()
if tagEnd < 0 {
return false
}
tag := t.sql[start : tagEnd+1] // 包含 $tag$
// 查找结束标记(从 tag 之后开始)
searchPos := tagEnd + 1
for searchPos < t.length {
if strings.HasPrefix(t.sql[searchPos:], tag) {
searchPos += len(tag)
t.Tokens = append(t.Tokens, Token{
Type: TokenString,
Value: t.sql[start:searchPos],
Pos: start,
End: searchPos,
})
t.pos = searchPos
return true
}
searchPos++
}
// 未找到结束标记,回退
return false
}
// readDollarTag 读取 PostgreSQL $tag$ 中的 tag返回结束 $ 的位置
func (t *Tokenizer) readDollarTag() int {
if t.sql[t.pos] != '$' {
return -1
}
pos := t.pos + 1
for pos < t.length {
ch := t.sql[pos]
if ch == '$' {
return pos
}
if !unicode.IsLetter(rune(ch)) && !unicode.IsDigit(rune(ch)) && ch != '_' {
return -1
}
pos++
}
return -1
}
// readNumber 读取一个数字(整数或浮点数)
func (t *Tokenizer) readNumber() {
start := t.pos
for t.pos < t.length && (isDigit(t.sql[t.pos]) || t.sql[t.pos] == '.') {
t.pos++
}
// 支持科学计数法 e.g. 1e10, 1.5E-3
if t.pos < t.length && (t.sql[t.pos] == 'e' || t.sql[t.pos] == 'E') {
t.pos++
if t.pos < t.length && (t.sql[t.pos] == '+' || t.sql[t.pos] == '-') {
t.pos++
}
for t.pos < t.length && isDigit(t.sql[t.pos]) {
t.pos++
}
}
t.Tokens = append(t.Tokens, Token{
Type: TokenNumber,
Value: t.sql[start:t.pos],
Pos: start,
End: t.pos,
})
}
// readIdentifierOrKeyword 读取标识符或关键字
func (t *Tokenizer) readIdentifierOrKeyword() {
start := t.pos
for t.pos < t.length && isIdentifierPart(t.sql[t.pos]) {
t.pos++
}
value := t.sql[start:t.pos]
upper := strings.ToUpper(value)
tokType := TokenIdentifier
if Keywords[upper] {
tokType = TokenKeyword
} else if t.config.ExtraKeywords[upper] {
tokType = TokenKeyword
}
t.Tokens = append(t.Tokens, Token{
Type: tokType,
Value: value,
Pos: start,
End: t.pos,
})
}
// readOperator 读取运算符
func (t *Tokenizer) readOperator() {
start := t.pos
// 尝试读取多字符运算符
if t.pos+1 < t.length {
two := t.sql[t.pos : t.pos+2]
if two == "<=" || two == ">=" || two == "<>" || two == "!=" ||
two == "||" || two == "::" || two == "->" || two == "->>" ||
two == "=>" || two == ".." {
// PostgreSQL :: 类型转换, -> JSON, ->> JSON text, => key-value, .. 范围
t.pos += 2
// 检查 ->>(三字符)
if two == "->" && t.pos < t.length && t.sql[t.pos] == '>' {
t.pos++
}
t.Tokens = append(t.Tokens, Token{
Type: TokenOperator,
Value: t.sql[start:t.pos],
Pos: start,
End: t.pos,
})
return
}
}
t.pos++
t.Tokens = append(t.Tokens, Token{
Type: TokenOperator,
Value: t.sql[start:t.pos],
Pos: start,
End: t.pos,
})
}
// Peek 预览当前 token不移动位置
func (t *Tokenizer) Peek() Token {
if t.current >= len(t.Tokens) {
return t.Tokens[len(t.Tokens)-1] // EOF
}
return t.Tokens[t.current]
}
// Next 返回当前 token 并移动到下一个
func (t *Tokenizer) Next() Token {
if t.current >= len(t.Tokens) {
return t.Tokens[len(t.Tokens)-1] // EOF
}
tok := t.Tokens[t.current]
t.current++
return tok
}
// Consume 消耗当前位置的 token等同于 Next为了可读性
func (t *Tokenizer) Consume() Token {
return t.Next()
}
// Pos 返回当前 token 索引
func (t *Tokenizer) Pos() int {
return t.current
}
// SetPos 设置当前 token 索引
func (t *Tokenizer) SetPos(p int) {
t.current = p
}
// Length 返回 token 总数
func (t *Tokenizer) Length() int {
return len(t.Tokens)
}
// TokenAt 获取指定索引的 token
func (t *Tokenizer) TokenAt(idx int) Token {
if idx < 0 || idx >= len(t.Tokens) {
return t.Tokens[len(t.Tokens)-1]
}
return t.Tokens[idx]
}
// helper functions
func isWhitespace(ch byte) bool {
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
}
func isDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isLetter(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isIdentifierStart(ch byte) bool {
return isLetter(ch) || ch == '_' || ch == '@'
}
func isIdentifierPart(ch byte) bool {
return isLetter(ch) || isDigit(ch) || ch == '_' || ch == '@' || ch == '$'
}
func isOperatorStart(ch byte) bool {
return ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '=' ||
ch == '<' || ch == '>' || ch == '!' || ch == '|' || ch == ':' || ch == '~'
}
func isPunctuation(ch byte) bool {
return ch == '(' || ch == ')' || ch == ',' || ch == ';' || ch == '.'
}