refactor: 移除antlr4减小包体积&ai助手优化

This commit is contained in:
meilin.huang
2026-05-08 20:45:13 +08:00
parent 3768cef62d
commit f23b243fc5
154 changed files with 13054 additions and 396804 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
java -jar antlr-4.13.1-complete.jar -Dlanguage=Go -package parser -visitor *.g4

View File

@@ -1,91 +0,0 @@
/*
PostgreSQL grammar.
The MIT License (MIT).
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package parser
import (
"unicode"
"github.com/antlr4-go/antlr/v4"
)
type PostgreSQLLexerBase struct {
*antlr.BaseLexer
stack StringStack
}
func (receiver *PostgreSQLLexerBase) pushTag() {
receiver.stack.Push(receiver.GetText())
}
func (receiver *PostgreSQLLexerBase) isTag() bool {
if receiver.stack.IsEmpty() {
return false
}
return receiver.GetText() == receiver.stack.PeekOrEmpty()
}
func (receiver *PostgreSQLLexerBase) popTag() {
_, _ = receiver.stack.Pop()
}
func (receiver *PostgreSQLLexerBase) checkLA(c int) bool {
return receiver.GetInputStream().LA(1) != c
}
func (receiver *PostgreSQLLexerBase) charIsLetter() bool {
c := receiver.GetInputStream().LA(-1)
return unicode.IsLetter(rune(c))
}
func (receiver *PostgreSQLLexerBase) HandleNumericFail() {
index := receiver.GetInputStream().Index() - 2
receiver.GetInputStream().Seek(index)
receiver.SetType(PostgreSQLLexerIntegral)
}
func (receiver *PostgreSQLLexerBase) HandleLessLessGreaterGreater() {
if receiver.GetText() == "<<" {
receiver.SetType(PostgreSQLLexerLESS_LESS)
}
if receiver.GetText() == ">>" {
receiver.SetType(PostgreSQLLexerGREATER_GREATER)
}
}
func (receiver *PostgreSQLLexerBase) UnterminatedBlockCommentDebugAssert() {
//Debug.Assert(InputStream.LA(1) == -1 /*EOF*/);
}
func (receiver *PostgreSQLLexerBase) CheckIfUtf32Letter() bool {
codePoint := receiver.GetInputStream().LA(-2)<<8 + receiver.GetInputStream().LA(-1)
var c []rune
if codePoint < 0x10000 {
c = []rune{rune(codePoint)}
} else {
codePoint -= 0x10000
c = []rune{
(rune)(codePoint/0x400 + 0xd800),
(rune)(codePoint%0x400 + 0xdc00),
}
}
return unicode.IsLetter(c[0])
}

View File

@@ -1,29 +0,0 @@
/*
PostgreSQL grammar.
The MIT License (MIT).
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package parser
type PostgreSQLParseError struct {
Number int
Offset int
Line int
Column int
Message string
}

File diff suppressed because one or more lines are too long

View File

@@ -1,214 +0,0 @@
/*
PostgreSQL grammar.
The MIT License (MIT).
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package parser
import (
"strings"
"github.com/antlr4-go/antlr/v4"
)
type PostgreSQLParserBase struct {
*antlr.BaseParser
parseErrors []*PostgreSQLParseError
}
func NewPostgreSQLParserBase(input antlr.TokenStream) *PostgreSQLParserBase {
return &PostgreSQLParserBase{
BaseParser: antlr.NewBaseParser(input),
}
}
func (receiver *PostgreSQLParserBase) GetParsedSqlTree(script string, line int) antlr.ParserRuleContext {
parser := getPostgreSQLParser(script)
result := parser.Root()
for _, err := range parser.parseErrors {
receiver.parseErrors = append(receiver.parseErrors, &PostgreSQLParseError{
Number: err.Number,
Offset: err.Offset,
Line: err.Line + line,
Column: err.Column,
Message: err.Message,
})
}
return result
}
func (receiver *PostgreSQLParserBase) ParseRoutineBody(localContextInterface ICreatefunc_opt_listContext) {
localContext, ok := localContextInterface.(*Createfunc_opt_listContext)
if !ok {
return
}
var lang string
for _, coi := range localContext.AllCreatefunc_opt_item() {
createFuncOptItemContext, ok := coi.(*Createfunc_opt_itemContext)
if !ok || createFuncOptItemContext.LANGUAGE() == nil {
continue
}
nonReservedWordOrSConstContextInterface := createFuncOptItemContext.Nonreservedword_or_sconst()
if nonReservedWordOrSConstContextInterface == nil {
continue
}
nonReservedWordOrSConstContext, ok := nonReservedWordOrSConstContextInterface.(*Nonreservedword_or_sconstContext)
if !ok {
continue
}
nonReservedWordContextInterface := nonReservedWordOrSConstContext.Nonreservedword()
if nonReservedWordContextInterface == nil {
continue
}
nonReservedWordContext, ok := nonReservedWordContextInterface.(*NonreservedwordContext)
if !ok {
continue
}
identifierInterface := nonReservedWordContext.Identifier()
if identifierInterface == nil {
continue
}
identifier, ok := identifierInterface.(*IdentifierContext)
if !ok {
continue
}
node := identifier.Identifier()
if node == nil {
continue
}
lang = node.GetText()
break
}
if lang == "" {
return
}
var funcAs *Createfunc_opt_itemContext
for _, coi := range localContext.AllCreatefunc_opt_item() {
ctx, ok := coi.(*Createfunc_opt_itemContext)
if !ok || ctx.LANGUAGE() == nil {
continue
}
as := ctx.Func_as()
if as != nil {
funcAs = ctx
break
}
}
if funcAs == nil {
return
}
funcAsContextInterface := funcAs.Func_as()
if funcAsContextInterface == nil {
return
}
funcAsContext, ok := funcAsContextInterface.(*Func_asContext)
if !ok {
return
}
sConstContextInterface := funcAsContext.Sconst(0)
if sConstContextInterface == nil {
return
}
sConstContext, ok := sConstContextInterface.(*SconstContext)
if !ok {
return
}
text := GetRoutineBodyString(sConstContext)
line := sConstContext.GetStart().GetLine()
parser := getPostgreSQLParser(text)
switch lang {
case "plpgsql":
funcAs.Func_as().(*Func_asContext).Definition = parser.Plsqlroot()
case "sql":
funcAs.Func_as().(*Func_asContext).Definition = parser.Root()
}
for _, err := range parser.parseErrors {
receiver.parseErrors = append(receiver.parseErrors, &PostgreSQLParseError{
Number: err.Number,
Offset: err.Offset,
Line: err.Line + line,
Column: err.Column,
Message: err.Message,
})
}
}
func TrimQuotes(s string) string {
if s == "" {
return s
}
return s[1 : len(s)-2]
}
func unquote(s string) string {
result := strings.Builder{}
length := len(s)
index := 0
for index < length {
c := s[index]
result.WriteByte(c)
if c == '\'' && index < length-1 && (s[index+1] == '\'') {
index++
}
index++
}
return result.String()
}
func GetRoutineBodyString(rule *SconstContext) string {
if rule.Anysconst() == nil {
return ""
}
anySConstContext := rule.Anysconst().(*AnysconstContext)
stringConstant := anySConstContext.StringConstant()
if stringConstant != nil {
return unquote(TrimQuotes(stringConstant.GetText()))
}
unicodeEscapeStringConstant := anySConstContext.UnicodeEscapeStringConstant()
if unicodeEscapeStringConstant != nil {
return TrimQuotes(unicodeEscapeStringConstant.GetText())
}
escapeStringConstant := anySConstContext.EscapeStringConstant()
if escapeStringConstant != nil {
return TrimQuotes(escapeStringConstant.GetText())
}
result := strings.Builder{}
for _, node := range anySConstContext.AllDollarText() {
result.WriteString(node.GetText())
}
return result.String()
}
func getPostgreSQLParser(script string) *PostgreSQLParser {
stream := antlr.NewInputStream(script)
lexer := NewPostgreSQLLexer(stream)
tokenStream := antlr.NewCommonTokenStream(lexer, 0)
parser := NewPostgreSQLParser(tokenStream)
errorListener := new(PostgreSQLParserErrorListener)
errorListener.grammar = parser
parser.AddErrorListener(errorListener)
return parser
}

View File

@@ -1,51 +0,0 @@
/*
PostgreSQL grammar.
The MIT License (MIT).
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package parser
import "github.com/antlr4-go/antlr/v4"
type PostgreSQLParserErrorListener struct {
grammar *PostgreSQLParser
}
var _ antlr.ErrorListener = &PostgreSQLParserErrorListener{}
func (receiver PostgreSQLParserErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
receiver.grammar.parseErrors = append(receiver.grammar.parseErrors, &PostgreSQLParseError{
Number: 0,
Offset: 0,
Line: line,
Column: column,
Message: msg,
})
}
func (receiver PostgreSQLParserErrorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// ignore
}
func (receiver PostgreSQLParserErrorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// ignore
}
func (receiver PostgreSQLParserErrorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
// ignore
}

View File

@@ -1,77 +0,0 @@
/*
PostgreSQL grammar.
The MIT License (MIT).
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package parser
import (
"errors"
)
var (
ErrorStackEmpty = errors.New("stack empty")
)
type StringStack struct {
items []string
}
func (receiver *StringStack) Push(value string) {
receiver.items = append(receiver.items, value)
}
func (receiver *StringStack) Pop() (string, error) {
if receiver.IsEmpty() {
return "", ErrorStackEmpty
}
value := receiver.items[0]
receiver.items = receiver.items[1:]
return value, nil
}
func (receiver *StringStack) PopOrEmpty() string {
value, err := receiver.Pop()
if err != nil {
return ""
}
return value
}
func (receiver *StringStack) Peek() (string, error) {
if receiver.IsEmpty() {
return "", ErrorStackEmpty
}
return receiver.items[0], nil
}
func (receiver *StringStack) PeekOrEmpty() string {
value, err := receiver.Peek()
if err != nil {
return ""
}
return value
}
func (receiver *StringStack) Size() int {
return len(receiver.items)
}
func (receiver *StringStack) IsEmpty() bool {
return receiver.Size() == 0
}

View File

@@ -0,0 +1,745 @@
package pgsql
import (
"strconv"
"strings"
"mayfly-go/internal/db/dbm/sqlparser/base"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
)
// Parser PostgreSQL 方言 SQL 解析器
type Parser struct {
*base.Lexer
}
// NewParser 创建 PostgreSQL 解析器
func NewParser(sql string) *Parser {
return &Parser{
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
DoubleQuoteAsIdentifier: true,
}),
}
}
// Parse 解析单条 SQL 语句
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
p.SkipSemicolons()
if p.Current().IsEOF() {
return nil, nil
}
stmt := p.parseStatement()
return stmt, nil
}
func (p *Parser) parseStatement() sqlstmt.Stmt {
tok := p.Current()
switch {
case tok.IsKeyword("SELECT") || tok.Value == "(":
return p.parseSelect()
case tok.IsKeyword("INSERT"):
return p.parseInsert()
case tok.IsKeyword("UPDATE"):
return p.parseUpdate()
case tok.IsKeyword("DELETE"):
return p.parseDelete()
case tok.IsKeyword("CREATE"):
return p.parseCreate()
case tok.IsKeyword("DROP"):
return p.parseDrop()
case tok.IsKeyword("ALTER"):
return p.parseAlter()
case tok.IsKeyword("WITH"):
return p.parseWith()
case tok.IsKeyword("TRUNCATE"):
return p.parseGenericDdl()
default:
return p.parseGenericStmt()
}
}
// ---------- SELECT 解析 ----------
func (p *Parser) parseSelect() sqlstmt.Stmt {
start := p.Pos
var selectStmt *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
selectStmt = p.parseSelectBody()
} else {
p.SkipParentheses()
selectStmt = &sqlstmt.SelectStmt{}
}
p.ExpectValue(")")
} else {
selectStmt = p.parseSelectBody()
}
// UNION 解析
selectStmt = p.parseUnions(selectStmt)
// ORDER BYUNION 之后的 ORDER BY
if selectStmt.OrderBy == nil && p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
selectStmt.OrderBy = p.parseOrderBy()
}
// LIMITUNION 之后的 LIMIT 覆盖子查询内部的 LIMIT
if p.Current().IsKeyword("LIMIT") || p.Current().IsKeyword("OFFSET") {
if limit := p.parseLimit(); limit != nil {
selectStmt.Limit = limit
}
}
// FOR UPDATE
if p.Current().IsKeyword("FOR") {
p.Consume()
if p.Current().IsKeyword("UPDATE") {
p.Consume()
}
}
// RETURNINGPostgreSQL 特有)
if p.Current().IsKeyword("RETURNING") {
p.Consume()
// 跳过 RETURNING 后面的所有列(可能包含逗号)
for !p.Current().IsEOF() && p.Current().Value != ";" {
p.Consume()
}
}
// 更新完整文本(包含所有已解析的内容)
selectStmt.Base = sqlstmt.Base{Text: p.TextFrom(start)}
return selectStmt
}
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
start := p.Pos
distinct := false
if p.Current().IsKeyword("SELECT") {
p.Consume()
if p.Current().IsKeyword("DISTINCT") {
p.Consume()
distinct = true
}
}
// SELECT 项
var items []sqlstmt.SelectItem
for !p.Current().IsEOF() && !p.IsSelectClauseEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
itemStart := p.Pos
p.SkipExpr()
text := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
col, alias := p.ExtractColumnAndAlias(text)
items = append(items, sqlstmt.SelectItem{
Text: text,
ColumnName: col,
Alias: alias,
})
}
selectStmt := &sqlstmt.SelectStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Distinct: distinct,
Items: items,
}
// FROM
if p.Current().IsKeyword("FROM") {
p.Consume()
selectStmt.From = p.parseFromClause()
}
// JOIN
for p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
if join := p.parseJoinClause(); join != nil {
selectStmt.Joins = append(selectStmt.Joins, *join)
}
}
// WHERE
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
selectStmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// GROUP BY
if p.Current().IsKeyword("GROUP") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.SkipGroupByExpr()
}
// HAVING
if p.Current().IsKeyword("HAVING") {
p.Consume()
p.SkipExpr()
}
// ORDER BY
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
selectStmt.OrderBy = p.parseOrderBy()
}
// LIMIT/OFFSET子查询或简单 SELECT 的 LIMIT
if p.Current().IsKeyword("LIMIT") || p.Current().IsKeyword("OFFSET") {
selectStmt.Limit = p.parseLimit()
}
return selectStmt
}
func (p *Parser) parseUnions(selectStmt *sqlstmt.SelectStmt) *sqlstmt.SelectStmt {
for p.Current().IsKeyword("UNION") {
p.Consume()
all := false
if p.Current().IsKeyword("ALL") {
p.Consume()
all = true
}
var nextSelect *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
nextSelect = p.parseSelectBody()
}
p.ExpectValue(")")
} else if p.Current().IsKeyword("SELECT") {
nextSelect = p.parseSelectBody()
}
if nextSelect != nil {
// 如果 unionSelect 有 LIMIT 或 ORDER BY移动到外层 selectStmt
if nextSelect.Limit != nil {
selectStmt.Limit = nextSelect.Limit
nextSelect.Limit = nil
}
if len(nextSelect.OrderBy) > 0 {
selectStmt.OrderBy = nextSelect.OrderBy
nextSelect.OrderBy = nil
}
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
Select: nextSelect,
All: all,
})
}
}
return selectStmt
}
func (p *Parser) parseOrderBy() []sqlstmt.OrderByItem {
var items []sqlstmt.OrderByItem
for !p.Current().IsEOF() && !p.IsExprEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
start := p.Pos
p.SkipExpr()
text := p.TextFromExclusive(start)
desc := false
upper := strings.ToUpper(text)
if strings.HasSuffix(upper, " DESC") {
desc = true
text = strings.TrimSpace(text[:len(text)-5])
} else if strings.HasSuffix(upper, " ASC") {
text = strings.TrimSpace(text[:len(text)-4])
}
items = append(items, sqlstmt.OrderByItem{
Text: text,
Desc: desc,
})
}
return items
}
// LIMIT 解析PostgreSQL: OFFSET ... LIMIT 或 LIMIT ... ALL
func (p *Parser) parseLimit() *sqlstmt.Limit {
start := p.Pos
limit := &sqlstmt.Limit{}
// PostgreSQL 支持两种顺序:
// 1. LIMIT count OFFSET offset
// 2. OFFSET offset LIMIT count
if p.Current().IsKeyword("LIMIT") {
p.Consume()
// LIMIT ALL 或 LIMIT 值
if p.Current().IsKeyword("ALL") {
p.Consume()
limit.Text = p.TextFrom(start)
return limit
}
if p.Current().Type == tokenizer.TokenNumber {
countStr := p.Consume().Value
limit.Count, _ = strconv.Atoi(countStr)
}
// 检查是否有 OFFSET
if p.Current().IsKeyword("OFFSET") {
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
offsetStr := p.Consume().Value
limit.Offset, _ = strconv.Atoi(offsetStr)
} else if p.Current().Type == tokenizer.TokenIdentifier {
p.Consume()
}
}
} else if p.Current().IsKeyword("OFFSET") {
p.Consume()
// 跳过 OFFSET 值
if p.Current().Type == tokenizer.TokenNumber {
offsetStr := p.Consume().Value
limit.Offset, _ = strconv.Atoi(offsetStr)
} else if p.Current().Type == tokenizer.TokenIdentifier {
p.Consume()
}
// 检查是否有限制
if p.Current().IsKeyword("LIMIT") {
p.Consume()
if p.Current().IsKeyword("ALL") {
p.Consume()
limit.Text = p.TextFrom(start)
return limit
}
if p.Current().Type == tokenizer.TokenNumber {
countStr := p.Consume().Value
limit.Count, _ = strconv.Atoi(countStr)
}
}
}
limit.Text = p.TextFrom(start)
return limit
}
// ---------- FROM 解析 ----------
func (p *Parser) parseFromClause() []sqlstmt.TableRef {
var tables []sqlstmt.TableRef
for !p.Current().IsEOF() && !p.IsFromClauseEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
break
}
ref := p.parseTableRef()
if ref.Name != "" {
tables = append(tables, ref)
}
}
return tables
}
func (p *Parser) parseTableRef() sqlstmt.TableRef {
start := p.Pos
// 子查询
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
// 使用完整解析(包含 UNION
p.parseSelect()
} else {
p.SkipParentheses()
}
p.ExpectValue(")")
var alias string
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
alias = p.Unquote(p.Consume().Value)
}
return sqlstmt.TableRef{
Name: p.TextFrom(start),
Alias: alias,
}
}
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
return sqlstmt.TableRef{}
}
ref := sqlstmt.TableRef{}
part1 := p.Consume().Value
if p.Current().Value == "." {
p.Consume()
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
part2 := p.Consume().Value
ref.Schema = p.Unquote(part1)
ref.Name = p.Unquote(part2)
} else {
ref.Name = p.Unquote(part1)
}
} else {
ref.Name = p.Unquote(part1)
}
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
ref.Alias = p.Unquote(p.Consume().Value)
}
return ref
}
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
start := p.Pos
joinType := sqlstmt.JoinKindInner
if p.Current().IsKeyword("LEFT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindLeft
} else if p.Current().IsKeyword("RIGHT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindRight
} else if p.Current().IsKeyword("FULL") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindFull
} else if p.Current().IsKeyword("CROSS") {
p.Consume()
joinType = sqlstmt.JoinKindCross
} else if p.Current().IsKeyword("NATURAL") {
p.Consume()
joinType = sqlstmt.JoinKindNatural
} else if p.Current().IsKeyword("INNER") {
p.Consume()
}
if !p.Current().IsKeyword("JOIN") {
p.Pos = start
return nil
}
p.Consume()
tableRef := p.parseTableRef()
if tableRef.Name == "" {
p.Pos = start
return nil
}
var onExpr *sqlstmt.Expr
if p.Current().IsKeyword("ON") {
p.Consume()
onStart := p.Pos
p.SkipExpr()
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
} else if p.Current().IsKeyword("USING") {
p.Consume()
if p.Current().Value == "(" {
p.SkipParentheses()
}
}
return &sqlstmt.JoinClause{
Kind: joinType,
Table: tableRef,
On: onExpr,
Text: p.TextFrom(start),
}
}
// ---------- INSERT 解析 ----------
func (p *Parser) parseInsert() sqlstmt.Stmt {
start := p.Pos
p.Consume() // INSERT
// INTO可选
if p.Current().IsKeyword("INTO") {
p.Consume()
}
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
tableRef := p.parseTableRef()
// 解析列名列表 (col1, col2, ...)
columns := []string{}
if p.Current().Value == "(" {
p.Consume() // (
for !p.Current().IsEOF() && p.Current().Value != ")" {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().Type == tokenizer.TokenIdentifier {
columns = append(columns, p.Unquote(p.Consume().Value))
} else {
p.Consume()
}
}
if p.Current().Value == ")" {
p.Consume() // )
}
}
// VALUES 或 SELECT 或直接是 ON CONFLICT
if p.Current().IsKeyword("VALUES") || p.Current().Value == "VALUES" {
p.Consume()
// 消费到 ON 或 RETURNING 或 EOF
for !p.Current().IsEOF() {
val := strings.ToUpper(p.Current().Value)
if val == "ON" || val == "RETURNING" {
break
}
if p.Current().Value == "(" {
p.SkipParentheses()
} else {
p.Consume()
}
}
} else if p.Current().IsKeyword("SELECT") || p.Current().Value == "SELECT" {
// INSERT INTO ... SELECT ...
p.parseSelect()
}
// ON CONFLICTPostgreSQL 特有)
if strings.ToUpper(p.Current().Value) == "ON" {
p.Consume()
if p.Current().IsKeyword("CONFLICT") {
p.Consume()
// 跳过 ON CONFLICT 后面的所有内容直到 RETURNING 或 EOF
for !p.Current().IsEOF() && !p.Current().IsKeyword("RETURNING") {
p.Consume()
}
}
}
// RETURNINGPostgreSQL 特有)
if p.Current().IsKeyword("RETURNING") {
p.Consume()
// 跳过 RETURNING 后面的内容
for !p.Current().IsEOF() && p.Current().Value != ";" {
p.Consume()
}
}
return &sqlstmt.InsertStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Table: tableRef,
Columns: columns,
}
}
// ---------- UPDATE 解析 ----------
// PostgreSQL UPDATE 不支持 ORDER BY/LIMIT
func (p *Parser) parseUpdate() sqlstmt.Stmt {
start := p.Pos
p.Consume() // UPDATE
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
tableRef := p.parseTableRef()
tables := []sqlstmt.TableRef{tableRef}
// SET - 解析字段赋值
assignments := make([]sqlstmt.Assignment, 0)
if p.Current().IsKeyword("SET") {
p.Consume()
for !p.Current().IsEOF() {
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
assign := p.parseAssignment()
if assign != nil {
assignments = append(assignments, *assign)
}
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
}
}
// WHERE
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// RETURNINGPostgreSQL 特有)
if p.Current().IsKeyword("RETURNING") {
p.Consume()
// 跳过 RETURNING 后面的所有列(可能包含逗号)
for !p.Current().IsEOF() && p.Current().Value != ";" {
p.Consume()
}
}
return &sqlstmt.UpdateStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Set: assignments,
Where: where,
}
}
// ---------- DELETE 解析 ----------
// PostgreSQL DELETE 不支持 ORDER BY/LIMIT
func (p *Parser) parseDelete() sqlstmt.Stmt {
start := p.Pos
p.Consume() // DELETE
if p.Current().IsKeyword("FROM") {
p.Consume()
}
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
tableRef := p.parseTableRef()
tables := []sqlstmt.TableRef{tableRef}
// WHERE
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// RETURNINGPostgreSQL 特有)
if p.Current().IsKeyword("RETURNING") {
p.Consume()
// 跳过 RETURNING 后面的所有列(可能包含逗号)
for !p.Current().IsEOF() && p.Current().Value != ";" {
p.Consume()
}
}
return &sqlstmt.DeleteStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Where: where,
}
}
// ---------- DDL 解析 ----------
func (p *Parser) parseCreate() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "CREATE",
}
}
func (p *Parser) parseDrop() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DROP",
}
}
func (p *Parser) parseAlter() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "ALTER",
}
}
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DDL",
}
}
// parseAssignment 解析 SET 字段赋值column = value
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
start := p.Pos
colText := ""
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
colText += p.Consume().Value
}
if p.Current().Value != "=" {
p.Pos = start
return nil
}
p.Consume() // =
valStart := p.Pos
for !p.Current().IsEOF() && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
if p.Current().Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
return &sqlstmt.Assignment{
Column: p.Unquote(strings.TrimSpace(colText)),
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
Text: p.TextFromExclusive(start),
}
}
// ---------- WITH 解析 ----------
func (p *Parser) parseWith() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.WithStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}
// ---------- 通用语句解析 ----------
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.OtherStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}

View File

@@ -0,0 +1,846 @@
package pgsql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== INSERT 测试 ==========
func TestPgInsertBasic(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
if !ok {
t.Fatal("expected InsertStmt")
}
// 验证完整文本
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
}
// 验证表名
if insertStmt.Table.Name != "users" {
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
}
// 验证列名
if len(insertStmt.Columns) != 2 {
t.Fatalf("expected 2 columns, got %d", len(insertStmt.Columns))
}
if insertStmt.Columns[0] != "name" {
t.Errorf("expected Columns[0]='name', got '%s'", insertStmt.Columns[0])
}
if insertStmt.Columns[1] != "age" {
t.Errorf("expected Columns[1]='age', got '%s'", insertStmt.Columns[1])
}
}
func TestPgInsertMultipleRows(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestPgInsertReturning(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30) RETURNING id, name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestPgInsertFromSelect(t *testing.T) {
sql := "INSERT INTO users_backup SELECT * FROM users WHERE status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
// ========== UPDATE 测试 ==========
func TestPgUpdateBasic(t *testing.T) {
sql := "UPDATE users SET name = 'John' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
// 验证完整文本
if updateStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
}
// 验证表名
if len(updateStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
}
if updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
// 验证 SET 字段
if len(updateStmt.Set) != 1 {
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
}
if updateStmt.Set[0].Column != "name" {
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
}
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
}
// 验证 WHERE
if updateStmt.Where == nil {
t.Fatal("expected WHERE")
}
if updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
}
}
func TestPgUpdateMultipleColumns(t *testing.T) {
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
// 验证完整文本
if updateStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
}
// 验证表名
if len(updateStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
}
if updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
// 验证 SET 字段3个赋值
if len(updateStmt.Set) != 3 {
t.Fatalf("expected 3 assignments, got %d", len(updateStmt.Set))
}
// 验证第一个赋值name = 'John'
if updateStmt.Set[0].Column != "name" {
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
}
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
}
if updateStmt.Set[0].Text != "name = 'John'" {
t.Errorf("expected Set[0].Text=\"name = 'John'\", got '%s'", updateStmt.Set[0].Text)
}
// 验证第二个赋值age = 30
if updateStmt.Set[1].Column != "age" {
t.Errorf("expected Set[1].Column='age', got '%s'", updateStmt.Set[1].Column)
}
if updateStmt.Set[1].Value == nil || updateStmt.Set[1].Value.Text != "30" {
t.Errorf("expected Set[1].Value='30', got '%s'", updateStmt.Set[1].Value.Text)
}
if updateStmt.Set[1].Text != "age = 30" {
t.Errorf("expected Set[1].Text='age = 30', got '%s'", updateStmt.Set[1].Text)
}
// 验证第三个赋值email = 'john@example.com'
if updateStmt.Set[2].Column != "email" {
t.Errorf("expected Set[2].Column='email', got '%s'", updateStmt.Set[2].Column)
}
if updateStmt.Set[2].Value == nil || updateStmt.Set[2].Value.Text != "'john@example.com'" {
t.Errorf("expected Set[2].Value=''john@example.com'', got '%s'", updateStmt.Set[2].Value.Text)
}
if updateStmt.Set[2].Text != "email = 'john@example.com'" {
t.Errorf("expected Set[2].Text=\"email = 'john@example.com'\", got '%s'", updateStmt.Set[2].Text)
}
// 验证 WHERE
if updateStmt.Where == nil {
t.Fatal("expected WHERE")
}
if updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
}
}
func TestPgUpdateFromJoin(t *testing.T) {
// PostgreSQL UPDATE FROM 语法
sql := "UPDATE orders SET status = 'shipped' FROM users WHERE orders.user_id = users.id AND users.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE text: %s", updateStmt.GetText())
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
// 主表名应该正确解析
if len(updateStmt.Tables) < 1 {
t.Fatal("expected at least 1 table")
}
// 注意:当前解析器对于 UPDATE FROM 语法支持有限,可能表名包含别名
// 但至少验证能成功解析而不报错
t.Logf("Successfully parsed UPDATE FROM statement")
}
// ========== DELETE 测试 ==========
func TestPgDeleteBasic(t *testing.T) {
sql := "DELETE FROM users WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
// 验证完整文本
if deleteStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
}
// 验证表名
if len(deleteStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
}
if deleteStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", deleteStmt.Tables[0].Name)
}
// 验证 WHERE
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
if deleteStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1', got '%s'", deleteStmt.Where.Text)
}
}
func TestPgDeleteUsing(t *testing.T) {
sql := "DELETE FROM orders o USING users u WHERE o.user_id = u.id AND u.status = 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) < 1 || deleteStmt.Tables[0].Name != "orders" {
t.Errorf("expected table='orders'")
}
}
// ========== Schema.Table 测试 ==========
func TestPgInsertWithSchema(t *testing.T) {
sql := `INSERT INTO "public"."users" ("name", "age") VALUES ('John', 30)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.Table.Schema != "public" {
t.Errorf("expected schema='public', got '%s'", insertStmt.Table.Schema)
}
if insertStmt.Table.Name != "users" {
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
}
}
func TestPgUpdateWithSchema(t *testing.T) {
sql := `UPDATE "public"."t_db" SET "name" = 'fsdfds3' WHERE "id" = 5`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE text: %s", updateStmt.GetText())
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
t.Logf("UPDATE SET: %+v", updateStmt.Set)
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
if len(updateStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
}
if updateStmt.Tables[0].Schema != "public" {
t.Errorf("expected schema='public', got '%s'", updateStmt.Tables[0].Schema)
}
if updateStmt.Tables[0].Name != "t_db" {
t.Errorf("expected table='t_db', got '%s'", updateStmt.Tables[0].Name)
}
// 验证 SET 字段
if len(updateStmt.Set) != 1 {
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
}
if updateStmt.Set[0].Column != "name" {
t.Errorf("expected column='name', got '%s'", updateStmt.Set[0].Column)
}
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'fsdfds3'" {
t.Errorf("expected value=''fsdfds3'', got '%s'", updateStmt.Set[0].Value.Text)
}
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 5` {
t.Errorf("expected WHERE='\"id\" = 5', got '%s'", updateStmt.Where.Text)
}
}
func TestPgDeleteWithSchema(t *testing.T) {
sql := `DELETE FROM "public"."logs" WHERE "created_at" < '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
}
if deleteStmt.Tables[0].Schema != "public" {
t.Errorf("expected schema='public', got '%s'", deleteStmt.Tables[0].Schema)
}
if deleteStmt.Tables[0].Name != "logs" {
t.Errorf("expected table='logs', got '%s'", deleteStmt.Tables[0].Name)
}
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"created_at" < '2024-01-01'` {
t.Errorf("expected WHERE text")
}
}
// ========== DDL 测试 ==========
func TestPgDDLCreate(t *testing.T) {
sql := "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50), age INT)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "CREATE" {
t.Errorf("expected DdlKind='CREATE', got '%s'", ddlStmt.DdlKind)
}
if ddlStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestPgDDLDrop(t *testing.T) {
sql := "DROP TABLE users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "DROP" {
t.Errorf("expected DdlKind='DROP', got '%s'", ddlStmt.DdlKind)
}
}
func TestPgDDLAlter(t *testing.T) {
sql := "ALTER TABLE users ADD COLUMN email VARCHAR(255)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "ALTER" {
t.Errorf("expected DdlKind='ALTER', got '%s'", ddlStmt.DdlKind)
}
}
func TestPgDDLTruncate(t *testing.T) {
sql := "TRUNCATE TABLE users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
t.Logf("TRUNCATE stmt type: %T", stmt)
t.Logf("TRUNCATE text: %s", stmt.GetText())
}
// ========== 复杂 DML 测试 ==========
func TestPgComplexUpdateWithSubquery(t *testing.T) {
sql := "UPDATE users SET total_orders = (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
}
func TestPgComplexDeleteWithSubquery(t *testing.T) {
sql := "DELETE FROM users WHERE id NOT IN (SELECT DISTINCT user_id FROM orders)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
}
func TestPgInsertOnConflict(t *testing.T) {
sql := "INSERT INTO users (id, name) VALUES (1, 'John') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("Actual text: %s", insertStmt.GetText())
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
// ========== 复杂 INSERT 测试 ==========
func TestPgInsertWithDoubleQuotes(t *testing.T) {
// PostgreSQL 双引号标识符
sql := `INSERT INTO "users" ("name", "age", "email") VALUES ('John', 30, 'john@example.com')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestPgInsertWithSpecialChars(t *testing.T) {
// 包含特殊字符
sql := `INSERT INTO "logs" ("message", "level") VALUES ('Error: connection failed!', 'ERROR')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT text: %s", insertStmt.GetText())
}
func TestPgInsertReturningMultiple(t *testing.T) {
// RETURNING 多个字段
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com') RETURNING "id", "name", "created_at"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT RETURNING: %s", insertStmt.GetText())
}
func TestPgInsertFromSelectComplex(t *testing.T) {
// INSERT FROM SELECT 复杂查询
sql := `INSERT INTO "users_backup" SELECT * FROM "users" WHERE "status" = 1 AND "created_at" > '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT FROM SELECT: %s", insertStmt.GetText())
}
func TestPgInsertOnConflictDoNothing(t *testing.T) {
// ON CONFLICT DO NOTHING
sql := `INSERT INTO "users" ("id", "name") VALUES (1, 'John') ON CONFLICT ("id") DO NOTHING`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT ON CONFLICT DO NOTHING: %s", insertStmt.GetText())
}
// ========== 复杂 UPDATE 测试 ==========
func TestPgUpdateWithDoubleQuotes(t *testing.T) {
sql := `UPDATE "users" SET "name" = 'John', "age" = 30 WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 1` {
t.Errorf("expected WHERE")
}
}
func TestPgUpdateWithComplexWhere(t *testing.T) {
sql := `UPDATE "orders" SET "status" = 'cancelled' WHERE "status" = 'pending' AND "created_at" < '2024-01-01' AND ("amount" < 100 OR "user_id" IS NULL)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Complex WHERE: %s", updateStmt.Where.Text)
}
func TestPgUpdateWithFunctions(t *testing.T) {
sql := `UPDATE "users" SET "updated_at" = NOW(), "login_count" = "login_count" + 1 WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE with functions: %s", updateStmt.GetText())
}
func TestPgUpdateReturningComplex(t *testing.T) {
sql := `UPDATE "users" SET "status" = 0 WHERE "status" = 1 RETURNING "id", "name", "old_status"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != `"status" = 1` {
t.Errorf("expected WHERE")
}
}
func TestPgUpdateWithSubquery(t *testing.T) {
sql := `UPDATE "users" SET "total" = (SELECT SUM("amount") FROM "orders" WHERE "user_id" = "users"."id") WHERE "status" = 'active'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
}
func TestPgUpdateFromComplex(t *testing.T) {
// PostgreSQL UPDATE FROM 复杂场景
sql := `UPDATE "orders" o SET "status" = 'shipped' FROM "users" u WHERE o."user_id" = u."id" AND u."status" = 1 AND o."created_at" > '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE FROM: %s", updateStmt.GetText())
}
// ========== 复杂 DELETE 测试 ==========
func TestPgDeleteWithDoubleQuotes(t *testing.T) {
sql := `DELETE FROM "users" WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
}
func TestPgDeleteWithComplexWhere(t *testing.T) {
sql := `DELETE FROM "logs" WHERE "created_at" < '2024-01-01' AND ("level" = 'DEBUG' OR "level" = 'INFO')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Complex DELETE WHERE: %s", deleteStmt.Where.Text)
}
func TestPgDeleteReturningComplex(t *testing.T) {
sql := `DELETE FROM "users" WHERE "status" = 0 RETURNING "id", "name", "deleted_at"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"status" = 0` {
t.Errorf("expected WHERE")
}
}
func TestPgDeleteWithSubquery(t *testing.T) {
sql := `DELETE FROM "users" WHERE "id" NOT IN (SELECT DISTINCT "user_id" FROM "orders")`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
}
func TestPgDeleteUsingComplex(t *testing.T) {
// PostgreSQL DELETE USING 复杂场景
sql := `DELETE FROM "orders" o USING "users" u WHERE o."user_id" = u."id" AND u."status" = 0 AND o."created_at" < '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
t.Logf("DELETE USING: %s", deleteStmt.GetText())
}
// ========== 复杂 DDL 测试 ==========
func TestPgDDLCreateTableWithQuotes(t *testing.T) {
sql := `CREATE TABLE "users" ("id" INT PRIMARY KEY, "name" VARCHAR(50) NOT NULL, "email" VARCHAR(255) UNIQUE)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "CREATE" {
t.Errorf("expected DdlKind='CREATE'")
}
}
func TestPgDDLCreateTableWithSerial(t *testing.T) {
sql := `CREATE TABLE "orders" ("id" SERIAL PRIMARY KEY, "amount" DECIMAL(10,2), "created_at" TIMESTAMP DEFAULT NOW())`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
t.Logf("CREATE TABLE with SERIAL: %s", ddlStmt.GetText())
}
func TestPgDDLAlterTableAddColumn(t *testing.T) {
sql := `ALTER TABLE "users" ADD COLUMN "email" VARCHAR(255)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "ALTER" {
t.Errorf("expected DdlKind='ALTER'")
}
}
func TestPgDDLDropIfExists(t *testing.T) {
sql := `DROP TABLE IF EXISTS "users"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "DROP" {
t.Errorf("expected DdlKind='DROP'")
}
}
// ========== 注释风格测试 ==========
func TestPgDMLWithSingleLineComment(t *testing.T) {
// PostgreSQL 单行注释 --
sql := "-- 更新用户信息\nUPDATE \"users\" SET \"name\" = 'John' WHERE \"id\" = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
t.Logf("UPDATE with -- comment: %s", updateStmt.GetText())
}
func TestPgDMLWithMultiLineComment(t *testing.T) {
// 多行注释 /* */
sql := "/* 删除过期订单 */ DELETE FROM \"orders\" WHERE \"created_at\" < '2024-01-01'"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("DELETE with /* */ comment: %s", deleteStmt.GetText())
}
func TestPgDMLWithInlineComment(t *testing.T) {
// 行内注释
sql := `SELECT "id", "name" /* 用户名 */ FROM "users" WHERE "status" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
t.Logf("SELECT with inline comment: %s", selectStmt.GetText())
}
func TestPgDMLWithMultipleComments(t *testing.T) {
// 多个注释
sql := `-- 查询活跃用户
/* 只查询最近注册的 */
SELECT "id", "name" FROM "users"
WHERE "status" = 1 AND "created_at" > '2024-01-01'
ORDER BY "created_at" DESC`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Multiple comments: %s", selectStmt.GetText())
}
func TestPgInsertWithComment(t *testing.T) {
sql := "-- 插入新用户\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('John', 'john@example.com')"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT with comment: %s", insertStmt.GetText())
}
func TestPgUpdateWithComment(t *testing.T) {
sql := `/* 批量更新状态 */
UPDATE "orders" SET "status" = 'cancelled'
WHERE "status" = 'pending' -- 只更新待处理的订单
AND "created_at" < '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("UPDATE with multiple comments: %s", updateStmt.GetText())
}

View File

@@ -0,0 +1,455 @@
package pgsql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 简单分页测试 ==========
func TestPgPaginationSimple(t *testing.T) {
sql := "SELECT * FROM users LIMIT 10 OFFSET 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithOffset(t *testing.T) {
sql := "SELECT id, name FROM users LIMIT 10 OFFSET 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationOffsetFirst(t *testing.T) {
// PostgreSQL 支持 OFFSET 在 LIMIT 前面
sql := "SELECT * FROM products OFFSET 30 LIMIT 15"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 30 {
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 15 {
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationOnlyLimit(t *testing.T) {
sql := "SELECT * FROM orders LIMIT 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationOnlyOffset(t *testing.T) {
sql := "SELECT * FROM users OFFSET 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 100 {
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 0 {
t.Errorf("expected count=0, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationLimitAll(t *testing.T) {
// PostgreSQL 特有的 LIMIT ALL
sql := "SELECT * FROM users OFFSET 50 LIMIT ALL"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
}
// ========== 复杂分页测试 ==========
func TestPgPaginationWithWhere(t *testing.T) {
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 20 OFFSET 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithOrderBy(t *testing.T) {
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 100 OFFSET 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.OrderBy) != 2 {
t.Fatalf("expected 2 ORDER BY clauses")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithWhereOrderBy(t *testing.T) {
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 10 OFFSET 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
t.Errorf("expected WHERE='status = 'active''")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
t.Errorf("expected ORDER BY score")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithJoin(t *testing.T) {
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 20 OFFSET 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatal("expected 1 JOIN")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 100 {
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithGroupBy(t *testing.T) {
sql := "SELECT user_id, COUNT(*) as order_count FROM orders GROUP BY user_id ORDER BY order_count DESC LIMIT 10 OFFSET 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
// ========== UNION 分页测试 ==========
func TestPgPaginationWithUnion(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 50 OFFSET 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause for UNION")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationUnionAll(t *testing.T) {
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 30 OFFSET 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION ALL")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 30 {
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
}
}
func TestPgMultipleUnionsPagination(t *testing.T) {
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 ORDER BY 1 LIMIT 100 OFFSET 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 2 {
t.Fatalf("expected 2 UNIONs, got %d", len(selectStmt.Unions))
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
// ========== 子查询分页测试 ==========
func TestPgPaginationWithSubquery(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id LIMIT 100 OFFSET 0) AS tmp LIMIT 10 OFFSET 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.From) != 1 {
t.Fatal("expected 1 FROM table")
}
if selectStmt.Limit == nil {
t.Fatal("expected outer LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected outer offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected outer count=10, got %d", selectStmt.Limit.Count)
}
}
func TestPgNestedPagination(t *testing.T) {
// 嵌套分页查询
sql := "SELECT * FROM (SELECT * FROM users LIMIT 100 OFFSET 0) AS tmp LIMIT 5 OFFSET 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected outer LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 5 {
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
}
}
func TestPgPaginationWithExists(t *testing.T) {
sql := "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id LIMIT 1) LIMIT 20 OFFSET 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
// ========== 大数据量分页测试 ==========
func TestPgLargeOffsetPagination(t *testing.T) {
sql := "SELECT * FROM logs ORDER BY id LIMIT 100 OFFSET 1000000"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 1000000 {
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
// ========== FOR UPDATE 分页测试 ==========
func TestPgPaginationForUpdate(t *testing.T) {
sql := "SELECT * FROM users WHERE status = 0 ORDER BY id LIMIT 10 OFFSET 0 FOR UPDATE"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}

View File

@@ -0,0 +1,412 @@
package pgsql
import (
"strings"
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== SELECT 测试 ==========
func TestPgSelectBasic(t *testing.T) {
sql := "SELECT id, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 SELECT 字段
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items, got %d", len(selectStmt.Items))
}
if selectStmt.Items[0].Text != "id" {
t.Errorf("expected Items[0]='id', got '%s'", selectStmt.Items[0].Text)
}
if selectStmt.Items[1].Text != "name" {
t.Errorf("expected Items[1]='name', got '%s'", selectStmt.Items[1].Text)
}
// 验证 FROM 表名
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" {
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
}
// 验证 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
if selectStmt.Where.Text != "status = 1" {
t.Errorf("expected WHERE='status = 1', got '%s'", selectStmt.Where.Text)
}
// 验证 ORDER BY
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "id" {
t.Errorf("expected OrderBy[0].Text='id', got '%s'", selectStmt.OrderBy[0].Text)
}
if !selectStmt.OrderBy[0].Desc {
t.Error("expected OrderBy[0].Desc=true")
}
// 验证 LIMIT
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT")
}
if selectStmt.Limit.Text != "LIMIT 10" {
t.Errorf("expected LIMIT.Text='LIMIT 10', got '%s'", selectStmt.Limit.Text)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected LIMIT.Count=10, got %d", selectStmt.Limit.Count)
}
}
func TestPgSelectDistinct(t *testing.T) {
sql := "SELECT DISTINCT id, name FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if !selectStmt.Distinct {
t.Error("expected Distinct=true")
}
}
func TestPgSelectStar(t *testing.T) {
sql := "SELECT * FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 SELECT *
if len(selectStmt.Items) != 1 {
t.Fatalf("expected 1 item, got %d", len(selectStmt.Items))
}
if selectStmt.Items[0].Text != "*" {
t.Errorf("expected Items[0]='*', got '%s'", selectStmt.Items[0].Text)
}
// 验证 FROM
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" {
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
}
}
func TestPgSelectWithJoin(t *testing.T) {
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证主表
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" {
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
}
if selectStmt.From[0].Alias != "u" {
t.Errorf("expected From[0].Alias='u', got '%s'", selectStmt.From[0].Alias)
}
// 验证 JOIN
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
}
if selectStmt.Joins[0].Table.Name != "orders" {
t.Errorf("expected Join[0].Table.Name='orders', got '%s'", selectStmt.Joins[0].Table.Name)
}
if selectStmt.Joins[0].Table.Alias != "o" {
t.Errorf("expected Join[0].Table.Alias='o', got '%s'", selectStmt.Joins[0].Table.Alias)
}
if selectStmt.Joins[0].Kind != sqlstmt.JoinKindLeft {
t.Errorf("expected Join[0].Kind=JoinKindLeft, got '%d'", selectStmt.Joins[0].Kind)
}
if selectStmt.Joins[0].On == nil {
t.Fatal("expected Join[0].On")
}
if selectStmt.Joins[0].On.Text != "u.id = o.user_id" {
t.Errorf("expected Join[0].On.Text='u.id = o.user_id', got '%s'", selectStmt.Joins[0].On.Text)
}
// 验证 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
if selectStmt.Where.Text != "u.status = 1" {
t.Errorf("expected WHERE='u.status = 1', got '%s'", selectStmt.Where.Text)
}
}
func TestPgSelectComplexWhere(t *testing.T) {
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 FROM
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" {
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
}
// 验证 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
// 验证 ORDER BY
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "name" {
t.Errorf("expected OrderBy[0].Text='name', got '%s'", selectStmt.OrderBy[0].Text)
}
}
func TestPgOffsetLimit(t *testing.T) {
sql := "SELECT * FROM users OFFSET 10 LIMIT 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 LIMIT/OFFSET
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT")
}
if selectStmt.Limit.Text != "OFFSET 10 LIMIT 20" {
t.Errorf("expected LIMIT.Text='OFFSET 10 LIMIT 20', got '%s'", selectStmt.Limit.Text)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected LIMIT.Count=20, got %d", selectStmt.Limit.Count)
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected LIMIT.Offset=10, got %d", selectStmt.Limit.Offset)
}
}
func TestPgLimitAll(t *testing.T) {
sql := "SELECT * FROM users LIMIT ALL"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT ALL" {
t.Errorf("expected LIMIT text='LIMIT ALL'")
}
}
func TestPgUnion(t *testing.T) {
sql := "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 UNION 数量
if len(selectStmt.Unions) != 2 {
t.Fatalf("expected 2 unions, got %d", len(selectStmt.Unions))
}
// 验证第一个 UNIONDISTINCT
if selectStmt.Unions[0].All {
t.Error("expected Unions[0].All=false (DISTINCT)")
}
if selectStmt.Unions[0].Select == nil {
t.Fatal("expected Unions[0].Select")
}
// 验证第二个 UNIONALL
if !selectStmt.Unions[1].All {
t.Error("expected Unions[1].All=true")
}
if selectStmt.Unions[1].Select == nil {
t.Fatal("expected Unions[1].Select")
}
}
// ========== FOR UPDATE 测试 ==========
func TestPgForUpdate(t *testing.T) {
sql := "SELECT id FROM users WHERE id = 1 FOR UPDATE"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
// 验证完整文本
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
// 验证 FROM
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" {
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
}
// 验证 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
if selectStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1', got '%s'", selectStmt.Where.Text)
}
// 注意FOR UPDATE 标记可能在 Base.Text 中体现
// 只要完整文本包含 FOR UPDATE 即可
if !strings.Contains(selectStmt.GetText(), "FOR UPDATE") {
t.Error("expected text to contain 'FOR UPDATE'")
}
}
func TestPgForUpdateSkipLocked(t *testing.T) {
sql := "SELECT id FROM users WHERE status = 1 FOR UPDATE SKIP LOCKED"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
t.Errorf("expected WHERE='status = 1'")
}
}
// ========== RETURNING 测试 ==========
func TestPgUpdateReturning(t *testing.T) {
sql := "UPDATE users SET name = 'John' WHERE id = 1 RETURNING id, name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
if stmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestPgDeleteReturning(t *testing.T) {
sql := "DELETE FROM users WHERE id = 1 RETURNING *"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
if stmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
// ========== DDL 测试 ==========
func TestPgDDL(t *testing.T) {
sqls := []string{
"CREATE TABLE users (id INT PRIMARY KEY)",
"DROP TABLE users",
"ALTER TABLE users ADD COLUMN email VARCHAR(255)",
}
for _, sql := range sqls {
t.Run(sql[:10], func(t *testing.T) {
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
if stmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
})
}
}

View File

@@ -0,0 +1,368 @@
package pgsql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 子查询测试 ==========
func TestPgSubqueryInFrom(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 表数量: %d", len(selectStmt.From))
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
}
t.Logf("FROM[0] Name: %s", selectStmt.From[0].Name)
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
if selectStmt.From[0].Alias != "u" {
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
}
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE clause")
}
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where.Text != "u.id > 10" {
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
}
}
func TestPgSubqueryInWhere(t *testing.T) {
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestPgSubqueryInSelect(t *testing.T) {
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
for i, item := range selectStmt.Items {
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
}
if len(selectStmt.Items) != 3 {
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
}
if selectStmt.Items[2].Alias != "order_count" {
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
}
}
func TestPgNestedSubquery(t *testing.T) {
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE")
}
if selectStmt.Where.Text != "outer_u.id > 5" {
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
}
}
func TestPgSubqueryWithUnion(t *testing.T) {
sql := "SELECT * FROM (SELECT id FROM users UNION SELECT id FROM admins) AS all_users WHERE all_users.id > 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Where.Text != "all_users.id > 10" {
t.Errorf("expected WHERE text='all_users.id > 10', got '%s'", selectStmt.Where.Text)
}
}
func TestPgCorrelatedSubquery(t *testing.T) {
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestPgSubqueryWithExists(t *testing.T) {
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestPgMultipleSubqueries(t *testing.T) {
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestPgSubqueryWithLimit(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 数量: %d", len(selectStmt.From))
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Alias != "top_users" {
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
}
}
func TestPgComplexSubquery(t *testing.T) {
sql := `SELECT
u.id,
u.name,
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
FROM users u
WHERE u.status = 1
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
ORDER BY u.name`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
if len(selectStmt.Items) != 4 {
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
}
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
t.Logf("WHERE: %s", selectStmt.Where.Text)
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "u.name" {
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
}
}
// ========== JOIN 测试 ==========
func TestPgMultipleJoins(t *testing.T) {
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("JOIN 数量: %d", len(selectStmt.Joins))
if len(selectStmt.Joins) != 2 {
t.Fatalf("expected 2 joins, got %d", len(selectStmt.Joins))
}
if selectStmt.Joins[0].Table.Name != "orders" {
t.Errorf("expected first join table='orders', got '%s'", selectStmt.Joins[0].Table.Name)
}
if selectStmt.Joins[1].Table.Name != "products" {
t.Errorf("expected second join table='products', got '%s'", selectStmt.Joins[1].Table.Name)
}
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Where.Text != "u.status = 1" {
t.Errorf("expected WHERE text='u.status = 1', got '%s'", selectStmt.Where.Text)
}
}
func TestPgRightJoin(t *testing.T) {
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
}
}
func TestPgCrossJoin(t *testing.T) {
sql := "SELECT * FROM users CROSS JOIN roles"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
}
}
// ========== UNION 测试 ==========
func TestPgUnionWithOrderBy(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("UNION 数量: %d", len(selectStmt.Unions))
if len(selectStmt.Unions) != 1 {
t.Fatalf("expected 1 union, got %d", len(selectStmt.Unions))
}
// ORDER BY 应该属于整个 UNION
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "1" {
t.Errorf("expected order by text='1', got '%s'", selectStmt.OrderBy[0].Text)
}
if !selectStmt.OrderBy[0].Desc {
t.Error("expected order by DESC")
}
}
func TestPgUnionWithLimit(t *testing.T) {
sql := "SELECT id FROM users UNION ALL SELECT id FROM admins LIMIT 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT")
}
if selectStmt.Limit.Text != "LIMIT 20" {
t.Errorf("expected LIMIT text='LIMIT 20', got '%s'", selectStmt.Limit.Text)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected LIMIT count=20, got %d", selectStmt.Limit.Count)
}
}

View File

@@ -1,60 +1,17 @@
package pgsql
import (
"mayfly-go/internal/db/dbm/sqlparser/base"
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
"mayfly-go/pkg/logx"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"github.com/antlr4-go/antlr/v4"
"mayfly-go/pkg/gox"
"mayfly-go/pkg/logx"
)
func GetPgsqlParserTree(baseLine int, statement string) (antlr.ParseTree, *antlr.CommonTokenStream, error) {
lexer := pgparser.NewPostgreSQLLexer(antlr.NewInputStream(statement))
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
parser := pgparser.NewPostgreSQLParser(stream)
lexerErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
lexer.RemoveErrorListeners()
lexer.AddErrorListener(lexerErrorListener)
parserErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
parser.RemoveErrorListeners()
parser.AddErrorListener(parserErrorListener)
parser.BuildParseTrees = true
tree := parser.Root()
if lexerErrorListener.Err != nil {
return nil, nil, lexerErrorListener.Err
}
if parserErrorListener.Err != nil {
return nil, nil, parserErrorListener.Err
}
return tree, stream, nil
}
type PgsqlParser struct {
}
func (*PgsqlParser) Parse(stmt string) (stmts []sqlstmt.Stmt, err error) {
defer func() {
if e := recover(); e != nil {
logx.ErrorTrace("postgres sql parser err: ", e)
err = e.(error)
}
}()
tree, _, err := GetPgsqlParserTree(1, stmt)
if err != nil {
return nil, err
}
return tree.Accept(new(PgsqlVisitor)).([]sqlstmt.Stmt), nil
func (*PgsqlParser) Parse(stmt string) (sqlstmt.Stmt, error) {
defer gox.Recover(func(e error) {
logx.ErrorTrace("postgres sql parser err: ", e)
})
return NewParser(stmt).Parse()
}

View File

@@ -1,83 +0,0 @@
package pgsql
import (
"fmt"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"testing"
)
func TestParserSimpleSelect(t *testing.T) {
parser := new(PgsqlParser)
sql := `SELECT t.*,t.id as tid FROM mayfly.sys_login_log as t where t.id > 0 OFFSET 0 LIMIT 25; select * from tdb where id > 1`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
stmt := stmts[0].(*sqlstmt.SimpleSelectStmt)
t.Log(stmt.QuerySpecification.Where.GetText())
fmt.Println(stmt.QuerySpecification.From.GetText())
t.Log(stmts)
}
func TestParserUnionSelect(t *testing.T) {
parser := new(PgsqlParser)
sql := `(select sum(t.age), t.id tid, t1.id2, t1.* from T_DB t join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc) union all (select * from t_db2) OFFSET 0 LIMIT 25;`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserSingleUpdate(t *testing.T) {
parser := new(PgsqlParser)
sql := `UPDATE test.t_sys_msg t
SET
recipient_id = 13,
t.creator = 'admin4'
WHERE
t.id = 1;`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserDelete(t *testing.T) {
parser := new(PgsqlParser)
sql := `Delete from t_sys_msg t
WHERE
t.id = 1;`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserInsert(t *testing.T) {
parser := new(PgsqlParser)
sql := `INSERT INTO
mayfly_go.t_sys_msg (
type,
msg,
recipient_id,
creator_id,
create_time,
is_deleted
)
VALUES
(1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}

View File

@@ -1,493 +0,0 @@
package pgsql
import (
"strings"
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"github.com/spf13/cast"
)
type PgsqlVisitor struct {
*pgparser.BasePostgreSQLParserVisitor
}
func (v *PgsqlVisitor) VisitRoot(ctx *pgparser.RootContext) interface{} {
if sbc := ctx.Stmtblock(); sbc != nil {
return sbc.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitPlsqlroot(ctx *pgparser.PlsqlrootContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitStmtblock(ctx *pgparser.StmtblockContext) interface{} {
if smc := ctx.Stmtmulti(); smc != nil {
return smc.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitStmtmulti(ctx *pgparser.StmtmultiContext) interface{} {
allSqlStatement := ctx.AllStmt()
stmts := make([]sqlstmt.Stmt, 0)
for _, sqlStatement := range allSqlStatement {
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
}
return stmts
}
func (v *PgsqlVisitor) VisitStmt(ctx *pgparser.StmtContext) interface{} {
if selectstmtCtx := ctx.Selectstmt(); selectstmtCtx != nil {
return selectstmtCtx.Accept(v)
}
if updatestmtCtx := ctx.Updatestmt(); updatestmtCtx != nil {
return updatestmtCtx.Accept(v)
}
if deletestmtCtx := ctx.Deletestmt(); deletestmtCtx != nil {
return deletestmtCtx.Accept(v)
}
if insertstmtC := ctx.Insertstmt(); insertstmtC != nil {
return insertstmtC.Accept(v)
}
if c := ctx.Createdbstmt(); c != nil {
cds := new(sqlstmt.CreateDatabase)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Createtablespacestmt(); c != nil {
cds := new(sqlstmt.CreateTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Altertablestmt(); c != nil {
cds := new(sqlstmt.AlterTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Dropdbstmt(); c != nil {
cds := new(sqlstmt.DropDatabase)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Droptablespacestmt(); c != nil {
cds := new(sqlstmt.DropTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if explain := ctx.Explainstmt(); explain != nil {
otherRead := new(sqlstmt.OtherReadStmt)
otherRead.Node = sqlstmt.NewNode(explain.GetParser(), explain)
return otherRead
}
if c := ctx.Variableshowstmt(); c != nil {
otherRead := new(sqlstmt.OtherReadStmt)
otherRead.Node = sqlstmt.NewNode(c.GetParser(), c)
return otherRead
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitSelectstmt(ctx *pgparser.SelectstmtContext) interface{} {
if spnc := ctx.Select_no_parens(); spnc != nil {
return spnc.Accept(v)
}
selectstmt := new(sqlstmt.SelectStmt)
selectstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return selectstmt
}
func (v *PgsqlVisitor) VisitSelect_with_parens(ctx *pgparser.Select_with_parensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitSelect_no_parens(ctx *pgparser.Select_no_parensContext) interface{} {
if c := ctx.Select_clause(); c == nil {
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
var limit *sqlstmt.Limit
if limitC := ctx.Select_limit(); limitC != nil {
limit = limitC.Accept(v).(*sqlstmt.Limit)
}
if limitC := ctx.Opt_select_limit(); limitC != nil {
limit = limitC.Accept(v).(*sqlstmt.Limit)
}
selectClause := ctx.Select_clause()
asis := selectClause.AllSimple_select_intersect()
// 简单查询
if len(asis) == 1 {
sss := new(sqlstmt.SimpleSelectStmt)
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sss.QuerySpecification = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)[0]
sss.QuerySpecification.Limit = limit
return sss
}
uss := new(sqlstmt.UnionSelectStmt)
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
allUnion := selectClause.AllUNION()
// todo 赋值union信息
for _, union := range allUnion {
uss.UnionType = union.GetText()
}
// uss.QuerySpecifications = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)
uss.Limit = limit
return uss
}
func (v *PgsqlVisitor) VisitSelect_clause(ctx *pgparser.Select_clauseContext) interface{} {
qs := make([]*sqlstmt.QuerySpecification, 0)
for _, ssi := range ctx.AllSimple_select_intersect() {
qs = append(qs, ssi.Accept(v).(*sqlstmt.QuerySpecification))
}
return qs
}
func (v *PgsqlVisitor) VisitSimple_select_intersect(ctx *pgparser.Simple_select_intersectContext) interface{} {
// 只返回一个查询INTERSECT交集暂不支持
if spsc := ctx.AllSimple_select_pramary(); spsc != nil {
return spsc[0].Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitSimple_select_pramary(ctx *pgparser.Simple_select_pramaryContext) interface{} {
qs := new(sqlstmt.QuerySpecification)
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if c := ctx.From_clause(); c != nil {
qs.From = c.Accept(v).(*sqlstmt.TableSources)
}
if c := ctx.Opt_target_list(); c != nil {
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
}
if c := ctx.Target_list(); c != nil {
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
}
if c := ctx.Where_clause(); c != nil && c.A_expr() != nil {
qs.Where = c.A_expr().Accept(v).(sqlstmt.IExpr)
}
return qs
}
func (v *PgsqlVisitor) VisitSelect_limit(ctx *pgparser.Select_limitContext) interface{} {
limit := new(sqlstmt.Limit)
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if lc := ctx.Limit_clause(); lc != nil {
if lv := lc.Select_limit_value(); lv != nil {
limit.RowCount = cast.ToInt(lv.GetText())
}
}
if oc := ctx.Offset_clause(); oc != nil {
if ov := oc.Select_offset_value(); ov != nil {
limit.Offset = cast.ToInt(ov.GetText())
}
}
return limit
}
func (v *PgsqlVisitor) VisitOpt_select_limit(ctx *pgparser.Opt_select_limitContext) interface{} {
if slc := ctx.Select_limit(); slc != nil {
return slc.Accept(v)
}
return nil
}
func (v *PgsqlVisitor) VisitFrom_clause(ctx *pgparser.From_clauseContext) interface{} {
if c := ctx.From_list(); c != nil {
return c.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitFrom_list(ctx *pgparser.From_listContext) interface{} {
ts := new(sqlstmt.TableSources)
ts.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
// ts.StartIndex = ctx.GetStart().GetStart()
// ts.StopIndex = ctx.GetStop().GetStop()
tableSources := make([]sqlstmt.ITableSource, 0)
allTableRefCtx := ctx.AllTable_ref()
for _, trc := range allTableRefCtx {
tableSources = append(tableSources, trc.Accept(v).(sqlstmt.ITableSource))
}
ts.TableSources = tableSources
return ts
}
func (v *PgsqlVisitor) VisitNon_ansi_join(ctx *pgparser.Non_ansi_joinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitTable_ref(ctx *pgparser.Table_refContext) interface{} {
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
atomTable := new(sqlstmt.AtomTableItem)
if c := ctx.Relation_expr(); c != nil {
tableName := new(sqlstmt.TableName)
if qn := c.Qualified_name(); qn != nil {
if qc := qn.Colid(); qc != nil {
if c := qn.Indirection(); c != nil {
tableName.Owner = qc.GetText()
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
}
}
}
atomTable.TableName = tableName
}
if c := ctx.Opt_alias_clause(); c != nil {
if aliasC := c.Table_alias_clause(); aliasC != nil {
atomTable.Alias = aliasC.Table_alias().GetText()
}
}
tableSourceBase.TableSourceItem = atomTable
return tableSourceBase
}
func (v *PgsqlVisitor) VisitOpt_target_list(ctx *pgparser.Opt_target_listContext) interface{} {
if c := ctx.Target_list(); c != nil {
return c.Accept(v)
}
ses := new(sqlstmt.SelectElements)
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ses
}
func (v *PgsqlVisitor) VisitTarget_list(ctx *pgparser.Target_listContext) interface{} {
ses := new(sqlstmt.SelectElements)
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if tecs := ctx.AllTarget_el(); tecs != nil {
eles := make([]sqlstmt.ISelectElement, 0)
for _, tec := range tecs {
eles = append(eles, tec.Accept(v).(sqlstmt.ISelectElement))
}
ses.Elements = eles
}
if len(ses.Elements) == 1 && ses.Elements[0].GetText() == "*" {
ses.Star = "*"
}
return ses
}
// Visit a parse tree produced by PostgreSQLParser#target_label.
func (v *PgsqlVisitor) VisitTarget_label(ctx *pgparser.Target_labelContext) interface{} {
sce := new(sqlstmt.SelectColumnElement)
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
columnName := new(sqlstmt.ColumnName)
if c := ctx.Collabel(); c != nil {
sce.Alias = c.GetText()
}
if c := ctx.Identifier(); c != nil {
sce.Alias = c.GetText()
}
if exprCtx := ctx.A_expr(); exprCtx != nil {
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), exprCtx)
if aextrCtx := exprCtx.A_expr_qual(); aextrCtx != nil {
col := aextrCtx.GetText()
ownerAndColname := strings.Split(col, ".")
if len(ownerAndColname) == 2 {
columnName.Owner = ownerAndColname[0]
columnName.Identifier = sqlstmt.NewIdentifierValue(ownerAndColname[1])
} else {
columnName.Identifier = sqlstmt.NewIdentifierValue(col)
}
}
} else {
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
}
sce.ColumnName = columnName
return sce
}
// Visit a parse tree produced by PostgreSQLParser#target_star.
func (v *PgsqlVisitor) VisitTarget_star(ctx *pgparser.Target_starContext) interface{} {
sse := new(sqlstmt.SelectStarElement)
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sse.FullId = ctx.STAR().GetText()
return sse
}
func (v *PgsqlVisitor) VisitAlias_clause(ctx *pgparser.Alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitOpt_alias_clause(ctx *pgparser.Opt_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitTable_alias_clause(ctx *pgparser.Table_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitFunc_alias_clause(ctx *pgparser.Func_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitRelation_expr(ctx *pgparser.Relation_exprContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitRelation_expr_list(ctx *pgparser.Relation_expr_listContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitUpdatestmt(ctx *pgparser.UpdatestmtContext) interface{} {
updateStmt := new(sqlstmt.UpdateStmt)
updateStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
updateStmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
updateStmt.UpdatedElements = ctx.Set_clause_list().Accept(v).([]*sqlstmt.UpdatedElement)
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
updateStmt.Where = ec.Accept(v).(sqlstmt.IExpr)
}
return updateStmt
}
func (v *PgsqlVisitor) VisitSet_clause_list(ctx *pgparser.Set_clause_listContext) interface{} {
ues := make([]*sqlstmt.UpdatedElement, 0)
aucs := ctx.AllSet_clause()
for _, auc := range aucs {
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
}
return ues
}
func (v *PgsqlVisitor) VisitSet_clause(ctx *pgparser.Set_clauseContext) interface{} {
updateEle := new(sqlstmt.UpdatedElement)
updateEle.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
updateEle.ColumnName = ctx.Set_target().Accept(v).(*sqlstmt.ColumnName)
if ac := ctx.A_expr(); ac != nil {
updateEle.Value = ac.Accept(v).(sqlstmt.IExpr)
}
return updateEle
}
func (v *PgsqlVisitor) VisitSet_target(ctx *pgparser.Set_targetContext) interface{} {
columnName := new(sqlstmt.ColumnName)
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if ic := ctx.Opt_indirection(); ic != nil {
if ic.GetText() == "" {
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
} else {
columnName.Owner = ctx.Colid().GetText()
columnName.Identifier = sqlstmt.NewIdentifierValue(ic.GetText())
}
} else {
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
}
return columnName
}
func (v *PgsqlVisitor) VisitSet_target_list(ctx *pgparser.Set_target_listContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitA_expr(ctx *pgparser.A_exprContext) interface{} {
expr := new(sqlstmt.Expr)
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return expr
}
func (v *PgsqlVisitor) VisitA_expr_qual(ctx *pgparser.A_expr_qualContext) interface{} {
expr := new(sqlstmt.Expr)
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return expr
}
func (v *PgsqlVisitor) VisitDeletestmt(ctx *pgparser.DeletestmtContext) interface{} {
deletestmt := new(sqlstmt.DeleteStmt)
deletestmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
deletestmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
deletestmt.Where = ec.Accept(v).(sqlstmt.IExpr)
}
return deletestmt
}
func (v *PgsqlVisitor) VisitInsertstmt(ctx *pgparser.InsertstmtContext) interface{} {
insertstmt := new(sqlstmt.InsertStmt)
insertstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
insertstmt.TableName = ctx.Insert_target().Accept(v).(*sqlstmt.TableName)
return insertstmt
}
func (v *PgsqlVisitor) VisitInsert_target(ctx *pgparser.Insert_targetContext) interface{} {
tableName := new(sqlstmt.TableName)
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
table := ctx.GetText()
if strings.Contains(table, ".") {
tableAndOwner := strings.Split(table, ".")
tableName.Identifier = sqlstmt.NewIdentifierValue(tableAndOwner[1])
tableName.Owner = tableAndOwner[0]
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(table)
}
return tableName
}
func (v *PgsqlVisitor) VisitInsert_rest(ctx *pgparser.Insert_restContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) GetTableSourcesByrelation_expr_opt_alias(ctx pgparser.IRelation_expr_opt_aliasContext) *sqlstmt.TableSources {
tableSources := new(sqlstmt.TableSources)
tableSources.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
atomTable := new(sqlstmt.AtomTableItem)
atomTable.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if c := ctx.Relation_expr(); c != nil {
tableName := new(sqlstmt.TableName)
if qn := c.Qualified_name(); qn != nil {
if qc := qn.Colid(); qc != nil {
if c := qn.Indirection(); c != nil {
tableName.Owner = qc.GetText()
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
}
}
}
atomTable.TableName = tableName
}
if c := ctx.Colid(); c != nil {
atomTable.Alias = c.GetText()
}
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tableSourceBase.TableSourceItem = atomTable
tableSources.TableSources = []sqlstmt.ITableSource{tableSourceBase}
return tableSources
}