mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-06-17 07:25:20 +08:00
refactor: 移除antlr4减小包体积&ai助手优化
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
java -jar antlr-4.13.1-complete.jar -Dlanguage=Go -package parser -visitor *.g4
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,91 +0,0 @@
|
||||
/*
|
||||
PostgreSQL grammar.
|
||||
The MIT License (MIT).
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
)
|
||||
|
||||
type PostgreSQLLexerBase struct {
|
||||
*antlr.BaseLexer
|
||||
|
||||
stack StringStack
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) pushTag() {
|
||||
receiver.stack.Push(receiver.GetText())
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) isTag() bool {
|
||||
if receiver.stack.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
return receiver.GetText() == receiver.stack.PeekOrEmpty()
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) popTag() {
|
||||
_, _ = receiver.stack.Pop()
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) checkLA(c int) bool {
|
||||
return receiver.GetInputStream().LA(1) != c
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) charIsLetter() bool {
|
||||
c := receiver.GetInputStream().LA(-1)
|
||||
return unicode.IsLetter(rune(c))
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) HandleNumericFail() {
|
||||
index := receiver.GetInputStream().Index() - 2
|
||||
receiver.GetInputStream().Seek(index)
|
||||
receiver.SetType(PostgreSQLLexerIntegral)
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) HandleLessLessGreaterGreater() {
|
||||
if receiver.GetText() == "<<" {
|
||||
receiver.SetType(PostgreSQLLexerLESS_LESS)
|
||||
}
|
||||
if receiver.GetText() == ">>" {
|
||||
receiver.SetType(PostgreSQLLexerGREATER_GREATER)
|
||||
}
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) UnterminatedBlockCommentDebugAssert() {
|
||||
//Debug.Assert(InputStream.LA(1) == -1 /*EOF*/);
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLLexerBase) CheckIfUtf32Letter() bool {
|
||||
codePoint := receiver.GetInputStream().LA(-2)<<8 + receiver.GetInputStream().LA(-1)
|
||||
var c []rune
|
||||
if codePoint < 0x10000 {
|
||||
c = []rune{rune(codePoint)}
|
||||
} else {
|
||||
codePoint -= 0x10000
|
||||
c = []rune{
|
||||
(rune)(codePoint/0x400 + 0xd800),
|
||||
(rune)(codePoint%0x400 + 0xdc00),
|
||||
}
|
||||
}
|
||||
return unicode.IsLetter(c[0])
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
/*
|
||||
PostgreSQL grammar.
|
||||
The MIT License (MIT).
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
type PostgreSQLParseError struct {
|
||||
Number int
|
||||
Offset int
|
||||
Line int
|
||||
Column int
|
||||
Message string
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,214 +0,0 @@
|
||||
/*
|
||||
PostgreSQL grammar.
|
||||
The MIT License (MIT).
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
)
|
||||
|
||||
type PostgreSQLParserBase struct {
|
||||
*antlr.BaseParser
|
||||
|
||||
parseErrors []*PostgreSQLParseError
|
||||
}
|
||||
|
||||
func NewPostgreSQLParserBase(input antlr.TokenStream) *PostgreSQLParserBase {
|
||||
return &PostgreSQLParserBase{
|
||||
BaseParser: antlr.NewBaseParser(input),
|
||||
}
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLParserBase) GetParsedSqlTree(script string, line int) antlr.ParserRuleContext {
|
||||
parser := getPostgreSQLParser(script)
|
||||
result := parser.Root()
|
||||
for _, err := range parser.parseErrors {
|
||||
receiver.parseErrors = append(receiver.parseErrors, &PostgreSQLParseError{
|
||||
Number: err.Number,
|
||||
Offset: err.Offset,
|
||||
Line: err.Line + line,
|
||||
Column: err.Column,
|
||||
Message: err.Message,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (receiver *PostgreSQLParserBase) ParseRoutineBody(localContextInterface ICreatefunc_opt_listContext) {
|
||||
localContext, ok := localContextInterface.(*Createfunc_opt_listContext)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var lang string
|
||||
for _, coi := range localContext.AllCreatefunc_opt_item() {
|
||||
createFuncOptItemContext, ok := coi.(*Createfunc_opt_itemContext)
|
||||
if !ok || createFuncOptItemContext.LANGUAGE() == nil {
|
||||
continue
|
||||
}
|
||||
nonReservedWordOrSConstContextInterface := createFuncOptItemContext.Nonreservedword_or_sconst()
|
||||
if nonReservedWordOrSConstContextInterface == nil {
|
||||
continue
|
||||
}
|
||||
nonReservedWordOrSConstContext, ok := nonReservedWordOrSConstContextInterface.(*Nonreservedword_or_sconstContext)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
nonReservedWordContextInterface := nonReservedWordOrSConstContext.Nonreservedword()
|
||||
if nonReservedWordContextInterface == nil {
|
||||
continue
|
||||
}
|
||||
nonReservedWordContext, ok := nonReservedWordContextInterface.(*NonreservedwordContext)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
identifierInterface := nonReservedWordContext.Identifier()
|
||||
if identifierInterface == nil {
|
||||
continue
|
||||
}
|
||||
identifier, ok := identifierInterface.(*IdentifierContext)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
node := identifier.Identifier()
|
||||
if node == nil {
|
||||
continue
|
||||
}
|
||||
lang = node.GetText()
|
||||
break
|
||||
}
|
||||
if lang == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var funcAs *Createfunc_opt_itemContext
|
||||
for _, coi := range localContext.AllCreatefunc_opt_item() {
|
||||
ctx, ok := coi.(*Createfunc_opt_itemContext)
|
||||
if !ok || ctx.LANGUAGE() == nil {
|
||||
continue
|
||||
}
|
||||
as := ctx.Func_as()
|
||||
if as != nil {
|
||||
funcAs = ctx
|
||||
break
|
||||
}
|
||||
}
|
||||
if funcAs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
funcAsContextInterface := funcAs.Func_as()
|
||||
if funcAsContextInterface == nil {
|
||||
return
|
||||
}
|
||||
funcAsContext, ok := funcAsContextInterface.(*Func_asContext)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sConstContextInterface := funcAsContext.Sconst(0)
|
||||
if sConstContextInterface == nil {
|
||||
return
|
||||
}
|
||||
sConstContext, ok := sConstContextInterface.(*SconstContext)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
text := GetRoutineBodyString(sConstContext)
|
||||
line := sConstContext.GetStart().GetLine()
|
||||
parser := getPostgreSQLParser(text)
|
||||
switch lang {
|
||||
case "plpgsql":
|
||||
funcAs.Func_as().(*Func_asContext).Definition = parser.Plsqlroot()
|
||||
case "sql":
|
||||
funcAs.Func_as().(*Func_asContext).Definition = parser.Root()
|
||||
}
|
||||
for _, err := range parser.parseErrors {
|
||||
receiver.parseErrors = append(receiver.parseErrors, &PostgreSQLParseError{
|
||||
Number: err.Number,
|
||||
Offset: err.Offset,
|
||||
Line: err.Line + line,
|
||||
Column: err.Column,
|
||||
Message: err.Message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TrimQuotes(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return s[1 : len(s)-2]
|
||||
}
|
||||
|
||||
func unquote(s string) string {
|
||||
result := strings.Builder{}
|
||||
length := len(s)
|
||||
index := 0
|
||||
for index < length {
|
||||
c := s[index]
|
||||
result.WriteByte(c)
|
||||
if c == '\'' && index < length-1 && (s[index+1] == '\'') {
|
||||
index++
|
||||
}
|
||||
index++
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func GetRoutineBodyString(rule *SconstContext) string {
|
||||
if rule.Anysconst() == nil {
|
||||
return ""
|
||||
}
|
||||
anySConstContext := rule.Anysconst().(*AnysconstContext)
|
||||
|
||||
stringConstant := anySConstContext.StringConstant()
|
||||
if stringConstant != nil {
|
||||
return unquote(TrimQuotes(stringConstant.GetText()))
|
||||
}
|
||||
|
||||
unicodeEscapeStringConstant := anySConstContext.UnicodeEscapeStringConstant()
|
||||
if unicodeEscapeStringConstant != nil {
|
||||
return TrimQuotes(unicodeEscapeStringConstant.GetText())
|
||||
}
|
||||
|
||||
escapeStringConstant := anySConstContext.EscapeStringConstant()
|
||||
if escapeStringConstant != nil {
|
||||
return TrimQuotes(escapeStringConstant.GetText())
|
||||
}
|
||||
|
||||
result := strings.Builder{}
|
||||
for _, node := range anySConstContext.AllDollarText() {
|
||||
result.WriteString(node.GetText())
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func getPostgreSQLParser(script string) *PostgreSQLParser {
|
||||
stream := antlr.NewInputStream(script)
|
||||
lexer := NewPostgreSQLLexer(stream)
|
||||
tokenStream := antlr.NewCommonTokenStream(lexer, 0)
|
||||
parser := NewPostgreSQLParser(tokenStream)
|
||||
errorListener := new(PostgreSQLParserErrorListener)
|
||||
errorListener.grammar = parser
|
||||
parser.AddErrorListener(errorListener)
|
||||
return parser
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
/*
|
||||
PostgreSQL grammar.
|
||||
The MIT License (MIT).
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import "github.com/antlr4-go/antlr/v4"
|
||||
|
||||
type PostgreSQLParserErrorListener struct {
|
||||
grammar *PostgreSQLParser
|
||||
}
|
||||
|
||||
var _ antlr.ErrorListener = &PostgreSQLParserErrorListener{}
|
||||
|
||||
func (receiver PostgreSQLParserErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
|
||||
receiver.grammar.parseErrors = append(receiver.grammar.parseErrors, &PostgreSQLParseError{
|
||||
Number: 0,
|
||||
Offset: 0,
|
||||
Line: line,
|
||||
Column: column,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (receiver PostgreSQLParserErrorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
func (receiver PostgreSQLParserErrorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
func (receiver PostgreSQLParserErrorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
|
||||
// ignore
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,77 +0,0 @@
|
||||
/*
|
||||
PostgreSQL grammar.
|
||||
The MIT License (MIT).
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
package parser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorStackEmpty = errors.New("stack empty")
|
||||
)
|
||||
|
||||
type StringStack struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
func (receiver *StringStack) Push(value string) {
|
||||
receiver.items = append(receiver.items, value)
|
||||
}
|
||||
|
||||
func (receiver *StringStack) Pop() (string, error) {
|
||||
if receiver.IsEmpty() {
|
||||
return "", ErrorStackEmpty
|
||||
}
|
||||
value := receiver.items[0]
|
||||
receiver.items = receiver.items[1:]
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (receiver *StringStack) PopOrEmpty() string {
|
||||
value, err := receiver.Pop()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (receiver *StringStack) Peek() (string, error) {
|
||||
if receiver.IsEmpty() {
|
||||
return "", ErrorStackEmpty
|
||||
}
|
||||
return receiver.items[0], nil
|
||||
}
|
||||
|
||||
func (receiver *StringStack) PeekOrEmpty() string {
|
||||
value, err := receiver.Peek()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (receiver *StringStack) Size() int {
|
||||
return len(receiver.items)
|
||||
}
|
||||
|
||||
func (receiver *StringStack) IsEmpty() bool {
|
||||
return receiver.Size() == 0
|
||||
}
|
||||
745
server/internal/db/dbm/sqlparser/pgsql/parser.go
Normal file
745
server/internal/db/dbm/sqlparser/pgsql/parser.go
Normal file
@@ -0,0 +1,745 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/base"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
|
||||
)
|
||||
|
||||
// Parser PostgreSQL 方言 SQL 解析器
|
||||
type Parser struct {
|
||||
*base.Lexer
|
||||
}
|
||||
|
||||
// NewParser 创建 PostgreSQL 解析器
|
||||
func NewParser(sql string) *Parser {
|
||||
return &Parser{
|
||||
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
|
||||
DoubleQuoteAsIdentifier: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse 解析单条 SQL 语句
|
||||
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
|
||||
p.SkipSemicolons()
|
||||
if p.Current().IsEOF() {
|
||||
return nil, nil
|
||||
}
|
||||
stmt := p.parseStatement()
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseStatement() sqlstmt.Stmt {
|
||||
tok := p.Current()
|
||||
switch {
|
||||
case tok.IsKeyword("SELECT") || tok.Value == "(":
|
||||
return p.parseSelect()
|
||||
case tok.IsKeyword("INSERT"):
|
||||
return p.parseInsert()
|
||||
case tok.IsKeyword("UPDATE"):
|
||||
return p.parseUpdate()
|
||||
case tok.IsKeyword("DELETE"):
|
||||
return p.parseDelete()
|
||||
case tok.IsKeyword("CREATE"):
|
||||
return p.parseCreate()
|
||||
case tok.IsKeyword("DROP"):
|
||||
return p.parseDrop()
|
||||
case tok.IsKeyword("ALTER"):
|
||||
return p.parseAlter()
|
||||
case tok.IsKeyword("WITH"):
|
||||
return p.parseWith()
|
||||
case tok.IsKeyword("TRUNCATE"):
|
||||
return p.parseGenericDdl()
|
||||
default:
|
||||
return p.parseGenericStmt()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SELECT 解析 ----------
|
||||
|
||||
func (p *Parser) parseSelect() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
var selectStmt *sqlstmt.SelectStmt
|
||||
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
selectStmt = p.parseSelectBody()
|
||||
} else {
|
||||
p.SkipParentheses()
|
||||
selectStmt = &sqlstmt.SelectStmt{}
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else {
|
||||
selectStmt = p.parseSelectBody()
|
||||
}
|
||||
|
||||
// UNION 解析
|
||||
selectStmt = p.parseUnions(selectStmt)
|
||||
|
||||
// ORDER BY(UNION 之后的 ORDER BY)
|
||||
if selectStmt.OrderBy == nil && p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
selectStmt.OrderBy = p.parseOrderBy()
|
||||
}
|
||||
|
||||
// LIMIT(UNION 之后的 LIMIT 覆盖子查询内部的 LIMIT)
|
||||
if p.Current().IsKeyword("LIMIT") || p.Current().IsKeyword("OFFSET") {
|
||||
if limit := p.parseLimit(); limit != nil {
|
||||
selectStmt.Limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
// FOR UPDATE
|
||||
if p.Current().IsKeyword("FOR") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("UPDATE") {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// RETURNING(PostgreSQL 特有)
|
||||
if p.Current().IsKeyword("RETURNING") {
|
||||
p.Consume()
|
||||
// 跳过 RETURNING 后面的所有列(可能包含逗号)
|
||||
for !p.Current().IsEOF() && p.Current().Value != ";" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// 更新完整文本(包含所有已解析的内容)
|
||||
selectStmt.Base = sqlstmt.Base{Text: p.TextFrom(start)}
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
|
||||
start := p.Pos
|
||||
|
||||
distinct := false
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("DISTINCT") {
|
||||
p.Consume()
|
||||
distinct = true
|
||||
}
|
||||
}
|
||||
|
||||
// SELECT 项
|
||||
var items []sqlstmt.SelectItem
|
||||
for !p.Current().IsEOF() && !p.IsSelectClauseEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
itemStart := p.Pos
|
||||
p.SkipExpr()
|
||||
text := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
|
||||
col, alias := p.ExtractColumnAndAlias(text)
|
||||
items = append(items, sqlstmt.SelectItem{
|
||||
Text: text,
|
||||
ColumnName: col,
|
||||
Alias: alias,
|
||||
})
|
||||
}
|
||||
|
||||
selectStmt := &sqlstmt.SelectStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Distinct: distinct,
|
||||
Items: items,
|
||||
}
|
||||
|
||||
// FROM
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
selectStmt.From = p.parseFromClause()
|
||||
}
|
||||
|
||||
// JOIN
|
||||
for p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
|
||||
if join := p.parseJoinClause(); join != nil {
|
||||
selectStmt.Joins = append(selectStmt.Joins, *join)
|
||||
}
|
||||
}
|
||||
|
||||
// WHERE
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
selectStmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// GROUP BY
|
||||
if p.Current().IsKeyword("GROUP") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.SkipGroupByExpr()
|
||||
}
|
||||
|
||||
// HAVING
|
||||
if p.Current().IsKeyword("HAVING") {
|
||||
p.Consume()
|
||||
p.SkipExpr()
|
||||
}
|
||||
|
||||
// ORDER BY
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
selectStmt.OrderBy = p.parseOrderBy()
|
||||
}
|
||||
|
||||
// LIMIT/OFFSET(子查询或简单 SELECT 的 LIMIT)
|
||||
if p.Current().IsKeyword("LIMIT") || p.Current().IsKeyword("OFFSET") {
|
||||
selectStmt.Limit = p.parseLimit()
|
||||
}
|
||||
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseUnions(selectStmt *sqlstmt.SelectStmt) *sqlstmt.SelectStmt {
|
||||
for p.Current().IsKeyword("UNION") {
|
||||
p.Consume()
|
||||
all := false
|
||||
if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
all = true
|
||||
}
|
||||
|
||||
var nextSelect *sqlstmt.SelectStmt
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
nextSelect = p.parseSelectBody()
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
nextSelect = p.parseSelectBody()
|
||||
}
|
||||
|
||||
if nextSelect != nil {
|
||||
// 如果 unionSelect 有 LIMIT 或 ORDER BY,移动到外层 selectStmt
|
||||
if nextSelect.Limit != nil {
|
||||
selectStmt.Limit = nextSelect.Limit
|
||||
nextSelect.Limit = nil
|
||||
}
|
||||
if len(nextSelect.OrderBy) > 0 {
|
||||
selectStmt.OrderBy = nextSelect.OrderBy
|
||||
nextSelect.OrderBy = nil
|
||||
}
|
||||
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
|
||||
Select: nextSelect,
|
||||
All: all,
|
||||
})
|
||||
}
|
||||
}
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseOrderBy() []sqlstmt.OrderByItem {
|
||||
var items []sqlstmt.OrderByItem
|
||||
for !p.Current().IsEOF() && !p.IsExprEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
start := p.Pos
|
||||
p.SkipExpr()
|
||||
text := p.TextFromExclusive(start)
|
||||
|
||||
desc := false
|
||||
upper := strings.ToUpper(text)
|
||||
if strings.HasSuffix(upper, " DESC") {
|
||||
desc = true
|
||||
text = strings.TrimSpace(text[:len(text)-5])
|
||||
} else if strings.HasSuffix(upper, " ASC") {
|
||||
text = strings.TrimSpace(text[:len(text)-4])
|
||||
}
|
||||
|
||||
items = append(items, sqlstmt.OrderByItem{
|
||||
Text: text,
|
||||
Desc: desc,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// LIMIT 解析(PostgreSQL: OFFSET ... LIMIT 或 LIMIT ... ALL)
|
||||
func (p *Parser) parseLimit() *sqlstmt.Limit {
|
||||
start := p.Pos
|
||||
limit := &sqlstmt.Limit{}
|
||||
|
||||
// PostgreSQL 支持两种顺序:
|
||||
// 1. LIMIT count OFFSET offset
|
||||
// 2. OFFSET offset LIMIT count
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
p.Consume()
|
||||
// LIMIT ALL 或 LIMIT 值
|
||||
if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
limit.Text = p.TextFrom(start)
|
||||
return limit
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
countStr := p.Consume().Value
|
||||
limit.Count, _ = strconv.Atoi(countStr)
|
||||
}
|
||||
// 检查是否有 OFFSET
|
||||
if p.Current().IsKeyword("OFFSET") {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
offsetStr := p.Consume().Value
|
||||
limit.Offset, _ = strconv.Atoi(offsetStr)
|
||||
} else if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
} else if p.Current().IsKeyword("OFFSET") {
|
||||
p.Consume()
|
||||
// 跳过 OFFSET 值
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
offsetStr := p.Consume().Value
|
||||
limit.Offset, _ = strconv.Atoi(offsetStr)
|
||||
} else if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
p.Consume()
|
||||
}
|
||||
// 检查是否有限制
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
limit.Text = p.TextFrom(start)
|
||||
return limit
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
countStr := p.Consume().Value
|
||||
limit.Count, _ = strconv.Atoi(countStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
limit.Text = p.TextFrom(start)
|
||||
return limit
|
||||
}
|
||||
|
||||
// ---------- FROM 解析 ----------
|
||||
|
||||
func (p *Parser) parseFromClause() []sqlstmt.TableRef {
|
||||
var tables []sqlstmt.TableRef
|
||||
for !p.Current().IsEOF() && !p.IsFromClauseEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
|
||||
break
|
||||
}
|
||||
ref := p.parseTableRef()
|
||||
if ref.Name != "" {
|
||||
tables = append(tables, ref)
|
||||
}
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableRef() sqlstmt.TableRef {
|
||||
start := p.Pos
|
||||
|
||||
// 子查询
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
// 使用完整解析(包含 UNION)
|
||||
p.parseSelect()
|
||||
} else {
|
||||
p.SkipParentheses()
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
|
||||
var alias string
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
return sqlstmt.TableRef{
|
||||
Name: p.TextFrom(start),
|
||||
Alias: alias,
|
||||
}
|
||||
}
|
||||
|
||||
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
|
||||
return sqlstmt.TableRef{}
|
||||
}
|
||||
|
||||
ref := sqlstmt.TableRef{}
|
||||
part1 := p.Consume().Value
|
||||
|
||||
if p.Current().Value == "." {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
|
||||
part2 := p.Consume().Value
|
||||
ref.Schema = p.Unquote(part1)
|
||||
ref.Name = p.Unquote(part2)
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
ref.Alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
|
||||
start := p.Pos
|
||||
joinType := sqlstmt.JoinKindInner
|
||||
if p.Current().IsKeyword("LEFT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindLeft
|
||||
} else if p.Current().IsKeyword("RIGHT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindRight
|
||||
} else if p.Current().IsKeyword("FULL") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindFull
|
||||
} else if p.Current().IsKeyword("CROSS") {
|
||||
p.Consume()
|
||||
joinType = sqlstmt.JoinKindCross
|
||||
} else if p.Current().IsKeyword("NATURAL") {
|
||||
p.Consume()
|
||||
joinType = sqlstmt.JoinKindNatural
|
||||
} else if p.Current().IsKeyword("INNER") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
if !p.Current().IsKeyword("JOIN") {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume()
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name == "" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
|
||||
var onExpr *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("ON") {
|
||||
p.Consume()
|
||||
onStart := p.Pos
|
||||
p.SkipExpr()
|
||||
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
|
||||
} else if p.Current().IsKeyword("USING") {
|
||||
p.Consume()
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.JoinClause{
|
||||
Kind: joinType,
|
||||
Table: tableRef,
|
||||
On: onExpr,
|
||||
Text: p.TextFrom(start),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- INSERT 解析 ----------
|
||||
|
||||
func (p *Parser) parseInsert() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // INSERT
|
||||
|
||||
// INTO(可选)
|
||||
if p.Current().IsKeyword("INTO") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
|
||||
tableRef := p.parseTableRef()
|
||||
|
||||
// 解析列名列表 (col1, col2, ...)
|
||||
columns := []string{}
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume() // (
|
||||
for !p.Current().IsEOF() && p.Current().Value != ")" {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
columns = append(columns, p.Unquote(p.Consume().Value))
|
||||
} else {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
if p.Current().Value == ")" {
|
||||
p.Consume() // )
|
||||
}
|
||||
}
|
||||
|
||||
// VALUES 或 SELECT 或直接是 ON CONFLICT
|
||||
if p.Current().IsKeyword("VALUES") || p.Current().Value == "VALUES" {
|
||||
p.Consume()
|
||||
// 消费到 ON 或 RETURNING 或 EOF
|
||||
for !p.Current().IsEOF() {
|
||||
val := strings.ToUpper(p.Current().Value)
|
||||
if val == "ON" || val == "RETURNING" {
|
||||
break
|
||||
}
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
} else {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
} else if p.Current().IsKeyword("SELECT") || p.Current().Value == "SELECT" {
|
||||
// INSERT INTO ... SELECT ...
|
||||
p.parseSelect()
|
||||
}
|
||||
|
||||
// ON CONFLICT(PostgreSQL 特有)
|
||||
if strings.ToUpper(p.Current().Value) == "ON" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("CONFLICT") {
|
||||
p.Consume()
|
||||
// 跳过 ON CONFLICT 后面的所有内容直到 RETURNING 或 EOF
|
||||
for !p.Current().IsEOF() && !p.Current().IsKeyword("RETURNING") {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RETURNING(PostgreSQL 特有)
|
||||
if p.Current().IsKeyword("RETURNING") {
|
||||
p.Consume()
|
||||
// 跳过 RETURNING 后面的内容
|
||||
for !p.Current().IsEOF() && p.Current().Value != ";" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.InsertStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Table: tableRef,
|
||||
Columns: columns,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- UPDATE 解析 ----------
|
||||
// PostgreSQL UPDATE 不支持 ORDER BY/LIMIT
|
||||
|
||||
func (p *Parser) parseUpdate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // UPDATE
|
||||
|
||||
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
|
||||
tableRef := p.parseTableRef()
|
||||
tables := []sqlstmt.TableRef{tableRef}
|
||||
|
||||
// SET - 解析字段赋值
|
||||
assignments := make([]sqlstmt.Assignment, 0)
|
||||
if p.Current().IsKeyword("SET") {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
assign := p.parseAssignment()
|
||||
if assign != nil {
|
||||
assignments = append(assignments, *assign)
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WHERE
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// RETURNING(PostgreSQL 特有)
|
||||
if p.Current().IsKeyword("RETURNING") {
|
||||
p.Consume()
|
||||
// 跳过 RETURNING 后面的所有列(可能包含逗号)
|
||||
for !p.Current().IsEOF() && p.Current().Value != ";" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.UpdateStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Set: assignments,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DELETE 解析 ----------
|
||||
// PostgreSQL DELETE 不支持 ORDER BY/LIMIT
|
||||
|
||||
func (p *Parser) parseDelete() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // DELETE
|
||||
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
// 使用 parseTableRef 正确解析表名(支持 schema.table 和双引号)
|
||||
tableRef := p.parseTableRef()
|
||||
tables := []sqlstmt.TableRef{tableRef}
|
||||
|
||||
// WHERE
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// RETURNING(PostgreSQL 特有)
|
||||
if p.Current().IsKeyword("RETURNING") {
|
||||
p.Consume()
|
||||
// 跳过 RETURNING 后面的所有列(可能包含逗号)
|
||||
for !p.Current().IsEOF() && p.Current().Value != ";" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.DeleteStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DDL 解析 ----------
|
||||
|
||||
func (p *Parser) parseCreate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "CREATE",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseDrop() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DROP",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseAlter() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "ALTER",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DDL",
|
||||
}
|
||||
}
|
||||
|
||||
// parseAssignment 解析 SET 字段赋值:column = value
|
||||
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
|
||||
start := p.Pos
|
||||
colText := ""
|
||||
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
colText += p.Consume().Value
|
||||
}
|
||||
if p.Current().Value != "=" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume() // =
|
||||
|
||||
valStart := p.Pos
|
||||
for !p.Current().IsEOF() && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
return &sqlstmt.Assignment{
|
||||
Column: p.Unquote(strings.TrimSpace(colText)),
|
||||
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
|
||||
Text: p.TextFromExclusive(start),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- WITH 解析 ----------
|
||||
|
||||
func (p *Parser) parseWith() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.WithStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 通用语句解析 ----------
|
||||
|
||||
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.OtherStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
846
server/internal/db/dbm/sqlparser/pgsql/parser_dml_test.go
Normal file
846
server/internal/db/dbm/sqlparser/pgsql/parser_dml_test.go
Normal file
@@ -0,0 +1,846 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== INSERT 测试 ==========
|
||||
|
||||
func TestPgInsertBasic(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
|
||||
if !ok {
|
||||
t.Fatal("expected InsertStmt")
|
||||
}
|
||||
|
||||
// 验证完整文本
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证表名
|
||||
if insertStmt.Table.Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
|
||||
}
|
||||
|
||||
// 验证列名
|
||||
if len(insertStmt.Columns) != 2 {
|
||||
t.Fatalf("expected 2 columns, got %d", len(insertStmt.Columns))
|
||||
}
|
||||
if insertStmt.Columns[0] != "name" {
|
||||
t.Errorf("expected Columns[0]='name', got '%s'", insertStmt.Columns[0])
|
||||
}
|
||||
if insertStmt.Columns[1] != "age" {
|
||||
t.Errorf("expected Columns[1]='age', got '%s'", insertStmt.Columns[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgInsertMultipleRows(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgInsertReturning(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30) RETURNING id, name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgInsertFromSelect(t *testing.T) {
|
||||
sql := "INSERT INTO users_backup SELECT * FROM users WHERE status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UPDATE 测试 ==========
|
||||
|
||||
func TestPgUpdateBasic(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if updateStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证表名
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||||
}
|
||||
if updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
|
||||
// 验证 SET 字段
|
||||
if len(updateStmt.Set) != 1 {
|
||||
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
|
||||
}
|
||||
if updateStmt.Set[0].Column != "name" {
|
||||
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
|
||||
}
|
||||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
|
||||
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if updateStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateMultipleColumns(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if updateStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证表名
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||||
}
|
||||
if updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
|
||||
// 验证 SET 字段(3个赋值)
|
||||
if len(updateStmt.Set) != 3 {
|
||||
t.Fatalf("expected 3 assignments, got %d", len(updateStmt.Set))
|
||||
}
|
||||
|
||||
// 验证第一个赋值:name = 'John'
|
||||
if updateStmt.Set[0].Column != "name" {
|
||||
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
|
||||
}
|
||||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
|
||||
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
|
||||
}
|
||||
if updateStmt.Set[0].Text != "name = 'John'" {
|
||||
t.Errorf("expected Set[0].Text=\"name = 'John'\", got '%s'", updateStmt.Set[0].Text)
|
||||
}
|
||||
|
||||
// 验证第二个赋值:age = 30
|
||||
if updateStmt.Set[1].Column != "age" {
|
||||
t.Errorf("expected Set[1].Column='age', got '%s'", updateStmt.Set[1].Column)
|
||||
}
|
||||
if updateStmt.Set[1].Value == nil || updateStmt.Set[1].Value.Text != "30" {
|
||||
t.Errorf("expected Set[1].Value='30', got '%s'", updateStmt.Set[1].Value.Text)
|
||||
}
|
||||
if updateStmt.Set[1].Text != "age = 30" {
|
||||
t.Errorf("expected Set[1].Text='age = 30', got '%s'", updateStmt.Set[1].Text)
|
||||
}
|
||||
|
||||
// 验证第三个赋值:email = 'john@example.com'
|
||||
if updateStmt.Set[2].Column != "email" {
|
||||
t.Errorf("expected Set[2].Column='email', got '%s'", updateStmt.Set[2].Column)
|
||||
}
|
||||
if updateStmt.Set[2].Value == nil || updateStmt.Set[2].Value.Text != "'john@example.com'" {
|
||||
t.Errorf("expected Set[2].Value=''john@example.com'', got '%s'", updateStmt.Set[2].Value.Text)
|
||||
}
|
||||
if updateStmt.Set[2].Text != "email = 'john@example.com'" {
|
||||
t.Errorf("expected Set[2].Text=\"email = 'john@example.com'\", got '%s'", updateStmt.Set[2].Text)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if updateStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateFromJoin(t *testing.T) {
|
||||
// PostgreSQL UPDATE FROM 语法
|
||||
sql := "UPDATE orders SET status = 'shipped' FROM users WHERE orders.user_id = users.id AND users.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
t.Logf("UPDATE text: %s", updateStmt.GetText())
|
||||
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
|
||||
|
||||
// 主表名应该正确解析
|
||||
if len(updateStmt.Tables) < 1 {
|
||||
t.Fatal("expected at least 1 table")
|
||||
}
|
||||
// 注意:当前解析器对于 UPDATE FROM 语法支持有限,可能表名包含别名
|
||||
// 但至少验证能成功解析而不报错
|
||||
t.Logf("Successfully parsed UPDATE FROM statement")
|
||||
}
|
||||
|
||||
// ========== DELETE 测试 ==========
|
||||
|
||||
func TestPgDeleteBasic(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if deleteStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证表名
|
||||
if len(deleteStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
|
||||
}
|
||||
if deleteStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", deleteStmt.Tables[0].Name)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if deleteStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1', got '%s'", deleteStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteUsing(t *testing.T) {
|
||||
sql := "DELETE FROM orders o USING users u WHERE o.user_id = u.id AND u.status = 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if len(deleteStmt.Tables) < 1 || deleteStmt.Tables[0].Name != "orders" {
|
||||
t.Errorf("expected table='orders'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Schema.Table 测试 ==========
|
||||
|
||||
func TestPgInsertWithSchema(t *testing.T) {
|
||||
sql := `INSERT INTO "public"."users" ("name", "age") VALUES ('John', 30)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.Table.Schema != "public" {
|
||||
t.Errorf("expected schema='public', got '%s'", insertStmt.Table.Schema)
|
||||
}
|
||||
if insertStmt.Table.Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateWithSchema(t *testing.T) {
|
||||
sql := `UPDATE "public"."t_db" SET "name" = 'fsdfds3' WHERE "id" = 5`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
t.Logf("UPDATE text: %s", updateStmt.GetText())
|
||||
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
|
||||
t.Logf("UPDATE SET: %+v", updateStmt.Set)
|
||||
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
|
||||
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||||
}
|
||||
if updateStmt.Tables[0].Schema != "public" {
|
||||
t.Errorf("expected schema='public', got '%s'", updateStmt.Tables[0].Schema)
|
||||
}
|
||||
if updateStmt.Tables[0].Name != "t_db" {
|
||||
t.Errorf("expected table='t_db', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
// 验证 SET 字段
|
||||
if len(updateStmt.Set) != 1 {
|
||||
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
|
||||
}
|
||||
if updateStmt.Set[0].Column != "name" {
|
||||
t.Errorf("expected column='name', got '%s'", updateStmt.Set[0].Column)
|
||||
}
|
||||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'fsdfds3'" {
|
||||
t.Errorf("expected value=''fsdfds3'', got '%s'", updateStmt.Set[0].Value.Text)
|
||||
}
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 5` {
|
||||
t.Errorf("expected WHERE='\"id\" = 5', got '%s'", updateStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteWithSchema(t *testing.T) {
|
||||
sql := `DELETE FROM "public"."logs" WHERE "created_at" < '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if len(deleteStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
|
||||
}
|
||||
if deleteStmt.Tables[0].Schema != "public" {
|
||||
t.Errorf("expected schema='public', got '%s'", deleteStmt.Tables[0].Schema)
|
||||
}
|
||||
if deleteStmt.Tables[0].Name != "logs" {
|
||||
t.Errorf("expected table='logs', got '%s'", deleteStmt.Tables[0].Name)
|
||||
}
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"created_at" < '2024-01-01'` {
|
||||
t.Errorf("expected WHERE text")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DDL 测试 ==========
|
||||
|
||||
func TestPgDDLCreate(t *testing.T) {
|
||||
sql := "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50), age INT)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "CREATE" {
|
||||
t.Errorf("expected DdlKind='CREATE', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
if ddlStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDDLDrop(t *testing.T) {
|
||||
sql := "DROP TABLE users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "DROP" {
|
||||
t.Errorf("expected DdlKind='DROP', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDDLAlter(t *testing.T) {
|
||||
sql := "ALTER TABLE users ADD COLUMN email VARCHAR(255)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "ALTER" {
|
||||
t.Errorf("expected DdlKind='ALTER', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDDLTruncate(t *testing.T) {
|
||||
sql := "TRUNCATE TABLE users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
t.Logf("TRUNCATE stmt type: %T", stmt)
|
||||
t.Logf("TRUNCATE text: %s", stmt.GetText())
|
||||
}
|
||||
|
||||
// ========== 复杂 DML 测试 ==========
|
||||
|
||||
func TestPgComplexUpdateWithSubquery(t *testing.T) {
|
||||
sql := "UPDATE users SET total_orders = (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgComplexDeleteWithSubquery(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id NOT IN (SELECT DISTINCT user_id FROM orders)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgInsertOnConflict(t *testing.T) {
|
||||
sql := "INSERT INTO users (id, name) VALUES (1, 'John') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("Actual text: %s", insertStmt.GetText())
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂 INSERT 测试 ==========
|
||||
|
||||
func TestPgInsertWithDoubleQuotes(t *testing.T) {
|
||||
// PostgreSQL 双引号标识符
|
||||
sql := `INSERT INTO "users" ("name", "age", "email") VALUES ('John', 30, 'john@example.com')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgInsertWithSpecialChars(t *testing.T) {
|
||||
// 包含特殊字符
|
||||
sql := `INSERT INTO "logs" ("message", "level") VALUES ('Error: connection failed!', 'ERROR')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT text: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgInsertReturningMultiple(t *testing.T) {
|
||||
// RETURNING 多个字段
|
||||
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com') RETURNING "id", "name", "created_at"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT RETURNING: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgInsertFromSelectComplex(t *testing.T) {
|
||||
// INSERT FROM SELECT 复杂查询
|
||||
sql := `INSERT INTO "users_backup" SELECT * FROM "users" WHERE "status" = 1 AND "created_at" > '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT FROM SELECT: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgInsertOnConflictDoNothing(t *testing.T) {
|
||||
// ON CONFLICT DO NOTHING
|
||||
sql := `INSERT INTO "users" ("id", "name") VALUES (1, 'John') ON CONFLICT ("id") DO NOTHING`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT ON CONFLICT DO NOTHING: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
// ========== 复杂 UPDATE 测试 ==========
|
||||
|
||||
func TestPgUpdateWithDoubleQuotes(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "name" = 'John', "age" = 30 WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 1` {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateWithComplexWhere(t *testing.T) {
|
||||
sql := `UPDATE "orders" SET "status" = 'cancelled' WHERE "status" = 'pending' AND "created_at" < '2024-01-01' AND ("amount" < 100 OR "user_id" IS NULL)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Complex WHERE: %s", updateStmt.Where.Text)
|
||||
}
|
||||
|
||||
func TestPgUpdateWithFunctions(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "updated_at" = NOW(), "login_count" = "login_count" + 1 WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
t.Logf("UPDATE with functions: %s", updateStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgUpdateReturningComplex(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "status" = 0 WHERE "status" = 1 RETURNING "id", "name", "old_status"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != `"status" = 1` {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateWithSubquery(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "total" = (SELECT SUM("amount") FROM "orders" WHERE "user_id" = "users"."id") WHERE "status" = 'active'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUpdateFromComplex(t *testing.T) {
|
||||
// PostgreSQL UPDATE FROM 复杂场景
|
||||
sql := `UPDATE "orders" o SET "status" = 'shipped' FROM "users" u WHERE o."user_id" = u."id" AND u."status" = 1 AND o."created_at" > '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
t.Logf("UPDATE FROM: %s", updateStmt.GetText())
|
||||
}
|
||||
|
||||
// ========== 复杂 DELETE 测试 ==========
|
||||
|
||||
func TestPgDeleteWithDoubleQuotes(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteWithComplexWhere(t *testing.T) {
|
||||
sql := `DELETE FROM "logs" WHERE "created_at" < '2024-01-01' AND ("level" = 'DEBUG' OR "level" = 'INFO')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Complex DELETE WHERE: %s", deleteStmt.Where.Text)
|
||||
}
|
||||
|
||||
func TestPgDeleteReturningComplex(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "status" = 0 RETURNING "id", "name", "deleted_at"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"status" = 0` {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteWithSubquery(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "id" NOT IN (SELECT DISTINCT "user_id" FROM "orders")`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteUsingComplex(t *testing.T) {
|
||||
// PostgreSQL DELETE USING 复杂场景
|
||||
sql := `DELETE FROM "orders" o USING "users" u WHERE o."user_id" = u."id" AND u."status" = 0 AND o."created_at" < '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
t.Logf("DELETE USING: %s", deleteStmt.GetText())
|
||||
}
|
||||
|
||||
// ========== 复杂 DDL 测试 ==========
|
||||
|
||||
func TestPgDDLCreateTableWithQuotes(t *testing.T) {
|
||||
sql := `CREATE TABLE "users" ("id" INT PRIMARY KEY, "name" VARCHAR(50) NOT NULL, "email" VARCHAR(255) UNIQUE)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "CREATE" {
|
||||
t.Errorf("expected DdlKind='CREATE'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDDLCreateTableWithSerial(t *testing.T) {
|
||||
sql := `CREATE TABLE "orders" ("id" SERIAL PRIMARY KEY, "amount" DECIMAL(10,2), "created_at" TIMESTAMP DEFAULT NOW())`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
t.Logf("CREATE TABLE with SERIAL: %s", ddlStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgDDLAlterTableAddColumn(t *testing.T) {
|
||||
sql := `ALTER TABLE "users" ADD COLUMN "email" VARCHAR(255)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "ALTER" {
|
||||
t.Errorf("expected DdlKind='ALTER'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDDLDropIfExists(t *testing.T) {
|
||||
sql := `DROP TABLE IF EXISTS "users"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "DROP" {
|
||||
t.Errorf("expected DdlKind='DROP'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 注释风格测试 ==========
|
||||
|
||||
func TestPgDMLWithSingleLineComment(t *testing.T) {
|
||||
// PostgreSQL 单行注释 --
|
||||
sql := "-- 更新用户信息\nUPDATE \"users\" SET \"name\" = 'John' WHERE \"id\" = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
t.Logf("UPDATE with -- comment: %s", updateStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgDMLWithMultiLineComment(t *testing.T) {
|
||||
// 多行注释 /* */
|
||||
sql := "/* 删除过期订单 */ DELETE FROM \"orders\" WHERE \"created_at\" < '2024-01-01'"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("DELETE with /* */ comment: %s", deleteStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgDMLWithInlineComment(t *testing.T) {
|
||||
// 行内注释
|
||||
sql := `SELECT "id", "name" /* 用户名 */ FROM "users" WHERE "status" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
t.Logf("SELECT with inline comment: %s", selectStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgDMLWithMultipleComments(t *testing.T) {
|
||||
// 多个注释
|
||||
sql := `-- 查询活跃用户
|
||||
/* 只查询最近注册的 */
|
||||
SELECT "id", "name" FROM "users"
|
||||
WHERE "status" = 1 AND "created_at" > '2024-01-01'
|
||||
ORDER BY "created_at" DESC`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Multiple comments: %s", selectStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgInsertWithComment(t *testing.T) {
|
||||
sql := "-- 插入新用户\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('John', 'john@example.com')"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT with comment: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestPgUpdateWithComment(t *testing.T) {
|
||||
sql := `/* 批量更新状态 */
|
||||
UPDATE "orders" SET "status" = 'cancelled'
|
||||
WHERE "status" = 'pending' -- 只更新待处理的订单
|
||||
AND "created_at" < '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("UPDATE with multiple comments: %s", updateStmt.GetText())
|
||||
}
|
||||
455
server/internal/db/dbm/sqlparser/pgsql/parser_pagination_test.go
Normal file
455
server/internal/db/dbm/sqlparser/pgsql/parser_pagination_test.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 简单分页测试 ==========
|
||||
|
||||
func TestPgPaginationSimple(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT 10 OFFSET 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithOffset(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users LIMIT 10 OFFSET 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationOffsetFirst(t *testing.T) {
|
||||
// PostgreSQL 支持 OFFSET 在 LIMIT 前面
|
||||
sql := "SELECT * FROM products OFFSET 30 LIMIT 15"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 30 {
|
||||
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 15 {
|
||||
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationOnlyLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM orders LIMIT 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationOnlyOffset(t *testing.T) {
|
||||
sql := "SELECT * FROM users OFFSET 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 100 {
|
||||
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 0 {
|
||||
t.Errorf("expected count=0, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationLimitAll(t *testing.T) {
|
||||
// PostgreSQL 特有的 LIMIT ALL
|
||||
sql := "SELECT * FROM users OFFSET 50 LIMIT ALL"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂分页测试 ==========
|
||||
|
||||
func TestPgPaginationWithWhere(t *testing.T) {
|
||||
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 20 OFFSET 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithOrderBy(t *testing.T) {
|
||||
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 100 OFFSET 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.OrderBy) != 2 {
|
||||
t.Fatalf("expected 2 ORDER BY clauses")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithWhereOrderBy(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 10 OFFSET 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
|
||||
t.Errorf("expected WHERE='status = 'active''")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
|
||||
t.Errorf("expected ORDER BY score")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 20 OFFSET 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatal("expected 1 JOIN")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 100 {
|
||||
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithGroupBy(t *testing.T) {
|
||||
sql := "SELECT user_id, COUNT(*) as order_count FROM orders GROUP BY user_id ORDER BY order_count DESC LIMIT 10 OFFSET 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 分页测试 ==========
|
||||
|
||||
func TestPgPaginationWithUnion(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 50 OFFSET 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause for UNION")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationUnionAll(t *testing.T) {
|
||||
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 30 OFFSET 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION ALL")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 30 {
|
||||
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgMultipleUnionsPagination(t *testing.T) {
|
||||
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 ORDER BY 1 LIMIT 100 OFFSET 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 2 {
|
||||
t.Fatalf("expected 2 UNIONs, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 子查询分页测试 ==========
|
||||
|
||||
func TestPgPaginationWithSubquery(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id LIMIT 100 OFFSET 0) AS tmp LIMIT 10 OFFSET 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatal("expected 1 FROM table")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected outer LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected outer offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected outer count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgNestedPagination(t *testing.T) {
|
||||
// 嵌套分页查询
|
||||
sql := "SELECT * FROM (SELECT * FROM users LIMIT 100 OFFSET 0) AS tmp LIMIT 5 OFFSET 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected outer LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 5 {
|
||||
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgPaginationWithExists(t *testing.T) {
|
||||
sql := "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id LIMIT 1) LIMIT 20 OFFSET 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 大数据量分页测试 ==========
|
||||
|
||||
func TestPgLargeOffsetPagination(t *testing.T) {
|
||||
sql := "SELECT * FROM logs ORDER BY id LIMIT 100 OFFSET 1000000"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 1000000 {
|
||||
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== FOR UPDATE 分页测试 ==========
|
||||
|
||||
func TestPgPaginationForUpdate(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE status = 0 ORDER BY id LIMIT 10 OFFSET 0 FOR UPDATE"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
412
server/internal/db/dbm/sqlparser/pgsql/parser_select_test.go
Normal file
412
server/internal/db/dbm/sqlparser/pgsql/parser_select_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== SELECT 测试 ==========
|
||||
|
||||
func TestPgSelectBasic(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 SELECT 字段
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
if selectStmt.Items[0].Text != "id" {
|
||||
t.Errorf("expected Items[0]='id', got '%s'", selectStmt.Items[0].Text)
|
||||
}
|
||||
if selectStmt.Items[1].Text != "name" {
|
||||
t.Errorf("expected Items[1]='name', got '%s'", selectStmt.Items[1].Text)
|
||||
}
|
||||
|
||||
// 验证 FROM 表名
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "status = 1" {
|
||||
t.Errorf("expected WHERE='status = 1', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
|
||||
// 验证 ORDER BY
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "id" {
|
||||
t.Errorf("expected OrderBy[0].Text='id', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
if !selectStmt.OrderBy[0].Desc {
|
||||
t.Error("expected OrderBy[0].Desc=true")
|
||||
}
|
||||
|
||||
// 验证 LIMIT
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT")
|
||||
}
|
||||
if selectStmt.Limit.Text != "LIMIT 10" {
|
||||
t.Errorf("expected LIMIT.Text='LIMIT 10', got '%s'", selectStmt.Limit.Text)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected LIMIT.Count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSelectDistinct(t *testing.T) {
|
||||
sql := "SELECT DISTINCT id, name FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if !selectStmt.Distinct {
|
||||
t.Error("expected Distinct=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSelectStar(t *testing.T) {
|
||||
sql := "SELECT * FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 SELECT *
|
||||
if len(selectStmt.Items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(selectStmt.Items))
|
||||
}
|
||||
if selectStmt.Items[0].Text != "*" {
|
||||
t.Errorf("expected Items[0]='*', got '%s'", selectStmt.Items[0].Text)
|
||||
}
|
||||
|
||||
// 验证 FROM
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSelectWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证主表
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
||||
}
|
||||
if selectStmt.From[0].Alias != "u" {
|
||||
t.Errorf("expected From[0].Alias='u', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
|
||||
// 验证 JOIN
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
if selectStmt.Joins[0].Table.Name != "orders" {
|
||||
t.Errorf("expected Join[0].Table.Name='orders', got '%s'", selectStmt.Joins[0].Table.Name)
|
||||
}
|
||||
if selectStmt.Joins[0].Table.Alias != "o" {
|
||||
t.Errorf("expected Join[0].Table.Alias='o', got '%s'", selectStmt.Joins[0].Table.Alias)
|
||||
}
|
||||
if selectStmt.Joins[0].Kind != sqlstmt.JoinKindLeft {
|
||||
t.Errorf("expected Join[0].Kind=JoinKindLeft, got '%d'", selectStmt.Joins[0].Kind)
|
||||
}
|
||||
if selectStmt.Joins[0].On == nil {
|
||||
t.Fatal("expected Join[0].On")
|
||||
}
|
||||
if selectStmt.Joins[0].On.Text != "u.id = o.user_id" {
|
||||
t.Errorf("expected Join[0].On.Text='u.id = o.user_id', got '%s'", selectStmt.Joins[0].On.Text)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "u.status = 1" {
|
||||
t.Errorf("expected WHERE='u.status = 1', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSelectComplexWhere(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 FROM
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
|
||||
// 验证 ORDER BY
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "name" {
|
||||
t.Errorf("expected OrderBy[0].Text='name', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgOffsetLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM users OFFSET 10 LIMIT 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 LIMIT/OFFSET
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT")
|
||||
}
|
||||
if selectStmt.Limit.Text != "OFFSET 10 LIMIT 20" {
|
||||
t.Errorf("expected LIMIT.Text='OFFSET 10 LIMIT 20', got '%s'", selectStmt.Limit.Text)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected LIMIT.Count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected LIMIT.Offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgLimitAll(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT ALL"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT ALL" {
|
||||
t.Errorf("expected LIMIT text='LIMIT ALL'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUnion(t *testing.T) {
|
||||
sql := "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 UNION 数量
|
||||
if len(selectStmt.Unions) != 2 {
|
||||
t.Fatalf("expected 2 unions, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
|
||||
// 验证第一个 UNION(DISTINCT)
|
||||
if selectStmt.Unions[0].All {
|
||||
t.Error("expected Unions[0].All=false (DISTINCT)")
|
||||
}
|
||||
if selectStmt.Unions[0].Select == nil {
|
||||
t.Fatal("expected Unions[0].Select")
|
||||
}
|
||||
|
||||
// 验证第二个 UNION(ALL)
|
||||
if !selectStmt.Unions[1].All {
|
||||
t.Error("expected Unions[1].All=true")
|
||||
}
|
||||
if selectStmt.Unions[1].Select == nil {
|
||||
t.Fatal("expected Unions[1].Select")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== FOR UPDATE 测试 ==========
|
||||
|
||||
func TestPgForUpdate(t *testing.T) {
|
||||
sql := "SELECT id FROM users WHERE id = 1 FOR UPDATE"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
// 验证完整文本
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
// 验证 FROM
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
||||
}
|
||||
|
||||
// 验证 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
|
||||
// 注意:FOR UPDATE 标记可能在 Base.Text 中体现
|
||||
// 只要完整文本包含 FOR UPDATE 即可
|
||||
if !strings.Contains(selectStmt.GetText(), "FOR UPDATE") {
|
||||
t.Error("expected text to contain 'FOR UPDATE'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgForUpdateSkipLocked(t *testing.T) {
|
||||
sql := "SELECT id FROM users WHERE status = 1 FOR UPDATE SKIP LOCKED"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
|
||||
t.Errorf("expected WHERE='status = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== RETURNING 测试 ==========
|
||||
|
||||
func TestPgUpdateReturning(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John' WHERE id = 1 RETURNING id, name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
if stmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgDeleteReturning(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id = 1 RETURNING *"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
if stmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DDL 测试 ==========
|
||||
|
||||
func TestPgDDL(t *testing.T) {
|
||||
sqls := []string{
|
||||
"CREATE TABLE users (id INT PRIMARY KEY)",
|
||||
"DROP TABLE users",
|
||||
"ALTER TABLE users ADD COLUMN email VARCHAR(255)",
|
||||
}
|
||||
|
||||
for _, sql := range sqls {
|
||||
t.Run(sql[:10], func(t *testing.T) {
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
if stmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
368
server/internal/db/dbm/sqlparser/pgsql/parser_subquery_test.go
Normal file
368
server/internal/db/dbm/sqlparser/pgsql/parser_subquery_test.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 子查询测试 ==========
|
||||
|
||||
func TestPgSubqueryInFrom(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 表数量: %d", len(selectStmt.From))
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
t.Logf("FROM[0] Name: %s", selectStmt.From[0].Name)
|
||||
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
|
||||
|
||||
if selectStmt.From[0].Alias != "u" {
|
||||
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE clause")
|
||||
}
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
if selectStmt.Where.Text != "u.id > 10" {
|
||||
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSubqueryInWhere(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSubqueryInSelect(t *testing.T) {
|
||||
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
for i, item := range selectStmt.Items {
|
||||
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
|
||||
}
|
||||
|
||||
if len(selectStmt.Items) != 3 {
|
||||
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
if selectStmt.Items[2].Alias != "order_count" {
|
||||
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgNestedSubquery(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "outer_u.id > 5" {
|
||||
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSubqueryWithUnion(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id FROM users UNION SELECT id FROM admins) AS all_users WHERE all_users.id > 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Where.Text != "all_users.id > 10" {
|
||||
t.Errorf("expected WHERE text='all_users.id > 10', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgCorrelatedSubquery(t *testing.T) {
|
||||
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSubqueryWithExists(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgMultipleSubqueries(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgSubqueryWithLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 数量: %d", len(selectStmt.From))
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
if selectStmt.From[0].Alias != "top_users" {
|
||||
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgComplexSubquery(t *testing.T) {
|
||||
sql := `SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
|
||||
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
|
||||
FROM users u
|
||||
WHERE u.status = 1
|
||||
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
|
||||
ORDER BY u.name`
|
||||
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
if len(selectStmt.Items) != 4 {
|
||||
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
|
||||
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
|
||||
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
|
||||
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "u.name" {
|
||||
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== JOIN 测试 ==========
|
||||
|
||||
func TestPgMultipleJoins(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("JOIN 数量: %d", len(selectStmt.Joins))
|
||||
|
||||
if len(selectStmt.Joins) != 2 {
|
||||
t.Fatalf("expected 2 joins, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
|
||||
if selectStmt.Joins[0].Table.Name != "orders" {
|
||||
t.Errorf("expected first join table='orders', got '%s'", selectStmt.Joins[0].Table.Name)
|
||||
}
|
||||
if selectStmt.Joins[1].Table.Name != "products" {
|
||||
t.Errorf("expected second join table='products', got '%s'", selectStmt.Joins[1].Table.Name)
|
||||
}
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Where.Text != "u.status = 1" {
|
||||
t.Errorf("expected WHERE text='u.status = 1', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgRightJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgCrossJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users CROSS JOIN roles"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 测试 ==========
|
||||
|
||||
func TestPgUnionWithOrderBy(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("UNION 数量: %d", len(selectStmt.Unions))
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatalf("expected 1 union, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
|
||||
// ORDER BY 应该属于整个 UNION
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "1" {
|
||||
t.Errorf("expected order by text='1', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
if !selectStmt.OrderBy[0].Desc {
|
||||
t.Error("expected order by DESC")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPgUnionWithLimit(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION ALL SELECT id FROM admins LIMIT 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT")
|
||||
}
|
||||
if selectStmt.Limit.Text != "LIMIT 20" {
|
||||
t.Errorf("expected LIMIT text='LIMIT 20', got '%s'", selectStmt.Limit.Text)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected LIMIT count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
@@ -1,60 +1,17 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/sqlparser/base"
|
||||
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
|
||||
"mayfly-go/pkg/logx"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
"mayfly-go/pkg/gox"
|
||||
"mayfly-go/pkg/logx"
|
||||
)
|
||||
|
||||
func GetPgsqlParserTree(baseLine int, statement string) (antlr.ParseTree, *antlr.CommonTokenStream, error) {
|
||||
lexer := pgparser.NewPostgreSQLLexer(antlr.NewInputStream(statement))
|
||||
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
|
||||
parser := pgparser.NewPostgreSQLParser(stream)
|
||||
|
||||
lexerErrorListener := &base.ParseErrorListener{
|
||||
BaseLine: baseLine,
|
||||
}
|
||||
lexer.RemoveErrorListeners()
|
||||
lexer.AddErrorListener(lexerErrorListener)
|
||||
|
||||
parserErrorListener := &base.ParseErrorListener{
|
||||
BaseLine: baseLine,
|
||||
}
|
||||
parser.RemoveErrorListeners()
|
||||
parser.AddErrorListener(parserErrorListener)
|
||||
parser.BuildParseTrees = true
|
||||
|
||||
tree := parser.Root()
|
||||
|
||||
if lexerErrorListener.Err != nil {
|
||||
return nil, nil, lexerErrorListener.Err
|
||||
}
|
||||
|
||||
if parserErrorListener.Err != nil {
|
||||
return nil, nil, parserErrorListener.Err
|
||||
}
|
||||
|
||||
return tree, stream, nil
|
||||
}
|
||||
|
||||
type PgsqlParser struct {
|
||||
}
|
||||
|
||||
func (*PgsqlParser) Parse(stmt string) (stmts []sqlstmt.Stmt, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
logx.ErrorTrace("postgres sql parser err: ", e)
|
||||
err = e.(error)
|
||||
}
|
||||
}()
|
||||
tree, _, err := GetPgsqlParserTree(1, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tree.Accept(new(PgsqlVisitor)).([]sqlstmt.Stmt), nil
|
||||
func (*PgsqlParser) Parse(stmt string) (sqlstmt.Stmt, error) {
|
||||
defer gox.Recover(func(e error) {
|
||||
logx.ErrorTrace("postgres sql parser err: ", e)
|
||||
})
|
||||
return NewParser(stmt).Parse()
|
||||
}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParserSimpleSelect(t *testing.T) {
|
||||
parser := new(PgsqlParser)
|
||||
|
||||
sql := `SELECT t.*,t.id as tid FROM mayfly.sys_login_log as t where t.id > 0 OFFSET 0 LIMIT 25; select * from tdb where id > 1`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stmt := stmts[0].(*sqlstmt.SimpleSelectStmt)
|
||||
|
||||
t.Log(stmt.QuerySpecification.Where.GetText())
|
||||
fmt.Println(stmt.QuerySpecification.From.GetText())
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserUnionSelect(t *testing.T) {
|
||||
parser := new(PgsqlParser)
|
||||
|
||||
sql := `(select sum(t.age), t.id tid, t1.id2, t1.* from T_DB t join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc) union all (select * from t_db2) OFFSET 0 LIMIT 25;`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserSingleUpdate(t *testing.T) {
|
||||
parser := new(PgsqlParser)
|
||||
|
||||
sql := `UPDATE test.t_sys_msg t
|
||||
SET
|
||||
recipient_id = 13,
|
||||
t.creator = 'admin4'
|
||||
WHERE
|
||||
t.id = 1;`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserDelete(t *testing.T) {
|
||||
parser := new(PgsqlParser)
|
||||
|
||||
sql := `Delete from t_sys_msg t
|
||||
WHERE
|
||||
t.id = 1;`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserInsert(t *testing.T) {
|
||||
parser := new(PgsqlParser)
|
||||
|
||||
sql := `INSERT INTO
|
||||
mayfly_go.t_sys_msg (
|
||||
type,
|
||||
msg,
|
||||
recipient_id,
|
||||
creator_id,
|
||||
create_time,
|
||||
is_deleted
|
||||
)
|
||||
VALUES
|
||||
(1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
@@ -1,493 +0,0 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type PgsqlVisitor struct {
|
||||
*pgparser.BasePostgreSQLParserVisitor
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitRoot(ctx *pgparser.RootContext) interface{} {
|
||||
if sbc := ctx.Stmtblock(); sbc != nil {
|
||||
return sbc.Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitPlsqlroot(ctx *pgparser.PlsqlrootContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitStmtblock(ctx *pgparser.StmtblockContext) interface{} {
|
||||
if smc := ctx.Stmtmulti(); smc != nil {
|
||||
return smc.Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitStmtmulti(ctx *pgparser.StmtmultiContext) interface{} {
|
||||
allSqlStatement := ctx.AllStmt()
|
||||
stmts := make([]sqlstmt.Stmt, 0)
|
||||
for _, sqlStatement := range allSqlStatement {
|
||||
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
|
||||
}
|
||||
return stmts
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitStmt(ctx *pgparser.StmtContext) interface{} {
|
||||
if selectstmtCtx := ctx.Selectstmt(); selectstmtCtx != nil {
|
||||
return selectstmtCtx.Accept(v)
|
||||
}
|
||||
if updatestmtCtx := ctx.Updatestmt(); updatestmtCtx != nil {
|
||||
return updatestmtCtx.Accept(v)
|
||||
}
|
||||
if deletestmtCtx := ctx.Deletestmt(); deletestmtCtx != nil {
|
||||
return deletestmtCtx.Accept(v)
|
||||
}
|
||||
if insertstmtC := ctx.Insertstmt(); insertstmtC != nil {
|
||||
return insertstmtC.Accept(v)
|
||||
}
|
||||
if c := ctx.Createdbstmt(); c != nil {
|
||||
cds := new(sqlstmt.CreateDatabase)
|
||||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return cds
|
||||
}
|
||||
if c := ctx.Createtablespacestmt(); c != nil {
|
||||
cds := new(sqlstmt.CreateTable)
|
||||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return cds
|
||||
}
|
||||
if c := ctx.Altertablestmt(); c != nil {
|
||||
cds := new(sqlstmt.AlterTable)
|
||||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return cds
|
||||
}
|
||||
if c := ctx.Dropdbstmt(); c != nil {
|
||||
cds := new(sqlstmt.DropDatabase)
|
||||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return cds
|
||||
}
|
||||
if c := ctx.Droptablespacestmt(); c != nil {
|
||||
cds := new(sqlstmt.DropTable)
|
||||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return cds
|
||||
}
|
||||
if explain := ctx.Explainstmt(); explain != nil {
|
||||
otherRead := new(sqlstmt.OtherReadStmt)
|
||||
otherRead.Node = sqlstmt.NewNode(explain.GetParser(), explain)
|
||||
return otherRead
|
||||
}
|
||||
if c := ctx.Variableshowstmt(); c != nil {
|
||||
otherRead := new(sqlstmt.OtherReadStmt)
|
||||
otherRead.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||||
return otherRead
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSelectstmt(ctx *pgparser.SelectstmtContext) interface{} {
|
||||
if spnc := ctx.Select_no_parens(); spnc != nil {
|
||||
return spnc.Accept(v)
|
||||
}
|
||||
selectstmt := new(sqlstmt.SelectStmt)
|
||||
selectstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return selectstmt
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSelect_with_parens(ctx *pgparser.Select_with_parensContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSelect_no_parens(ctx *pgparser.Select_no_parensContext) interface{} {
|
||||
if c := ctx.Select_clause(); c == nil {
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
var limit *sqlstmt.Limit
|
||||
if limitC := ctx.Select_limit(); limitC != nil {
|
||||
limit = limitC.Accept(v).(*sqlstmt.Limit)
|
||||
}
|
||||
if limitC := ctx.Opt_select_limit(); limitC != nil {
|
||||
limit = limitC.Accept(v).(*sqlstmt.Limit)
|
||||
}
|
||||
|
||||
selectClause := ctx.Select_clause()
|
||||
asis := selectClause.AllSimple_select_intersect()
|
||||
// 简单查询
|
||||
if len(asis) == 1 {
|
||||
sss := new(sqlstmt.SimpleSelectStmt)
|
||||
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
sss.QuerySpecification = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)[0]
|
||||
sss.QuerySpecification.Limit = limit
|
||||
return sss
|
||||
}
|
||||
|
||||
uss := new(sqlstmt.UnionSelectStmt)
|
||||
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
allUnion := selectClause.AllUNION()
|
||||
// todo 赋值union信息
|
||||
for _, union := range allUnion {
|
||||
uss.UnionType = union.GetText()
|
||||
}
|
||||
// uss.QuerySpecifications = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)
|
||||
uss.Limit = limit
|
||||
return uss
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSelect_clause(ctx *pgparser.Select_clauseContext) interface{} {
|
||||
qs := make([]*sqlstmt.QuerySpecification, 0)
|
||||
for _, ssi := range ctx.AllSimple_select_intersect() {
|
||||
qs = append(qs, ssi.Accept(v).(*sqlstmt.QuerySpecification))
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSimple_select_intersect(ctx *pgparser.Simple_select_intersectContext) interface{} {
|
||||
// 只返回一个查询,INTERSECT(交集)暂不支持
|
||||
if spsc := ctx.AllSimple_select_pramary(); spsc != nil {
|
||||
return spsc[0].Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSimple_select_pramary(ctx *pgparser.Simple_select_pramaryContext) interface{} {
|
||||
qs := new(sqlstmt.QuerySpecification)
|
||||
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if c := ctx.From_clause(); c != nil {
|
||||
qs.From = c.Accept(v).(*sqlstmt.TableSources)
|
||||
}
|
||||
|
||||
if c := ctx.Opt_target_list(); c != nil {
|
||||
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
|
||||
}
|
||||
|
||||
if c := ctx.Target_list(); c != nil {
|
||||
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
|
||||
}
|
||||
|
||||
if c := ctx.Where_clause(); c != nil && c.A_expr() != nil {
|
||||
qs.Where = c.A_expr().Accept(v).(sqlstmt.IExpr)
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSelect_limit(ctx *pgparser.Select_limitContext) interface{} {
|
||||
limit := new(sqlstmt.Limit)
|
||||
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if lc := ctx.Limit_clause(); lc != nil {
|
||||
if lv := lc.Select_limit_value(); lv != nil {
|
||||
limit.RowCount = cast.ToInt(lv.GetText())
|
||||
}
|
||||
}
|
||||
if oc := ctx.Offset_clause(); oc != nil {
|
||||
if ov := oc.Select_offset_value(); ov != nil {
|
||||
limit.Offset = cast.ToInt(ov.GetText())
|
||||
}
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitOpt_select_limit(ctx *pgparser.Opt_select_limitContext) interface{} {
|
||||
if slc := ctx.Select_limit(); slc != nil {
|
||||
return slc.Accept(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitFrom_clause(ctx *pgparser.From_clauseContext) interface{} {
|
||||
if c := ctx.From_list(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitFrom_list(ctx *pgparser.From_listContext) interface{} {
|
||||
ts := new(sqlstmt.TableSources)
|
||||
ts.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
// ts.StartIndex = ctx.GetStart().GetStart()
|
||||
// ts.StopIndex = ctx.GetStop().GetStop()
|
||||
|
||||
tableSources := make([]sqlstmt.ITableSource, 0)
|
||||
allTableRefCtx := ctx.AllTable_ref()
|
||||
for _, trc := range allTableRefCtx {
|
||||
tableSources = append(tableSources, trc.Accept(v).(sqlstmt.ITableSource))
|
||||
}
|
||||
|
||||
ts.TableSources = tableSources
|
||||
return ts
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitNon_ansi_join(ctx *pgparser.Non_ansi_joinContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitTable_ref(ctx *pgparser.Table_refContext) interface{} {
|
||||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||||
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
atomTable := new(sqlstmt.AtomTableItem)
|
||||
|
||||
if c := ctx.Relation_expr(); c != nil {
|
||||
tableName := new(sqlstmt.TableName)
|
||||
if qn := c.Qualified_name(); qn != nil {
|
||||
if qc := qn.Colid(); qc != nil {
|
||||
if c := qn.Indirection(); c != nil {
|
||||
tableName.Owner = qc.GetText()
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
|
||||
} else {
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
|
||||
}
|
||||
}
|
||||
}
|
||||
atomTable.TableName = tableName
|
||||
}
|
||||
if c := ctx.Opt_alias_clause(); c != nil {
|
||||
if aliasC := c.Table_alias_clause(); aliasC != nil {
|
||||
atomTable.Alias = aliasC.Table_alias().GetText()
|
||||
}
|
||||
}
|
||||
|
||||
tableSourceBase.TableSourceItem = atomTable
|
||||
return tableSourceBase
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitOpt_target_list(ctx *pgparser.Opt_target_listContext) interface{} {
|
||||
if c := ctx.Target_list(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
|
||||
ses := new(sqlstmt.SelectElements)
|
||||
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ses
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitTarget_list(ctx *pgparser.Target_listContext) interface{} {
|
||||
ses := new(sqlstmt.SelectElements)
|
||||
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if tecs := ctx.AllTarget_el(); tecs != nil {
|
||||
eles := make([]sqlstmt.ISelectElement, 0)
|
||||
for _, tec := range tecs {
|
||||
eles = append(eles, tec.Accept(v).(sqlstmt.ISelectElement))
|
||||
}
|
||||
ses.Elements = eles
|
||||
}
|
||||
if len(ses.Elements) == 1 && ses.Elements[0].GetText() == "*" {
|
||||
ses.Star = "*"
|
||||
}
|
||||
|
||||
return ses
|
||||
}
|
||||
|
||||
// Visit a parse tree produced by PostgreSQLParser#target_label.
|
||||
func (v *PgsqlVisitor) VisitTarget_label(ctx *pgparser.Target_labelContext) interface{} {
|
||||
sce := new(sqlstmt.SelectColumnElement)
|
||||
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
columnName := new(sqlstmt.ColumnName)
|
||||
|
||||
if c := ctx.Collabel(); c != nil {
|
||||
sce.Alias = c.GetText()
|
||||
}
|
||||
if c := ctx.Identifier(); c != nil {
|
||||
sce.Alias = c.GetText()
|
||||
}
|
||||
if exprCtx := ctx.A_expr(); exprCtx != nil {
|
||||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), exprCtx)
|
||||
if aextrCtx := exprCtx.A_expr_qual(); aextrCtx != nil {
|
||||
col := aextrCtx.GetText()
|
||||
ownerAndColname := strings.Split(col, ".")
|
||||
if len(ownerAndColname) == 2 {
|
||||
columnName.Owner = ownerAndColname[0]
|
||||
columnName.Identifier = sqlstmt.NewIdentifierValue(ownerAndColname[1])
|
||||
} else {
|
||||
columnName.Identifier = sqlstmt.NewIdentifierValue(col)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
sce.ColumnName = columnName
|
||||
return sce
|
||||
}
|
||||
|
||||
// Visit a parse tree produced by PostgreSQLParser#target_star.
|
||||
func (v *PgsqlVisitor) VisitTarget_star(ctx *pgparser.Target_starContext) interface{} {
|
||||
sse := new(sqlstmt.SelectStarElement)
|
||||
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
sse.FullId = ctx.STAR().GetText()
|
||||
return sse
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitAlias_clause(ctx *pgparser.Alias_clauseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitOpt_alias_clause(ctx *pgparser.Opt_alias_clauseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitTable_alias_clause(ctx *pgparser.Table_alias_clauseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitFunc_alias_clause(ctx *pgparser.Func_alias_clauseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitRelation_expr(ctx *pgparser.Relation_exprContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitRelation_expr_list(ctx *pgparser.Relation_expr_listContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitUpdatestmt(ctx *pgparser.UpdatestmtContext) interface{} {
|
||||
updateStmt := new(sqlstmt.UpdateStmt)
|
||||
updateStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
updateStmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
|
||||
updateStmt.UpdatedElements = ctx.Set_clause_list().Accept(v).([]*sqlstmt.UpdatedElement)
|
||||
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
|
||||
updateStmt.Where = ec.Accept(v).(sqlstmt.IExpr)
|
||||
}
|
||||
|
||||
return updateStmt
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSet_clause_list(ctx *pgparser.Set_clause_listContext) interface{} {
|
||||
ues := make([]*sqlstmt.UpdatedElement, 0)
|
||||
aucs := ctx.AllSet_clause()
|
||||
for _, auc := range aucs {
|
||||
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
|
||||
}
|
||||
return ues
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSet_clause(ctx *pgparser.Set_clauseContext) interface{} {
|
||||
updateEle := new(sqlstmt.UpdatedElement)
|
||||
updateEle.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
updateEle.ColumnName = ctx.Set_target().Accept(v).(*sqlstmt.ColumnName)
|
||||
if ac := ctx.A_expr(); ac != nil {
|
||||
updateEle.Value = ac.Accept(v).(sqlstmt.IExpr)
|
||||
}
|
||||
return updateEle
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSet_target(ctx *pgparser.Set_targetContext) interface{} {
|
||||
columnName := new(sqlstmt.ColumnName)
|
||||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if ic := ctx.Opt_indirection(); ic != nil {
|
||||
if ic.GetText() == "" {
|
||||
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
|
||||
} else {
|
||||
columnName.Owner = ctx.Colid().GetText()
|
||||
columnName.Identifier = sqlstmt.NewIdentifierValue(ic.GetText())
|
||||
}
|
||||
} else {
|
||||
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
|
||||
}
|
||||
return columnName
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitSet_target_list(ctx *pgparser.Set_target_listContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitA_expr(ctx *pgparser.A_exprContext) interface{} {
|
||||
expr := new(sqlstmt.Expr)
|
||||
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return expr
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitA_expr_qual(ctx *pgparser.A_expr_qualContext) interface{} {
|
||||
expr := new(sqlstmt.Expr)
|
||||
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return expr
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitDeletestmt(ctx *pgparser.DeletestmtContext) interface{} {
|
||||
deletestmt := new(sqlstmt.DeleteStmt)
|
||||
deletestmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
deletestmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
|
||||
|
||||
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
|
||||
deletestmt.Where = ec.Accept(v).(sqlstmt.IExpr)
|
||||
}
|
||||
return deletestmt
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitInsertstmt(ctx *pgparser.InsertstmtContext) interface{} {
|
||||
insertstmt := new(sqlstmt.InsertStmt)
|
||||
insertstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
insertstmt.TableName = ctx.Insert_target().Accept(v).(*sqlstmt.TableName)
|
||||
return insertstmt
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitInsert_target(ctx *pgparser.Insert_targetContext) interface{} {
|
||||
tableName := new(sqlstmt.TableName)
|
||||
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
table := ctx.GetText()
|
||||
if strings.Contains(table, ".") {
|
||||
tableAndOwner := strings.Split(table, ".")
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(tableAndOwner[1])
|
||||
tableName.Owner = tableAndOwner[0]
|
||||
} else {
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(table)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) VisitInsert_rest(ctx *pgparser.Insert_restContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *PgsqlVisitor) GetTableSourcesByrelation_expr_opt_alias(ctx pgparser.IRelation_expr_opt_aliasContext) *sqlstmt.TableSources {
|
||||
tableSources := new(sqlstmt.TableSources)
|
||||
tableSources.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
atomTable := new(sqlstmt.AtomTableItem)
|
||||
atomTable.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if c := ctx.Relation_expr(); c != nil {
|
||||
tableName := new(sqlstmt.TableName)
|
||||
if qn := c.Qualified_name(); qn != nil {
|
||||
if qc := qn.Colid(); qc != nil {
|
||||
if c := qn.Indirection(); c != nil {
|
||||
tableName.Owner = qc.GetText()
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
|
||||
} else {
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
|
||||
}
|
||||
}
|
||||
}
|
||||
atomTable.TableName = tableName
|
||||
}
|
||||
|
||||
if c := ctx.Colid(); c != nil {
|
||||
atomTable.Alias = c.GetText()
|
||||
}
|
||||
|
||||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||||
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
tableSourceBase.TableSourceItem = atomTable
|
||||
|
||||
tableSources.TableSources = []sqlstmt.ITableSource{tableSourceBase}
|
||||
return tableSources
|
||||
}
|
||||
Reference in New Issue
Block a user