From 4b973b22a428a32e46e76fbd1def68a7ec1ad0ca Mon Sep 17 00:00:00 2001
From: "meilin.huang" <954537473@qq.com>
Date: Tue, 12 Sep 2023 20:54:07 +0800
Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=B3=BB=E7=BB=9Fwebsocket?=
=?UTF-8?q?=E6=B6=88=E6=81=AF=E9=87=8D=E6=9E=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../views/ops/machine/file/MachineFile.vue | 2 +-
server/internal/auth/api/account_login.go | 2 +
server/internal/db/api/db.go | 10 +-
server/internal/machine/api/machine_file.go | 8 +-
server/internal/msg/application/msg.go | 6 +-
server/internal/sys/api/system.go | 2 +-
server/pkg/ws/client.go | 84 +++++++++
server/pkg/ws/client_manager.go | 160 ++++++++++++++++++
server/pkg/ws/msg.go | 45 +++--
server/pkg/ws/ws.go | 59 ++-----
10 files changed, 303 insertions(+), 75 deletions(-)
create mode 100644 server/pkg/ws/client.go
create mode 100644 server/pkg/ws/client_manager.go
diff --git a/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue b/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue
index b9b3b0c5..47e380b0 100755
--- a/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue
+++ b/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue
@@ -190,7 +190,7 @@
-
rename: 双击文件名单元格后回车
+ rename: 双击文件名单元格修改后回车
操作
diff --git a/server/internal/auth/api/account_login.go b/server/internal/auth/api/account_login.go
index f3e40a82..788d26f7 100644
--- a/server/internal/auth/api/account_login.go
+++ b/server/internal/auth/api/account_login.go
@@ -17,6 +17,7 @@ import (
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/cryptox"
"mayfly-go/pkg/utils/jsonx"
+ "mayfly-go/pkg/ws"
"strconv"
"time"
)
@@ -118,4 +119,5 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
func (a *AccountLogin) Logout(rc *req.Ctx) {
req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id)
+ ws.CloseClient(rc.LoginAccount.Id)
}
diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go
index 77282b8b..12244a06 100644
--- a/server/internal/db/api/db.go
+++ b/server/internal/db/api/db.go
@@ -182,7 +182,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
errInfo = t.Error()
}
if len(errInfo) > 0 {
- d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
+ d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}
}()
@@ -202,7 +202,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
break
}
if err != nil {
- d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
+ d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
return
}
sql := sqlparser.String(stmt)
@@ -215,11 +215,11 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
}
if err != nil {
- d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
+ d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
}
}
- d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
+ d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
}()
}
@@ -275,7 +275,7 @@ func (d *Db) DumpSql(rc *req.Ctx) {
if len(msg) > 0 {
msg = "数据库导出失败: " + msg
writer.WriteString(msg)
- d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("数据库导出失败", msg))
+ d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrSysMsg("数据库导出失败", msg))
}
writer.Close()
}()
diff --git a/server/internal/machine/api/machine_file.go b/server/internal/machine/api/machine_file.go
index aacc336c..0f78b2c4 100644
--- a/server/internal/machine/api/machine_file.go
+++ b/server/internal/machine/api/machine_file.go
@@ -187,7 +187,7 @@ func (m *MachineFile) UploadFile(rc *req.Ctx) {
logx.Errorf("文件上传失败: %s", err)
switch t := err.(type) {
case biz.BizError:
- m.MsgApp.CreateAndSend(la, ws.ErrMsg("文件上传失败", fmt.Sprintf("执行文件上传失败:\n<-e errCode: %d, errMsg: %s", t.Code(), t.Error())))
+ m.MsgApp.CreateAndSend(la, ws.ErrSysMsg("文件上传失败", fmt.Sprintf("执行文件上传失败:\n<-e errCode: %d, errMsg: %s", t.Code(), t.Error())))
}
}
}()
@@ -196,7 +196,7 @@ func (m *MachineFile) UploadFile(rc *req.Ctx) {
rc.ReqParam = jsonx.Kvs("machine", mi, "path", fmt.Sprintf("%s/%s", path, fileheader.Filename))
biz.ErrIsNilAppendErr(err, "创建文件失败: %s")
// 保存消息并发送文件上传成功通知
- m.MsgApp.CreateAndSend(la, ws.SuccessMsg("文件上传成功", fmt.Sprintf("[%s]文件已成功上传至 %s[%s:%s]", fileheader.Filename, mi.Name, mi.Ip, path)))
+ m.MsgApp.CreateAndSend(la, ws.SuccessSysMsg("文件上传成功", fmt.Sprintf("[%s]文件已成功上传至 %s[%s:%s]", fileheader.Filename, mi.Name, mi.Ip, path)))
}
type FolderFile struct {
@@ -262,7 +262,7 @@ func (m *MachineFile) UploadFolder(rc *req.Ctx) {
logx.Errorf("文件上传失败: %s", err)
switch t := err.(type) {
case biz.BizError:
- m.MsgApp.CreateAndSend(la, ws.ErrMsg("文件上传失败", fmt.Sprintf("执行文件上传失败:\n<-e errCode: %d, errMsg: %s", t.Code(), t.Error())))
+ m.MsgApp.CreateAndSend(la, ws.ErrSysMsg("文件上传失败", fmt.Sprintf("执行文件上传失败:\n<-e errCode: %d, errMsg: %s", t.Code(), t.Error())))
}
}
}()
@@ -286,7 +286,7 @@ func (m *MachineFile) UploadFolder(rc *req.Ctx) {
// 等待所有协程执行完成
wg.Wait()
// 保存消息并发送文件上传成功通知
- m.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("文件上传成功", fmt.Sprintf("[%s]文件夹已成功上传至 %s[%s:%s]", folderName, mi.Name, mi.Ip, basePath)))
+ m.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessSysMsg("文件上传成功", fmt.Sprintf("[%s]文件夹已成功上传至 %s[%s:%s]", folderName, mi.Name, mi.Ip, basePath)))
}
func (m *MachineFile) RemoveFile(rc *req.Ctx) {
diff --git a/server/internal/msg/application/msg.go b/server/internal/msg/application/msg.go
index 29ee5b97..fd51ec9e 100644
--- a/server/internal/msg/application/msg.go
+++ b/server/internal/msg/application/msg.go
@@ -14,7 +14,7 @@ type Msg interface {
Create(msg *entity.Msg)
// 创建消息,并通过ws发送
- CreateAndSend(la *model.LoginAccount, msg *ws.Msg)
+ CreateAndSend(la *model.LoginAccount, msg *ws.SysMsg)
}
func newMsgApp(msgRepo repository.Msg) Msg {
@@ -35,9 +35,9 @@ func (a *msgAppImpl) Create(msg *entity.Msg) {
a.msgRepo.Insert(msg)
}
-func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *ws.Msg) {
+func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *ws.SysMsg) {
now := time.Now()
- msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
+ msg := &entity.Msg{Type: 2, Msg: wmsg.SysMsg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
a.msgRepo.Insert(msg)
ws.SendMsg(la.Id, wmsg)
}
diff --git a/server/internal/sys/api/system.go b/server/internal/sys/api/system.go
index c27d708c..44743d5d 100644
--- a/server/internal/sys/api/system.go
+++ b/server/internal/sys/api/system.go
@@ -37,5 +37,5 @@ func (s *System) ConnectWs(g *gin.Context) {
// 登录账号信息
la := rc.LoginAccount
- ws.Put(la.Id, wsConn)
+ ws.AddClient(la.Id, wsConn)
}
diff --git a/server/pkg/ws/client.go b/server/pkg/ws/client.go
new file mode 100644
index 00000000..c95a6b9e
--- /dev/null
+++ b/server/pkg/ws/client.go
@@ -0,0 +1,84 @@
+package ws
+
+import (
+ "encoding/json"
+ "errors"
+ "mayfly-go/pkg/utils/stringx"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+type UserId uint64
+
+// 客户端读取消息处理函数
+// @param msg
+type ReadMsgHandlerFunc func([]byte)
+
+type Client struct {
+ ClientId string // 标识ID
+ UserId UserId // 用户ID
+ WsConn *websocket.Conn // 用户连接
+
+ ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数
+}
+
+func (c *Client) Read() {
+ go func() {
+ for {
+ messageType, data, err := c.WsConn.ReadMessage()
+ if err != nil {
+ if messageType == -1 && websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
+ Manager.CloseClient(c)
+ return
+ }
+ if messageType != websocket.PingMessage {
+ return
+ }
+ }
+ if c.ReadMsgHander != nil {
+ c.ReadMsgHander(data)
+ }
+ }
+ }()
+}
+
+// 向客户端写入消息
+func (c *Client) WriteMsg(msg *Msg) error {
+ if msg.Type == JsonMsg {
+ bytes, _ := json.Marshal(msg.Data)
+ return c.WsConn.WriteMessage(websocket.TextMessage, bytes)
+ }
+
+ if msg.Type == BinaryMsg {
+ if byteData, ok := msg.Data.([]byte); ok {
+ return c.WsConn.WriteMessage(websocket.BinaryMessage, byteData)
+ } else {
+ return errors.New("该数据不为数组类型")
+ }
+ }
+
+ if msg.Type == TextMsg {
+ if strData, ok := msg.Data.(string); ok {
+ return c.WsConn.WriteMessage(websocket.TextMessage, []byte(strData))
+ } else {
+ return errors.New("该数据类型不为字符串")
+ }
+ }
+ return errors.New("不存在该消息类型, 无法发送")
+}
+
+// 向客户写入ping消息
+func (c *Client) Ping() error {
+ return c.WsConn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second))
+}
+
+func NewClient(userId UserId, socket *websocket.Conn) *Client {
+ cli := &Client{
+ ClientId: stringx.Rand(16),
+ UserId: userId,
+ WsConn: socket,
+ }
+
+ return cli
+}
diff --git a/server/pkg/ws/client_manager.go b/server/pkg/ws/client_manager.go
new file mode 100644
index 00000000..fd9772d3
--- /dev/null
+++ b/server/pkg/ws/client_manager.go
@@ -0,0 +1,160 @@
+package ws
+
+import (
+ "mayfly-go/pkg/logx"
+ "sync"
+ "time"
+)
+
+// 心跳间隔
+var heartbeatInterval = 25 * time.Second
+
+// 连接管理
+type ClientManager struct {
+ ClientMap map[UserId]*Client // 全部的连接, key->userid, value->&client
+ RwLock sync.RWMutex // 读写锁
+
+ ConnectChan chan *Client // 连接处理
+ DisConnectChan chan *Client // 断开连接处理
+ MsgChan chan *Msg // 消息信息channel通道
+}
+
+func NewClientManager() (clientManager *ClientManager) {
+ return &ClientManager{
+ ClientMap: make(map[UserId]*Client),
+ ConnectChan: make(chan *Client, 10),
+ DisConnectChan: make(chan *Client, 10),
+ MsgChan: make(chan *Msg, 100),
+ }
+}
+
+// 管道处理程序
+func (manager *ClientManager) Start() {
+ manager.HeartbeatTimer()
+ go manager.WriteMessage()
+ for {
+ select {
+ case client := <-manager.ConnectChan:
+ // 建立连接
+ manager.doConnect(client)
+ case conn := <-manager.DisConnectChan:
+ // 断开连接
+ manager.doDisconnect(conn)
+ }
+ }
+}
+
+// 添加客户端
+func (manager *ClientManager) AddClient(client *Client) {
+ manager.ConnectChan <- client
+}
+
+// 关闭客户端
+func (manager *ClientManager) CloseClient(client *Client) {
+ if client == nil {
+ return
+ }
+ manager.DisConnectChan <- client
+}
+
+// 根据用户id关闭客户端连接
+func (manager *ClientManager) CloseByUid(uid UserId) {
+ manager.CloseClient(manager.GetByUid(UserId(uid)))
+}
+
+// 获取所有的客户端
+func (manager *ClientManager) AllClient() map[UserId]*Client {
+ manager.RwLock.RLock()
+ defer manager.RwLock.RUnlock()
+
+ return manager.ClientMap
+}
+
+// 通过userId获取
+func (manager *ClientManager) GetByUid(userId UserId) *Client {
+ manager.RwLock.RLock()
+ defer manager.RwLock.RUnlock()
+ return manager.ClientMap[userId]
+}
+
+// 客户端数量
+func (manager *ClientManager) Count() int {
+ manager.RwLock.RLock()
+ defer manager.RwLock.RUnlock()
+ return len(manager.ClientMap)
+}
+
+// 发送json数据给指定用户
+func (manager *ClientManager) SendJsonMsg(userId UserId, data any) {
+ logx.Debugf("发送消息: toUid=%v, data=%v", userId, data)
+ manager.MsgChan <- &Msg{ToUserId: userId, Data: data, Type: JsonMsg}
+}
+
+// 监听并发送给客户端信息
+func (manager *ClientManager) WriteMessage() {
+ go func() {
+ for {
+ msg := <-manager.MsgChan
+ if cli := manager.GetByUid(msg.ToUserId); cli != nil {
+ if err := cli.WriteMsg(msg); err != nil {
+ manager.CloseClient(cli)
+ }
+ }
+ }
+ }()
+}
+
+// 启动定时器进行心跳检测
+func (manager *ClientManager) HeartbeatTimer() {
+ go func() {
+ ticker := time.NewTicker(heartbeatInterval)
+ defer ticker.Stop()
+ for {
+ <-ticker.C
+ //发送心跳
+ for userId, cli := range manager.AllClient() {
+ if cli == nil || cli.WsConn == nil {
+ continue
+ }
+ if err := cli.Ping(); err != nil {
+ manager.CloseClient(cli)
+ logx.Errorf("WS发送心跳失败: %v 总连接数:%d", userId, Manager.Count())
+ } else {
+ logx.Debugf("WS发送心跳成功: uid=%v", userId)
+ }
+ }
+ }
+
+ }()
+}
+
+// 处理建立连接
+func (manager *ClientManager) doConnect(client *Client) {
+ cli := manager.GetByUid(client.UserId)
+ if cli != nil {
+ manager.doDisconnect(cli)
+ }
+ manager.addClient2Map(client)
+ logx.Debugf("WS客户端已连接: uid=%d, count=%d", client.UserId, manager.Count())
+}
+
+// 处理断开连接
+func (manager *ClientManager) doDisconnect(client *Client) {
+ //关闭连接
+ _ = client.WsConn.Close()
+ client.WsConn = nil
+ manager.delClient4Map(client)
+ logx.Debugf("WS客户端已断开: uid=%d, count=%d", client.UserId, Manager.Count())
+}
+
+func (manager *ClientManager) addClient2Map(client *Client) {
+ manager.RwLock.Lock()
+ defer manager.RwLock.Unlock()
+ manager.ClientMap[client.UserId] = client
+}
+
+func (manager *ClientManager) delClient4Map(client *Client) {
+ manager.RwLock.Lock()
+ defer manager.RwLock.Unlock()
+ delete(manager.ClientMap, client.UserId)
+}
diff --git a/server/pkg/ws/msg.go b/server/pkg/ws/msg.go
index 3e3a20ab..e1abcd29 100644
--- a/server/pkg/ws/msg.go
+++ b/server/pkg/ws/msg.go
@@ -1,27 +1,46 @@
package ws
-const SuccessMsgType = 1
-const ErrorMsgType = 0
-const InfoMsgType = 2
+// 消息类型
+type MsgType uint8
+
+const (
+ JsonMsg MsgType = 1
+ TextMsg MsgType = 2
+ BinaryMsg MsgType = 3
+)
+
+// 消息信息
+type Msg struct {
+ ToUserId UserId
+ Data any
+
+ Type MsgType // 消息类型
+}
+
+// ************** 系统消息 **************
+
+const SuccessSysMsgType = 1
+const ErrorSysMsgType = 0
+const InfoSysMsgType = 2
// websocket消息
-type Msg struct {
- Type int `json:"type"` // 消息类型
- Title string `json:"title"` // 消息标题
- Msg string `json:"msg"` // 消息内容
+type SysMsg struct {
+ Type int `json:"type"` // 消息类型
+ Title string `json:"title"` // 消息标题
+ SysMsg string `json:"msg"` // 消息内容
}
// 普通消息
-func NewMsg(title, msg string) *Msg {
- return &Msg{Type: InfoMsgType, Title: title, Msg: msg}
+func NewSysMsg(title, msg string) *SysMsg {
+ return &SysMsg{Type: InfoSysMsgType, Title: title, SysMsg: msg}
}
// 成功消息
-func SuccessMsg(title, msg string) *Msg {
- return &Msg{Type: SuccessMsgType, Title: title, Msg: msg}
+func SuccessSysMsg(title, msg string) *SysMsg {
+ return &SysMsg{Type: SuccessSysMsgType, Title: title, SysMsg: msg}
}
// 错误消息
-func ErrMsg(title, msg string) *Msg {
- return &Msg{Type: ErrorMsgType, Title: title, Msg: msg}
+func ErrSysMsg(title, msg string) *SysMsg {
+ return &SysMsg{Type: ErrorSysMsgType, Title: title, SysMsg: msg}
}
diff --git a/server/pkg/ws/ws.go b/server/pkg/ws/ws.go
index f1be7a7f..39defecf 100644
--- a/server/pkg/ws/ws.go
+++ b/server/pkg/ws/ws.go
@@ -1,10 +1,7 @@
package ws
import (
- "encoding/json"
- "mayfly-go/pkg/logx"
"net/http"
- "time"
"github.com/gorilla/websocket"
)
@@ -17,58 +14,24 @@ var Upgrader = websocket.Upgrader{
},
}
-var conns = make(map[uint64]*websocket.Conn, 100)
+var Manager = NewClientManager() // 管理者
func init() {
- checkConn()
+ go Manager.Start()
}
-// 放置ws连接
-func Put(userId uint64, conn *websocket.Conn) {
- existConn := conns[userId]
- if existConn != nil {
- Delete(userId)
- }
-
- conn.SetCloseHandler(func(code int, text string) error {
- Delete(userId)
- return nil
- })
- conns[userId] = conn
+// 添加ws客户端
+func AddClient(userId uint64, conn *websocket.Conn) *Client {
+ cli := NewClient(UserId(userId), conn)
+ Manager.AddClient(cli)
+ return cli
}
-func checkConn() {
- heartbeat := time.Duration(60) * time.Second
- tick := time.NewTicker(heartbeat)
- go func() {
- for range tick.C {
- // 遍历所有连接,ping失败的则删除掉
- for uid, conn := range conns {
- err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(heartbeat/2))
- if err != nil {
- Delete(uid)
- return
- }
- }
- }
- }()
-}
-
-// 删除ws连接
-func Delete(userid uint64) {
- logx.Debugf("移除websocket连接: uid = %d", userid)
- conn := conns[userid]
- if conn != nil {
- conn.Close()
- delete(conns, userid)
- }
+func CloseClient(userid uint64) {
+ Manager.CloseByUid(UserId(userid))
}
// 对指定用户发送消息
-func SendMsg(userId uint64, msg *Msg) {
- conn := conns[userId]
- if conn != nil {
- bytes, _ := json.Marshal(msg)
- conn.WriteMessage(websocket.TextMessage, bytes)
- }
+func SendMsg(userId uint64, msg *SysMsg) {
+ Manager.SendJsonMsg(UserId(userId), msg)
}