From 75442884511227f74cd7fce805e9feb51873a82e Mon Sep 17 00:00:00 2001 From: wanli Date: Tue, 10 Oct 2023 09:24:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=98=BE=E7=A4=BA=20?= =?UTF-8?q?SQL=20=E6=96=87=E4=BB=B6=E6=89=A7=E8=A1=8C=E8=BF=9B=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mayfly_go_web/src/common/sockets.ts | 76 +++++-- .../progress-notify/progress-notify.ts | 14 ++ .../progress-notify/progress-notify.vue | 41 ++++ mayfly_go_web/src/router/index.ts | 2 +- server/go.mod | 14 +- server/internal/db/api/db.go | 202 ++++++++++++------ server/internal/db/api/db_test.go | 90 ++++++++ server/internal/db/api/sqlparser_test.go | 44 ++++ server/internal/db/application/db_sql_exec.go | 3 +- .../internal/msg/application/dto/sys_msg.go | 6 +- server/pkg/utils/uniqueid/uniqueid.go | 9 + server/pkg/ws/client_manager.go | 2 +- 12 files changed, 407 insertions(+), 96 deletions(-) create mode 100644 mayfly_go_web/src/components/progress-notify/progress-notify.ts create mode 100644 mayfly_go_web/src/components/progress-notify/progress-notify.vue create mode 100644 server/internal/db/api/db_test.go create mode 100644 server/internal/db/api/sqlparser_test.go create mode 100644 server/pkg/utils/uniqueid/uniqueid.go diff --git a/mayfly_go_web/src/common/sockets.ts b/mayfly_go_web/src/common/sockets.ts index cb678605..db61ca91 100644 --- a/mayfly_go_web/src/common/sockets.ts +++ b/mayfly_go_web/src/common/sockets.ts @@ -1,7 +1,10 @@ import Config from './config'; -import { ElNotification } from 'element-plus'; +import { ElNotification, NotificationHandle } from 'element-plus'; import SocketBuilder from './SocketBuilder'; import { getToken } from '@/common/utils/storage'; +import { createVNode, reactive } from "vue"; +import { buildProgressProps } from "@/components/progress-notify/progress-notify"; +import ProgressNotify from '/src/components/progress-notify/progress-notify.vue'; export default { /** @@ -12,32 +15,63 @@ export default { if (!token) { return null; } + const messageTypes = { + 0: "error", + 1: "success", + 2: "info", + } + const notifyMap: Map = new Map() + return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`) .message((event: { data: string }) => { const message = JSON.parse(event.data); - let mtype: string; - switch (message.type) { - case 0: - mtype = 'error'; - break; - case 2: - mtype = 'info'; - break; - case 1: - mtype = 'success'; + const type = messageTypes[message.type] + switch (message.category) { + case "execSqlFileProgress": + const content = JSON.parse(message.msg) + const id = content.id + let progress = notifyMap.get(id) + if (content.terminated) { + if (progress != undefined) { + progress.notification?.close() + notifyMap.delete(id) + progress = undefined + } + return + } + if (progress == undefined) { + progress = { + props: reactive(buildProgressProps()), + notification: undefined, + } + } + progress.props.progress.sqlFileName = content.sqlFileName + progress.props.progress.executedStatements = content.executedStatements + if (!notifyMap.has(id)) { + const vNodeMessage = createVNode( + ProgressNotify, + progress.props, + null, + ) + progress.notification = ElNotification({ + duration: 0, + title: message.title, + message: vNodeMessage, + type: type, + showClose: false, + }); + notifyMap.set(id, progress) + } break; default: - mtype = 'info'; + ElNotification({ + duration: 0, + title: message.title, + message: message.msg, + type: type, + }); + break; } - if (mtype == undefined) { - return; - } - ElNotification({ - duration: 0, - title: message.title, - message: message.msg, - type: mtype as any, - }); }) .open((event: any) => console.log(event)) .build(); diff --git a/mayfly_go_web/src/components/progress-notify/progress-notify.ts b/mayfly_go_web/src/components/progress-notify/progress-notify.ts new file mode 100644 index 00000000..396536cb --- /dev/null +++ b/mayfly_go_web/src/components/progress-notify/progress-notify.ts @@ -0,0 +1,14 @@ +export const buildProgressProps = () : any => { + return { + progress: { + sqlFileName: { + type: String + }, + executedStatements: { + type: Number + }, + }, + }; +} + + diff --git a/mayfly_go_web/src/components/progress-notify/progress-notify.vue b/mayfly_go_web/src/components/progress-notify/progress-notify.vue new file mode 100644 index 00000000..62d40f94 --- /dev/null +++ b/mayfly_go_web/src/components/progress-notify/progress-notify.vue @@ -0,0 +1,41 @@ + + \ No newline at end of file diff --git a/mayfly_go_web/src/router/index.ts b/mayfly_go_web/src/router/index.ts index 01a80fa6..74f70f50 100644 --- a/mayfly_go_web/src/router/index.ts +++ b/mayfly_go_web/src/router/index.ts @@ -206,7 +206,7 @@ router.beforeEach(async (to, from, next) => { if (SysWs) { SysWs.close(); - SysWs = null; + SysWs = undefined; } return; } diff --git a/server/go.mod b/server/go.mod index 348292e5..e5cdd97b 100644 --- a/server/go.mod +++ b/server/go.mod @@ -21,7 +21,7 @@ require ( github.com/pquerna/otp v1.4.0 github.com/redis/go-redis/v9 v9.2.1 github.com/robfig/cron/v3 v3.0.1 // 定时任务 - github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + github.com/stretchr/testify v1.8.4 go.mongodb.org/mongo-driver v1.12.1 // mongo golang.org/x/crypto v0.14.0 // ssh golang.org/x/oauth2 v0.13.0 @@ -32,12 +32,15 @@ require ( gorm.io/gorm v1.25.4 ) +require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -45,21 +48,22 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.13.6 // indirect + github.com/klauspost/compress v1.16.5 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/kr/fs v0.1.0 // indirect - github.com/kr/pretty v0.3.0 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/montanaflynn/stats v0.7.0 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 4db5aa20..4c820a95 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -3,7 +3,11 @@ package api import ( "bufio" "fmt" + "github.com/lib/pq" "io" + "mayfly-go/pkg/utils/uniqueid" + "mayfly-go/pkg/ws" + "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" @@ -79,9 +83,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) { } func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection { - dbName := g.Query("db") - biz.NotEmpty(dbName, "db不能为空") - return d.DbApp.GetDbConnection(getDbId(g), dbName) + return d.DbApp.GetDbConnection(getDbId(g), getDbName(g)) } func (d *Db) TableInfos(rc *req.Ctx) { @@ -152,67 +154,120 @@ func (d *Db) ExecSql(rc *req.Ctx) { rc.ResData = colAndRes } +// progressCategory sql文件执行进度消息类型 +const progressCategory = "execSqlFileProgress" + +// progressMsg sql文件执行进度消息 +type progressMsg struct { + Id uint64 `json:"id"` + SqlFileName string `json:"sqlFileName"` + ExecutedStatements int `json:"executedStatements"` + Terminated bool `json:"terminated"` +} + // 执行sql文件 func (d *Db) ExecSqlFile(rc *req.Ctx) { g := rc.GinCtx - fileheader, err := g.FormFile("file") + multipart, err := g.Request.MultipartReader() biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s") - - file, _ := fileheader.Open() - filename := fileheader.Filename + file, err := multipart.NextPart() + biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s") + defer file.Close() + filename := file.FileName() dbId := getDbId(g) dbName := getDbName(g) - dbConn := d.getDbConnection(rc.GinCtx) + dbConn := d.DbApp.GetDbConnection(dbId, dbName) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) - logExecRecord := true - // 如果执行sql文件大于该值则不记录sql执行记录 - if fileheader.Size > 50*1024 { - logExecRecord = false + defer func() { + if err := recover(); err != nil { + var errInfo string + switch t := err.(type) { + case biz.BizError: + errInfo = t.Error() + case *biz.BizError: + errInfo = t.Error() + case string: + errInfo = t + } + if len(errInfo) > 0 { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) + } + } + }() + + execReq := &application.DbSqlExecReq{ + DbId: dbId, + Db: dbName, + Remark: filename, + DbConn: dbConn, + LoginAccount: rc.LoginAccount, } - go func() { - defer func() { - if err := recover(); err != nil { - var errInfo string - switch t := err.(type) { - case error: - errInfo = t.Error() - } - if len(errInfo) > 0 { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) - } + defer func() { + if err := recover(); err != nil { + var errInfo string + switch t := err.(type) { + case error: + errInfo = t.Error() } - }() - - execReq := &application.DbSqlExecReq{ - DbId: dbId, - Db: dbName, - Remark: fileheader.Filename, - DbConn: dbConn, - LoginAccount: rc.LoginAccount, - } - - sqlScanner := SplitSqls(file) - for sqlScanner.Scan() { - sql := sqlScanner.Text() - execReq.Sql = sql - // 需要记录执行记录 - if logExecRecord { - _, err = d.DbSqlExecApp.Exec(execReq) - } else { - _, err = dbConn.Exec(sql) - } - - if err != nil { - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error()))) - return + if len(errInfo) > 0 { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) } } - d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) }() + + progressId := uniqueid.IncrementID() + executedStatements := 0 + defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ + Id: progressId, + SqlFileName: filename, + ExecutedStatements: executedStatements, + Terminated: true, + }).WithCategory(progressCategory)) + + ticker := time.NewTicker(time.Second * 1) + defer ticker.Stop() + sqlScanner := SplitSqls(file) + for sqlScanner.Scan() { + select { + case <-ticker.C: + ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ + Id: progressId, + SqlFileName: filename, + ExecutedStatements: executedStatements, + Terminated: false, + }).WithCategory(progressCategory)) + default: + } + sql := sqlScanner.Text() + const prefixUse = "use " + if strings.HasPrefix(sql, prefixUse) { + dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n") + if len(dbNameExec) > 0 { + dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec) + biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") + execReq.DbConn = dbConn + } + } + // 需要记录执行记录 + const maxRecordStatements = 64 + if executedStatements < maxRecordStatements { + execReq.Sql = sql + _, err = d.DbSqlExecApp.Exec(execReq) + } else { + _, err = dbConn.Exec(sql) + } + + if err != nil { + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error()))) + return + } + executedStatements++ + } + d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) } // 数据库dump @@ -276,23 +331,44 @@ func (d *Db) DumpSql(rc *req.Ctx) { rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType) } +func escapeSql(dbType string, sql string) string { + if dbType == entity.DbTypePostgres { + return pq.QuoteLiteral(sql) + } else { + sql = strings.ReplaceAll(sql, `\`, `\\`) + sql = strings.ReplaceAll(sql, `'`, `''`) + return "'" + sql + "'" + } +} + +func quoteTable(dbType string, table string) string { + if dbType == entity.DbTypePostgres { + return "\"" + table + "\"" + } else { + return "`" + table + "`" + } +} + func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) { dbConn := d.DbApp.GetDbConnection(dbId, dbName) - writer.WriteString("-- ----------------------------") + writer.WriteString("\n-- ----------------------------") writer.WriteString("\n-- 导出平台: mayfly-go") writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05"))) writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName)) writer.WriteString("\n-- ----------------------------\n") - writer.TryFlush() if switchDb { switch dbConn.Info.Type { case entity.DbTypeMysql: - writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName)) + writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName)) default: - biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql) + biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql) } } + if dbConn.Info.Type == entity.DbTypeMysql { + writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n") + } + dbMeta := dbConn.GetMeta() if len(tables) == 0 { ti := dbMeta.GetTableInfos() @@ -303,23 +379,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str } for _, table := range tables { + writer.TryFlush() + quotedTable := quoteTable(dbConn.Info.Type, table) if needStruct { writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) - writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table)) + writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable)) writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n") } - if !needData { continue } - writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table)) writer.WriteString("BEGIN;\n") - - insertSql := "INSERT INTO `%s` VALUES (%s);\n" - + insertSql := "INSERT INTO %s VALUES (%s);\n" dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) { var values []string + writer.TryFlush() for _, column := range columns { value := record[column] if value == nil { @@ -328,17 +403,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str } strValue, ok := value.(string) if ok { - values = append(values, fmt.Sprintf("%#v", strValue)) + strValue = escapeSql(dbConn.Info.Type, strValue) + values = append(values, strValue) } else { values = append(values, stringx.AnyToStr(value)) } } - writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", "))) - writer.TryFlush() + writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", "))) }) - writer.WriteString("COMMIT;\n") - writer.TryFlush() + } + if dbConn.Info.Type == entity.DbTypeMysql { + writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n") } } diff --git a/server/internal/db/api/db_test.go b/server/internal/db/api/db_test.go new file mode 100644 index 00000000..4f8547ee --- /dev/null +++ b/server/internal/db/api/db_test.go @@ -0,0 +1,90 @@ +package api + +import ( + "github.com/stretchr/testify/require" + "mayfly-go/internal/db/domain/entity" + "strings" + "testing" +) + +func Test_escapeSql(t *testing.T) { + tests := []struct { + name string + dbType string + sql string + want string + }{ + { + dbType: entity.DbTypeMysql, + sql: "\\a\\b", + want: "'\\\\a\\\\b'", + }, + { + dbType: entity.DbTypeMysql, + sql: "'a'", + want: "'''a'''", + }, + { + name: "不间断空格", + dbType: entity.DbTypeMysql, + sql: "a\u00A0b", + want: "'a\u00A0b'", + }, + { + dbType: entity.DbTypePostgres, + sql: "\\a\\b", + want: " E'\\\\a\\\\b'", + }, + { + dbType: entity.DbTypePostgres, + sql: "'a'", + want: "'''a'''", + }, + { + name: "不间断空格", + dbType: entity.DbTypePostgres, + sql: "a\u00A0b", + want: "'a\u00A0b'", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := escapeSql(tt.dbType, tt.sql) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_SplitSqls(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + }, + { + name: "select with ';\n'", + input: "select 'the first line;\nthe second line;\n'", + // SplitSqls split statements by ';\n' + want: "select 'the first line;\n", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scanner := SplitSqls(strings.NewReader(test.input)) + require.True(t, scanner.Scan()) + got := scanner.Text() + if len(test.want) == 0 { + test.want = test.input + } + require.Equal(t, test.want, got) + }) + } +} diff --git a/server/internal/db/api/sqlparser_test.go b/server/internal/db/api/sqlparser_test.go new file mode 100644 index 00000000..3c41c878 --- /dev/null +++ b/server/internal/db/api/sqlparser_test.go @@ -0,0 +1,44 @@ +package api + +import ( + "github.com/stretchr/testify/require" + "github.com/xwb1989/sqlparser" + "strings" + "testing" +) + +func Test_ParseNext_WithCurrentDate(t *testing.T) { + tests := []struct { + name string + input string + want string + wantXwb1989 string + err string + }{ + { + name: "create table with current_timestamp", + input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)", + // xwb1989/sqlparser 不支持 current_timestamp() + wantXwb1989: "create table tbl", + }, + { + name: "create table with current_date", + input: "create table tbl (\n\tcreate_at date default current_date()\n)", + // xwb1989/sqlparser 不支持 current_date() + wantXwb1989: "create table tbl", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token := sqlparser.NewTokenizer(strings.NewReader(test.input)) + tree, err := sqlparser.ParseNext(token) + if len(test.err) > 0 { + require.Error(t, err) + require.Contains(t, err.Error(), test.err) + return + } + require.NoError(t, err) + require.Equal(t, test.wantXwb1989, sqlparser.String(tree)) + }) + } +} diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 3c5ab4d6..6f2d62bb 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -3,6 +3,7 @@ package application import ( "encoding/json" "fmt" + "github.com/xwb1989/sqlparser" "mayfly-go/internal/db/config" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" @@ -10,8 +11,6 @@ import ( "mayfly-go/pkg/model" "strconv" "strings" - - "github.com/xwb1989/sqlparser" ) type DbSqlExecReq struct { diff --git a/server/internal/msg/application/dto/sys_msg.go b/server/internal/msg/application/dto/sys_msg.go index da94ee61..d328903b 100644 --- a/server/internal/msg/application/dto/sys_msg.go +++ b/server/internal/msg/application/dto/sys_msg.go @@ -11,7 +11,7 @@ const InfoSysMsgType = 2 // websocket消息 type SysMsg struct { Type int `json:"type"` // 消息类型 - Category int `json:"category"` // 消息类别 + Category string `json:"category"` // 消息类别 Title string `json:"title"` // 消息标题 Msg string `json:"msg"` // 消息内容 } @@ -21,7 +21,7 @@ func (sm *SysMsg) WithTitle(title string) *SysMsg { return sm } -func (sm *SysMsg) WithCategory(category int) *SysMsg { +func (sm *SysMsg) WithCategory(category string) *SysMsg { sm.Category = category return sm } @@ -32,7 +32,7 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg { } // 普通消息 -func NewSysMsg(title string, msg any) *SysMsg { +func InfoSysMsg(title string, msg any) *SysMsg { return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)} } diff --git a/server/pkg/utils/uniqueid/uniqueid.go b/server/pkg/utils/uniqueid/uniqueid.go new file mode 100644 index 00000000..777a429b --- /dev/null +++ b/server/pkg/utils/uniqueid/uniqueid.go @@ -0,0 +1,9 @@ +package uniqueid + +import "sync/atomic" + +var id uint64 = 0 + +func IncrementID() uint64 { + return atomic.AddUint64(&id, 1) +} diff --git a/server/pkg/ws/client_manager.go b/server/pkg/ws/client_manager.go index aee950ca..c1b5f71a 100644 --- a/server/pkg/ws/client_manager.go +++ b/server/pkg/ws/client_manager.go @@ -7,7 +7,7 @@ import ( ) // 心跳间隔 -var heartbeatInterval = 25 * time.Second +const heartbeatInterval = 25 * time.Second // 连接管理 type ClientManager struct {