diff --git a/mayfly_go_web/src/common/sockets.ts b/mayfly_go_web/src/common/sockets.ts index cb678605..db61ca91 100644 --- a/mayfly_go_web/src/common/sockets.ts +++ b/mayfly_go_web/src/common/sockets.ts @@ -1,7 +1,10 @@ import Config from './config'; -import { ElNotification } from 'element-plus'; +import { ElNotification, NotificationHandle } from 'element-plus'; import SocketBuilder from './SocketBuilder'; import { getToken } from '@/common/utils/storage'; +import { createVNode, reactive } from "vue"; +import { buildProgressProps } from "@/components/progress-notify/progress-notify"; +import ProgressNotify from '/src/components/progress-notify/progress-notify.vue'; export default { /** @@ -12,32 +15,63 @@ export default { if (!token) { return null; } + const messageTypes = { + 0: "error", + 1: "success", + 2: "info", + } + const notifyMap: Map = new Map() + return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`) .message((event: { data: string }) => { const message = JSON.parse(event.data); - let mtype: string; - switch (message.type) { - case 0: - mtype = 'error'; - break; - case 2: - mtype = 'info'; - break; - case 1: - mtype = 'success'; + const type = messageTypes[message.type] + switch (message.category) { + case "execSqlFileProgress": + const content = JSON.parse(message.msg) + const id = content.id + let progress = notifyMap.get(id) + if (content.terminated) { + if (progress != undefined) { + progress.notification?.close() + notifyMap.delete(id) + progress = undefined + } + return + } + if (progress == undefined) { + progress = { + props: reactive(buildProgressProps()), + notification: undefined, + } + } + progress.props.progress.sqlFileName = content.sqlFileName + progress.props.progress.executedStatements = content.executedStatements + if (!notifyMap.has(id)) { + const vNodeMessage = createVNode( + ProgressNotify, + progress.props, + null, + ) + progress.notification = ElNotification({ + duration: 0, + title: message.title, + message: vNodeMessage, + type: type, + showClose: false, + }); + notifyMap.set(id, progress) + } break; default: - mtype = 'info'; + ElNotification({ + duration: 0, + title: message.title, + message: message.msg, + type: type, + }); + break; } - if (mtype == undefined) { - return; - } - ElNotification({ - duration: 0, - title: message.title, - message: message.msg, - type: mtype as any, - }); }) .open((event: any) => console.log(event)) .build(); diff --git a/mayfly_go_web/src/components/progress-notify/progress-notify.ts b/mayfly_go_web/src/components/progress-notify/progress-notify.ts new file mode 100644 index 00000000..396536cb --- /dev/null +++ b/mayfly_go_web/src/components/progress-notify/progress-notify.ts @@ -0,0 +1,14 @@ +export const buildProgressProps = () : any => { + return { + progress: { + sqlFileName: { + type: String + }, + executedStatements: { + type: Number + }, + }, + }; +} + + diff --git a/mayfly_go_web/src/components/progress-notify/progress-notify.vue b/mayfly_go_web/src/components/progress-notify/progress-notify.vue new file mode 100644 index 00000000..62d40f94 --- /dev/null +++ b/mayfly_go_web/src/components/progress-notify/progress-notify.vue @@ -0,0 +1,41 @@ + + \ No newline at end of file diff --git a/mayfly_go_web/src/router/index.ts b/mayfly_go_web/src/router/index.ts index 01a80fa6..74f70f50 100644 --- a/mayfly_go_web/src/router/index.ts +++ b/mayfly_go_web/src/router/index.ts @@ -206,7 +206,7 @@ router.beforeEach(async (to, from, next) => { if (SysWs) { SysWs.close(); - SysWs = null; + SysWs = undefined; } return; } diff --git a/server/go.mod b/server/go.mod index 348292e5..5ea46c5b 100644 --- a/server/go.mod +++ b/server/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/golang-jwt/jwt/v5 v5.0.0 github.com/gorilla/websocket v1.5.0 + github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31 github.com/lib/pq v1.10.9 github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d github.com/mojocn/base64Captcha v1.3.5 // 验证码 @@ -21,45 +22,52 @@ require ( github.com/pquerna/otp v1.4.0 github.com/redis/go-redis/v9 v9.2.1 github.com/robfig/cron/v3 v3.0.1 // 定时任务 - github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + github.com/stretchr/testify v1.8.4 go.mongodb.org/mongo-driver v1.12.1 // mongo golang.org/x/crypto v0.14.0 // ssh golang.org/x/oauth2 v0.13.0 gopkg.in/yaml.v3 v3.0.1 // gorm gorm.io/driver/mysql v1.5.1 - gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.4 ) +require ( + github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + gorm.io/driver/sqlite v1.5.1 +) + require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect + github.com/golang/glog v1.0.0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.13.6 // indirect + github.com/klauspost/compress v1.16.5 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/kr/fs v0.1.0 // indirect - github.com/kr/pretty v0.3.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/montanaflynn/stats v0.7.0 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect @@ -67,12 +75,15 @@ require ( github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect golang.org/x/net v0.16.0 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect + google.golang.org/grpc v1.52.3 // indirect google.golang.org/protobuf v1.31.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + vitess.io/vitess v0.17.3 // indirect ) diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 4db5aa20..fb13518f 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -1,9 +1,12 @@ package api import ( - "bufio" "fmt" + "github.com/lib/pq" "io" + "mayfly-go/pkg/utils/uniqueid" + "mayfly-go/pkg/ws" + "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" @@ -16,14 +19,13 @@ import ( "mayfly-go/pkg/gormx" "mayfly-go/pkg/model" "mayfly-go/pkg/req" + "mayfly-go/pkg/sqlparser" "mayfly-go/pkg/utils/stringx" - "regexp" "strconv" "strings" "time" "github.com/gin-gonic/gin" - "github.com/xwb1989/sqlparser" ) type Db struct { @@ -79,9 +81,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) { } func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection { - dbName := g.Query("db") - biz.NotEmpty(dbName, "db不能为空") - return d.DbApp.GetDbConnection(getDbId(g), dbName) + return d.DbApp.GetDbConnection(getDbId(g), getDbName(g)) } func (d *Db) TableInfos(rc *req.Ctx) { @@ -152,67 +152,119 @@ func (d *Db) ExecSql(rc *req.Ctx) { rc.ResData = colAndRes } +// progressCategory sql文件执行进度消息类型 +const progressCategory = "execSqlFileProgress" + +// progressMsg sql文件执行进度消息 +type progressMsg struct { + Id uint64 `json:"id"` + SqlFileName string `json:"sqlFileName"` + ExecutedStatements int `json:"executedStatements"` + Terminated bool `json:"terminated"` +} + // 执行sql文件 func (d *Db) ExecSqlFile(rc *req.Ctx) { g := rc.GinCtx - fileheader, err := g.FormFile("file") + multipart, err := g.Request.MultipartReader() biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s") - - file, _ := fileheader.Open() - filename := fileheader.Filename + file, err := multipart.NextPart() + biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s") + defer file.Close() + filename := file.FileName() dbId := getDbId(g) dbName := getDbName(g) - dbConn := d.getDbConnection(rc.GinCtx) + dbConn := d.DbApp.GetDbConnection(dbId, dbName) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) - logExecRecord := true - // 如果执行sql文件大于该值则不记录sql执行记录 - if fileheader.Size > 50*1024 { - logExecRecord = false + defer func() { + var errInfo string + switch t := recover().(type) { + case biz.BizError: + errInfo = t.Error() + case *biz.BizError: + errInfo = t.Error() + case string: + errInfo = t + } + if len(errInfo) > 0 { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) + } + }() + + execReq := &application.DbSqlExecReq{ + DbId: dbId, + Db: dbName, + Remark: filename, + DbConn: dbConn, + LoginAccount: rc.LoginAccount, } - go func() { - defer func() { - if err := recover(); err != nil { - var errInfo string - switch t := err.(type) { - case error: - errInfo = t.Error() - } - if len(errInfo) > 0 { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) - } - } - }() + progressId := uniqueid.IncrementID() + executedStatements := 0 + defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ + Id: progressId, + SqlFileName: filename, + ExecutedStatements: executedStatements, + Terminated: true, + }).WithCategory(progressCategory)) - execReq := &application.DbSqlExecReq{ - DbId: dbId, - Db: dbName, - Remark: fileheader.Filename, - DbConn: dbConn, - LoginAccount: rc.LoginAccount, + var parser sqlparser.Parser + if dbConn.Info.Type == entity.DbTypeMysql { + parser = sqlparser.NewMysqlParser(file) + } else { + parser = sqlparser.NewPostgresParser(file) + } + + ticker := time.NewTicker(time.Second * 1) + defer ticker.Stop() + for { + select { + case <-ticker.C: + ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ + Id: progressId, + SqlFileName: filename, + ExecutedStatements: executedStatements, + Terminated: false, + }).WithCategory(progressCategory)) + default: } - - sqlScanner := SplitSqls(file) - for sqlScanner.Scan() { - sql := sqlScanner.Text() + err = parser.Next() + if err == io.EOF { + break + } + if err != nil { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error()))) + return + } + sql := parser.Current() + const prefixUse = "use " + if strings.HasPrefix(sql, prefixUse) { + dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n") + if len(dbNameExec) > 0 { + dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec) + biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") + execReq.DbConn = dbConn + } + } + // 需要记录执行记录 + const maxRecordStatements = 64 + if executedStatements < maxRecordStatements { execReq.Sql = sql - // 需要记录执行记录 - if logExecRecord { - _, err = d.DbSqlExecApp.Exec(execReq) - } else { - _, err = dbConn.Exec(sql) - } - - if err != nil { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error()))) - return - } + _, err = d.DbSqlExecApp.Exec(execReq) + } else { + _, err = dbConn.Exec(sql) } - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) - }() + + if err != nil { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error()))) + return + } + executedStatements++ + } + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) } // 数据库dump @@ -276,23 +328,44 @@ func (d *Db) DumpSql(rc *req.Ctx) { rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType) } +func escapeSql(dbType string, sql string) string { + if dbType == entity.DbTypePostgres { + return pq.QuoteLiteral(sql) + } else { + sql = strings.ReplaceAll(sql, `\`, `\\`) + sql = strings.ReplaceAll(sql, `'`, `''`) + return "'" + sql + "'" + } +} + +func quoteTable(dbType string, table string) string { + if dbType == entity.DbTypePostgres { + return "\"" + table + "\"" + } else { + return "`" + table + "`" + } +} + func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) { dbConn := d.DbApp.GetDbConnection(dbId, dbName) - writer.WriteString("-- ----------------------------") + writer.WriteString("\n-- ----------------------------") writer.WriteString("\n-- 导出平台: mayfly-go") writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05"))) writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName)) writer.WriteString("\n-- ----------------------------\n") - writer.TryFlush() if switchDb { switch dbConn.Info.Type { case entity.DbTypeMysql: - writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName)) + writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName)) default: - biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql) + biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql) } } + if dbConn.Info.Type == entity.DbTypeMysql { + writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n") + } + dbMeta := dbConn.GetMeta() if len(tables) == 0 { ti := dbMeta.GetTableInfos() @@ -303,23 +376,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str } for _, table := range tables { + writer.TryFlush() + quotedTable := quoteTable(dbConn.Info.Type, table) if needStruct { writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) - writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table)) + writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable)) writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n") } - if !needData { continue } - writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table)) writer.WriteString("BEGIN;\n") - - insertSql := "INSERT INTO `%s` VALUES (%s);\n" - + insertSql := "INSERT INTO %s VALUES (%s);\n" dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) { var values []string + writer.TryFlush() for _, column := range columns { value := record[column] if value == nil { @@ -328,17 +400,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str } strValue, ok := value.(string) if ok { - values = append(values, fmt.Sprintf("%#v", strValue)) + strValue = escapeSql(dbConn.Info.Type, strValue) + values = append(values, strValue) } else { values = append(values, stringx.AnyToStr(value)) } } - writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", "))) - writer.TryFlush() + writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", "))) }) - writer.WriteString("COMMIT;\n") - writer.TryFlush() + } + if dbConn.Info.Type == entity.DbTypeMysql { + writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n") } } @@ -477,35 +550,3 @@ func getDbName(g *gin.Context) string { biz.NotEmpty(db, "db不能为空") return db } - -// 根据;\n切割sql -func SplitSqls(r io.Reader) *bufio.Scanner { - scanner := bufio.NewScanner(r) - re := regexp.MustCompile(`\s*;\s*\n`) - - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, io.EOF - } - - match := re.FindIndex(data) - - if match != nil { - // 如果找到了";\n",判断是否为最后一行 - if match[1] == len(data) { - // 如果是最后一行,则返回完整的切片 - return len(data), data, nil - } - // 否则,返回到";\n"之后,并且包括";\n"本身 - return match[1], data[:match[1]], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - - return scanner -} diff --git a/server/internal/db/api/db_test.go b/server/internal/db/api/db_test.go new file mode 100644 index 00000000..ec6b9480 --- /dev/null +++ b/server/internal/db/api/db_test.go @@ -0,0 +1,55 @@ +package api + +import ( + "github.com/stretchr/testify/require" + "mayfly-go/internal/db/domain/entity" + "testing" +) + +func Test_escapeSql(t *testing.T) { + tests := []struct { + name string + dbType string + sql string + want string + }{ + { + dbType: entity.DbTypeMysql, + sql: "\\a\\b", + want: "'\\\\a\\\\b'", + }, + { + dbType: entity.DbTypeMysql, + sql: "'a'", + want: "'''a'''", + }, + { + name: "不间断空格", + dbType: entity.DbTypeMysql, + sql: "a\u00A0b", + want: "'a\u00A0b'", + }, + { + dbType: entity.DbTypePostgres, + sql: "\\a\\b", + want: " E'\\\\a\\\\b'", + }, + { + dbType: entity.DbTypePostgres, + sql: "'a'", + want: "'''a'''", + }, + { + name: "不间断空格", + dbType: entity.DbTypePostgres, + sql: "a\u00A0b", + want: "'a\u00A0b'", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := escapeSql(tt.dbType, tt.sql) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 3c5ab4d6..1490bb50 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -3,6 +3,7 @@ package application import ( "encoding/json" "fmt" + "github.com/kanzihuang/vitess/go/vt/sqlparser" "mayfly-go/internal/db/config" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" @@ -10,8 +11,6 @@ import ( "mayfly-go/pkg/model" "strconv" "strings" - - "github.com/xwb1989/sqlparser" ) type DbSqlExecReq struct { diff --git a/server/internal/msg/application/dto/sys_msg.go b/server/internal/msg/application/dto/sys_msg.go index da94ee61..d328903b 100644 --- a/server/internal/msg/application/dto/sys_msg.go +++ b/server/internal/msg/application/dto/sys_msg.go @@ -11,7 +11,7 @@ const InfoSysMsgType = 2 // websocket消息 type SysMsg struct { Type int `json:"type"` // 消息类型 - Category int `json:"category"` // 消息类别 + Category string `json:"category"` // 消息类别 Title string `json:"title"` // 消息标题 Msg string `json:"msg"` // 消息内容 } @@ -21,7 +21,7 @@ func (sm *SysMsg) WithTitle(title string) *SysMsg { return sm } -func (sm *SysMsg) WithCategory(category int) *SysMsg { +func (sm *SysMsg) WithCategory(category string) *SysMsg { sm.Category = category return sm } @@ -32,7 +32,7 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg { } // 普通消息 -func NewSysMsg(title string, msg any) *SysMsg { +func InfoSysMsg(title string, msg any) *SysMsg { return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)} } diff --git a/server/pkg/sqlparser/sqlparser.go b/server/pkg/sqlparser/sqlparser.go new file mode 100644 index 00000000..1215a5d6 --- /dev/null +++ b/server/pkg/sqlparser/sqlparser.go @@ -0,0 +1,99 @@ +package sqlparser + +import ( + "bufio" + "github.com/kanzihuang/vitess/go/vt/sqlparser" + "io" + "regexp" +) + +type Parser interface { + Next() error + Current() string +} + +var _ Parser = &MysqlParser{} +var _ Parser = &PostgresParser{} + +type MysqlParser struct { + tokenizer *sqlparser.Tokenizer + statement string +} + +func NewMysqlParser(reader io.Reader) *MysqlParser { + return &MysqlParser{ + tokenizer: sqlparser.NewReaderTokenizer(reader), + } +} + +func (parser *MysqlParser) Next() error { + statement, err := sqlparser.ParseNext(parser.tokenizer) + if err != nil { + parser.statement = "" + return err + } + parser.statement = sqlparser.String(statement) + return nil +} + +func (parser *MysqlParser) Current() string { + return parser.statement +} + +type PostgresParser struct { + scanner *bufio.Scanner + statement string +} + +func NewPostgresParser(reader io.Reader) *PostgresParser { + return &PostgresParser{ + scanner: splitSqls(reader), + } +} + +func (parser *PostgresParser) Next() error { + if !parser.scanner.Scan() { + return io.EOF + } + return nil +} + +func (parser *PostgresParser) Current() string { + return parser.scanner.Text() +} + +// 根据;\n切割sql +func splitSqls(r io.Reader) *bufio.Scanner { + scanner := bufio.NewScanner(r) + re := regexp.MustCompile(`\s*;\s*\n`) + + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, io.EOF + } + + match := re.FindIndex(data) + + if match != nil { + // 如果找到了";\n",判断是否为最后一行 + if match[1] == len(data) { + // 如果是最后一行,则返回完整的切片 + return len(data), data, nil + } + // 否则,返回到";\n"之后,并且包括";\n"本身 + return match[1], data[:match[1]], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + + return scanner +} + +func SplitStatementToPieces(sql string) ([]string, error) { + return sqlparser.SplitStatementToPieces(sql) +} diff --git a/server/pkg/sqlparser/sqlparser_test.go b/server/pkg/sqlparser/sqlparser_test.go new file mode 100644 index 00000000..1f2b65fd --- /dev/null +++ b/server/pkg/sqlparser/sqlparser_test.go @@ -0,0 +1,98 @@ +package sqlparser + +import ( + "github.com/kanzihuang/vitess/go/vt/sqlparser" + "github.com/stretchr/testify/require" + sqlparser_xwb1989 "github.com/xwb1989/sqlparser" + "strings" + "testing" +) + +func Test_ParseNext_WithCurrentDate(t *testing.T) { + tests := []struct { + name string + input string + want string + wantXwb1989 string + err string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + // xwb1989/sqlparser 不支持 current_timestamp() + wantXwb1989: "create table tbl", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + // xwb1989/sqlparser 不支持 current_date() + wantXwb1989: "create table tbl", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input)) + tree, err := sqlparser.ParseNext(token) + if len(test.err) > 0 { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + return + } + require.NoError(t, err) + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.want, sqlparser.String(tree)) + }) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input)) + tree, err := sqlparser_xwb1989.ParseNext(token) + if len(test.err) > 0 { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + return + } + require.NoError(t, err) + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree)) + }) + } +} + +func Test_SplitSqls(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + }, + { + name: "select with ';\n'", + input: "select 'the first line;\nthe second line;\n'", + // SplitSqls split statements by ';\n' + want: "select 'the first line;\n", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scanner := splitSqls(strings.NewReader(test.input)) + require.True(t, scanner.Scan()) + got := scanner.Text() + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.want, got) + }) + } +} diff --git a/server/pkg/utils/uniqueid/uniqueid.go b/server/pkg/utils/uniqueid/uniqueid.go new file mode 100644 index 00000000..777a429b --- /dev/null +++ b/server/pkg/utils/uniqueid/uniqueid.go @@ -0,0 +1,9 @@ +package uniqueid + +import "sync/atomic" + +var id uint64 = 0 + +func IncrementID() uint64 { + return atomic.AddUint64(&id, 1) +} diff --git a/server/pkg/ws/client_manager.go b/server/pkg/ws/client_manager.go index aee950ca..c1b5f71a 100644 --- a/server/pkg/ws/client_manager.go +++ b/server/pkg/ws/client_manager.go @@ -7,7 +7,7 @@ import ( ) // 心跳间隔 -var heartbeatInterval = 25 * time.Second +const heartbeatInterval = 25 * time.Second // 连接管理 type ClientManager struct {