refactor: websocket支持单用户多连接

This commit is contained in:
meilin.huang
2023-10-19 19:00:23 +08:00
parent 747ea6404d
commit 2b91bbe185
31 changed files with 365 additions and 263 deletions

View File

@@ -1,13 +1,13 @@
# 🌈mayfly-go
<p align="center">
<a href="https://gitee.com/objs/mayfly-go" target="_blank">
<img src="https://gitee.com/objs/mayfly-go/badge/star.svg?theme=white" alt="star"/>
<img src="https://gitee.com/objs/mayfly-go/badge/fork.svg" alt="fork"/>
<a href="https://gitee.com/dromara/mayfly-go" target="_blank">
<img src="https://gitee.com/dromara/mayfly-go/badge/star.svg?theme=white" alt="star"/>
<img src="https://gitee.com/dromara/mayfly-go/badge/fork.svg" alt="fork"/>
</a>
<a href="https://github.com/may-fly/mayfly-go" target="_blank">
<img src="https://img.shields.io/github/stars/may-fly/mayfly-go.svg?style=social" alt="github star"/>
<img src="https://img.shields.io/github/forks/may-fly/mayfly-go.svg?style=social" alt="github fork"/>
<a href="https://github.com/dromara/mayfly-go" target="_blank">
<img src="https://img.shields.io/github/stars/dromara/mayfly-go.svg?style=social" alt="github star"/>
<img src="https://img.shields.io/github/forks/dromara/mayfly-go.svg?style=social" alt="github fork"/>
</a>
<a href="https://hub.docker.com/r/mayflygo/mayfly-go/tags" target="_blank">
<img src="https://img.shields.io/docker/pulls/mayflygo/mayfly-go.svg?label=docker%20pulls&color=fac858" alt="docker pulls"/>
@@ -100,4 +100,4 @@ http://go.mayfly.run
#### 💌 支持作者
如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/may-fly/mayfly-go">Github</a> 或者 <a target="_blank" href="https://gitee.com/objs/mayfly-go">Gitee</a> 帮我点个 ⭐ Star这将是对我极大的鼓励与支持。
如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/dromara/mayfly-go">Github</a> 或者 <a target="_blank" href="https://gitee.com/dromara/mayfly-go">Gitee</a> 帮我点个 ⭐ Star这将是对我极大的鼓励与支持。

View File

@@ -15,7 +15,7 @@
"countup.js": "^2.7.0",
"cropperjs": "^1.5.11",
"echarts": "^5.4.0",
"element-plus": "^2.4.0",
"element-plus": "^2.4.1",
"jsencrypt": "^3.3.1",
"lodash": "^4.17.21",
"mitt": "^3.0.1",

View File

@@ -1,7 +1,7 @@
import router from '../router';
import Axios from 'axios';
import config from './config';
import { getClientUuid, getToken, joinClientParams } from './utils/storage';
import { getClientId, getToken } from './utils/storage';
import { templateResolve } from './utils/string';
import { ElMessage } from 'element-plus';
@@ -54,7 +54,7 @@ service.interceptors.request.use(
if (token) {
// 设置token
config.headers['Authorization'] = token;
config.headers['Client-Uuid'] = getClientUuid();
config.headers['ClientId'] = getClientId();
}
return config;
},
@@ -180,6 +180,11 @@ function getApiUrl(url: string) {
return baseUrl + url + '?' + joinClientParams();
}
// 组装客户端参数,包括 token 和 clientId
export function joinClientParams(): string {
return `token=${getToken()}&clientId=${getClientId()}`;
}
export default {
request,
get,

View File

@@ -1,76 +0,0 @@
import Config from './config';
import { ElNotification, NotificationHandle } from 'element-plus';
import SocketBuilder from './SocketBuilder';
import { getToken, joinClientParams } from '@/common/utils/storage';
import { createVNode, reactive } from 'vue';
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
export default {
/**
* 全局系统消息websocket
*/
sysMsgSocket() {
const token = getToken();
if (!token) {
return null;
}
const messageTypes = {
0: 'error',
1: 'success',
2: 'info',
};
const notifyMap: Map<Number, any> = new Map();
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
return SocketBuilder.builder(sysMsgUrl)
.message((event: { data: string }) => {
const message = JSON.parse(event.data);
const type = messageTypes[message.type];
switch (message.category) {
case 'execSqlFileProgress':
const content = JSON.parse(message.msg);
const id = content.id;
let progress = notifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
notifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!notifyMap.has(id)) {
const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: vNodeMessage,
type: type,
showClose: false,
});
notifyMap.set(id, progress);
}
break;
default:
ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: type,
});
break;
}
})
.open((event: any) => console.log(event))
.build();
},
};

View File

@@ -0,0 +1,99 @@
import Config from './config';
import { ElNotification } from 'element-plus';
import SocketBuilder from './SocketBuilder';
import { getToken } from '@/common/utils/storage';
import { joinClientParams } from './request';
class SysSocket {
/**
* socket连接
*/
socket: any;
/**
* key -> 消息类别value -> 消息对应的处理器函数
*/
categoryHandlers: Map<string, any> = new Map();
/**
* 消息类型
*/
messageTypes = {
0: 'error',
1: 'success',
2: 'info',
};
/**
* 初始化全局系统消息websocket
*/
init() {
// 存在则不需要重新建立连接
if (this.socket) {
return;
}
const token = getToken();
if (!token) {
return null;
}
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
this.socket = SocketBuilder.builder(sysMsgUrl)
.message((event: { data: string }) => {
const message = JSON.parse(event.data);
// 存在消息类别对应的处理器,则进行处理,否则进行默认通知处理
const handler = this.categoryHandlers.get(message.category);
if (handler) {
handler(message);
return;
}
const type = this.getMsgType(message.type);
ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: type,
});
})
.open((event: any) => console.log(event))
.close(() => {
console.log('close sys socket');
this.socket = null;
})
.build();
}
destory() {
this.socket.close();
this.socket = null;
this.categoryHandlers.clear();
}
/**
* 注册消息处理函数
*
* @param category 消息类别
* @param handlerFunc 消息处理函数
*/
registerMsgHandler(category: any, handlerFunc: any) {
if (this.categoryHandlers.has(category)) {
console.log(`${category}该类别消息处理器已存在...`);
return;
}
if (typeof handlerFunc != 'function') {
throw new Error('message handler需为函数');
}
this.categoryHandlers.set(category, handlerFunc);
}
getMsgType(msgType: any) {
return this.messageTypes[msgType];
}
}
// 全局系统消息websocket;
const sysSocket = new SysSocket();
export default sysSocket;

View File

@@ -1,9 +1,9 @@
import { v1 as uuidv1 } from 'uuid';
import { randomUuid } from './string';
const TokenKey = 'token';
const UserKey = 'user';
const TagViewsKey = 'tagViews';
const ClientUuid = 'clientUuid'
const ClientIdKey = 'clientId';
// 获取请求token
export function getToken(): string {
@@ -52,18 +52,13 @@ export function removeTagViews() {
}
// 获取客户端UUID
export function getClientUuid(): string {
let uuid = getSession(ClientUuid)
export function getClientId(): string {
let uuid = getSession(ClientIdKey);
if (uuid == null) {
uuid = uuidv1()
setSession(ClientUuid, uuid)
uuid = randomUuid();
setSession(ClientIdKey, uuid);
}
return uuid
}
// 组装客户端参数,包括 token 和 clientUuid
export function joinClientParams(): string {
return `token=${getToken()}&clientUuid=${getClientUuid()}`
return uuid;
}
// 1. localStorage

View File

@@ -1,3 +1,5 @@
import { v1 as uuidv1 } from 'uuid';
/**
* 模板字符串解析template = 'hahaha{name}_{id}' ,param = {name: 'hh', id: 1}
* 解析后为 hahahahh_1
@@ -129,3 +131,11 @@ export function getContentWidth(content: any): number {
// }
return flexWidth;
}
/**
*
* @returns uuid
*/
export function randomUuid() {
return uuidv1();
}

View File

@@ -6,7 +6,7 @@ import { templateResolve } from '@/common/utils/string';
import { NextLoading } from '@/common/utils/loading';
import { dynamicRoutes, staticRoutes, pathMatch } from './route';
import openApi from '@/common/openApi';
import sockets from '@/common/sockets';
import syssocket from '@/common/syssocket';
import pinia from '@/store/index';
import { useThemeConfig } from '@/store/themeConfig';
import { useUserInfo } from '@/store/userInfo';
@@ -179,7 +179,6 @@ export async function initRouter() {
}
}
let SysWs: any;
let loadRouter = false;
// 路由加载前
@@ -204,10 +203,7 @@ router.beforeEach(async (to, from, next) => {
resetRoute();
NProgress.done();
if (SysWs) {
SysWs.close();
SysWs = undefined;
}
syssocket.destory();
return;
}
if (token && to.path === '/login') {
@@ -217,9 +213,10 @@ router.beforeEach(async (to, from, next) => {
}
// 终端不需要连接系统websocket消息
if (!SysWs && to.path != '/machine/terminal') {
SysWs = sockets.sysMsgSocket();
if (to.path != '/machine/terminal') {
syssocket.init();
}
// 不存在路由避免刷新页面找不到路由并且未加载过避免token过期导致获取权限接口报权限不足无限获取则重新初始化路由
if (useRoutesList().routesList.length == 0 && !loadRouter) {
await initRouter();

View File

@@ -171,6 +171,8 @@ const changeDatabase = () => {
};
const getAllDatabase = async () => {
// 清空数据库列表,可能已经有选择库了
state.databaseList = [];
if (state.form.instanceId > 0) {
state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId });
}

View File

@@ -172,7 +172,7 @@ import { ref, toRefs, reactive, onMounted, defineAsyncComponent } from 'vue';
import { ElMessage, ElMessageBox } from 'element-plus';
import { dbApi } from './api';
import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage';
import { joinClientParams } from '@/common/request';
import { isTrue } from '@/common/assert';
import { Search as SearchIcon } from '@element-plus/icons-vue';
import { dateFormat } from '@/common/utils/date';
@@ -355,7 +355,9 @@ const deleteDb = async () => {
await dbApi.deleteDb.request({ id: state.selectionData.map((x: any) => x.id).join(',') });
ElMessage.success('删除成功');
search();
} catch (err) {}
} catch (err) {
//
}
};
const onShowSqlExec = async (row: any) => {

View File

@@ -387,6 +387,9 @@ const addQueryTab = async (inst: any, db: string, sqlName: string = '') => {
dbs: inst.dbs,
};
state.tabs.set(label, tab);
// 注册当前sql编辑框提示词
registerDbCompletionItemProvider('sql', tab.dbId, tab.db, tab.params.dbs);
};
const onRemoveTab = (targetName: string) => {

View File

@@ -88,7 +88,7 @@
<script lang="ts" setup>
import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue';
import { getToken, joinClientParams } from '@/common/utils/storage';
import { getToken } from '@/common/utils/storage';
import { isTrue, notBlank } from '@/common/assert';
import { format as sqlFormatter } from 'sql-formatter';
import config from '@/common/config';
@@ -104,6 +104,12 @@ import { dateStrFormat } from '@/common/utils/date';
import { dbApi } from '../../api';
import MonacoEditor from '@/components/monaco/MonacoEditor.vue';
import { joinClientParams } from '@/common/request';
import { createVNode } from 'vue';
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
import { ElNotification } from 'element-plus';
import syssocket from '@/common/syssocket';
const emits = defineEmits(['saveSqlSuccess', 'deleteSqlSuccess']);
@@ -384,7 +390,9 @@ const deleteSql = async () => {
await dbApi.deleteDbSql.request({ id: dbId, db: db, name: sqlName });
ElMessage.success('删除成功');
emits('deleteSqlSuccess', dbId, db);
} catch (err) {}
} catch (err) {
//
}
};
/**
@@ -472,8 +480,45 @@ const exportData = () => {
);
};
/**
* sql文件执行进度通知缓存
*/
const sqlExecNotifyMap: Map<string, any> = new Map();
const beforeUpload = (file: File) => {
ElMessage.success(`'${file.name}' 正在上传执行, 请关注结果通知`);
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
const content = JSON.parse(message.msg);
const id = content.id;
let progress = sqlExecNotifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
sqlExecNotifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!sqlExecNotifyMap.has(id)) {
const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: vNodeMessage,
type: syssocket.getMsgType(message.type),
showClose: false,
});
sqlExecNotifyMap.set(id, progress);
}
});
};
// 执行sql成功

View File

@@ -345,7 +345,6 @@ export class DbInst {
// 获取该列中最长的数据(内容)
let maxWidthText = '';
let maxWidthValue;
// 获取该列中最长的数据(内容)
for (let i = 0; i < tableData.length; i++) {
let nowValue = tableData[i][prop];
@@ -356,7 +355,6 @@ export class DbInst {
let nowText = nowValue + '';
if (nowText.length > maxWidthText.length) {
maxWidthText = nowText;
maxWidthValue = nowValue;
}
}
const contentWidth: number = getTextWidth(maxWidthText) + 15;

View File

@@ -118,13 +118,13 @@
</template>
<script lang="ts" setup>
import { toRefs, reactive, watch, computed, onMounted, defineAsyncComponent, nextTick } from 'vue';
import { toRefs, reactive, watch, computed, onMounted, defineAsyncComponent } from 'vue';
import { ElMessageBox } from 'element-plus';
import { formatByteSize } from '@/common/utils/format';
import { dbApi } from '../api';
import SqlExecBox from '../component/SqlExecBox';
import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage';
import { joinClientParams } from '@/common/request';
import { isTrue } from '@/common/assert';
const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue'));
@@ -209,7 +209,7 @@ onMounted(async () => {
getTables();
});
watch(props, async (newValue: any) => {
watch(props, async () => {
await getTables();
});
@@ -239,6 +239,7 @@ const getTables = async () => {
state.tables = [];
state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db });
} catch (e) {
//
} finally {
state.loading = false;
}
@@ -317,7 +318,9 @@ const dropTable = async (row: any) => {
state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db });
},
});
} catch (err) {}
} catch (err) {
//
}
};
// 打开编辑表

View File

@@ -1,6 +1,6 @@
import Api from '@/common/Api';
import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage';
import { joinClientParams } from '@/common/request';
export const machineApi = {
// 获取权限列表

View File

@@ -274,11 +274,12 @@ import { ref, toRefs, reactive, onMounted, computed } from 'vue';
import { ElMessage, ElMessageBox, ElInput } from 'element-plus';
import { machineApi } from '../api';
import { joinClientParams } from '@/common/utils/storage';
import { joinClientParams } from '@/common/request';
import config from '@/common/config';
import { isTrue } from '@/common/assert';
import MachineFileContent from './MachineFileContent.vue';
import { notBlank } from '@/common/assert';
import { getToken } from '@/common/utils/storage';
const props = defineProps({
machineId: { type: Number },

View File

@@ -179,7 +179,7 @@ import { dateFormat } from '@/common/utils/date';
import { storeToRefs } from 'pinia';
import { useUserInfo } from '@/store/userInfo';
import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage';
import { joinClientParams } from '@/common/request';
const { userInfo } = storeToRefs(useUserInfo());
const state = reactive({

View File

@@ -773,10 +773,10 @@ echarts@^5.4.0:
tslib "2.3.0"
zrender "5.4.0"
element-plus@^2.4.0:
version "2.4.0"
resolved "https://registry.npmmirror.com/element-plus/-/element-plus-2.4.0.tgz#e79249ac4c0a606d377c2f31ad553aa992286fe3"
integrity sha512-yJEa8LXkGOOgkfkeqMMEdeX/Dc8EH9qPcRuX91dlhSXxgCKKbp9tH3QFTOG99ibZsrN/Em62nh7ddvbc7I1frw==
element-plus@^2.4.1:
version "2.4.1"
resolved "https://registry.npmmirror.com/element-plus/-/element-plus-2.4.1.tgz#8a5faa69e856d82494b94d77fb485d9e727c8bc1"
integrity sha512-t7nl+vQlkBKVk1Ag6AufSDyFV8YIXxTFsaya4Nz/0tiRlcz65WPN4WMFeNURuFJleu1HLNtP4YyQKMuS7El8uA==
dependencies:
"@ctrl/tinycolor" "^3.4.1"
"@element-plus/icons-vue" "^2.0.6"

View File

@@ -119,5 +119,5 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
func (a *AccountLogin) Logout(rc *req.Ctx) {
req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id)
ws.CloseClient(rc.LoginAccount.ClientUuid)
ws.CloseClient(ws.UserId(rc.LoginAccount.Id))
}

View File

@@ -2,8 +2,6 @@ package api
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"io"
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
@@ -15,15 +13,18 @@ import (
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ginx"
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/uniqueid"
"mayfly-go/pkg/ws"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
)
type Db struct {
@@ -126,29 +127,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
isMulti := len(sqls) > 1
var execResAll *application.DbSqlExecRes
progressId := uniqueid.IncrementID()
executedStatements := 0
progressTitle := fmt.Sprintf("%s/%s", dbConn.Info.Name, dbConn.Info.Database)
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: progressTitle,
ExecutedStatements: executedStatements,
Terminated: true,
}).WithCategory(progressCategory))
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
for _, s := range sqls {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: progressTitle,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
}
executedStatements++
s = stringx.TrimSpaceAndBr(s)
// 多条执行,如果有查询语句,则跳过
if isMulti && strings.HasPrefix(strings.ToLower(s), "select") {
@@ -178,7 +157,7 @@ const progressCategory = "execSqlFileProgress"
// progressMsg sql文件执行进度消息
type progressMsg struct {
Id uint64 `json:"id"`
Id string `json:"id"`
Title string `json:"title"`
ExecutedStatements int `json:"executedStatements"`
Terminated bool `json:"terminated"`
@@ -195,6 +174,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
filename := file.FileName()
dbId := getDbId(g)
dbName := getDbName(g)
clientId := g.Query("clientId")
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
@@ -209,7 +189,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
errInfo = t
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
}
}()
@@ -226,9 +206,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
tokenizer := sqlparser.NewReaderTokenizer(file,
sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
progressId := uniqueid.IncrementID()
executedStatements := 0
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
progressId := stringx.Rand(32)
defer ws.SendJsonMsg(ws.UserId(rc.LoginAccount.Id), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
@@ -239,7 +219,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
for {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
ws.SendJsonMsg(ws.UserId(rc.LoginAccount.Id), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
@@ -252,21 +232,19 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
if err == io.EOF {
break
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
return
}
biz.ErrIsNilAppendErr(err, "%s")
const prefixUse = "use "
const prefixUSE = "USE "
if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) {
var stmt sqlparser.Statement
stmt, err = sqlparser.Parse(sql)
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
}
biz.ErrIsNilAppendErr(err, "%s")
stmtUse, ok := stmt.(*sqlparser.Use)
// 最终执行结果以数据库直接结果为准
if !ok {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql)))
logx.Warnf("sql解析失败: %s", sql)
}
dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
@@ -281,12 +259,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
_, err = dbConn.Exec(sql)
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
biz.ErrIsNilAppendErr(err, "%s")
}
}
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成%s", rc.ReqParam)))
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成%s", rc.ReqParam)).WithClientId(clientId))
}
// 数据库dump

View File

@@ -14,6 +14,8 @@ type SysMsg struct {
Category string `json:"category"` // 消息类别
Title string `json:"title"` // 消息标题
Msg string `json:"msg"` // 消息内容
ClientId string
}
func (sm *SysMsg) WithTitle(title string) *SysMsg {
@@ -31,6 +33,11 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg {
return sm
}
func (sm *SysMsg) WithClientId(clientId string) *SysMsg {
sm.ClientId = clientId
return sm
}
// 普通消息
func InfoSysMsg(title string, msg any) *SysMsg {
return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)}

View File

@@ -40,5 +40,5 @@ func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *dto.SysMsg) {
now := time.Now()
msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
a.msgRepo.Insert(msg)
ws.SendJsonMsg(la.ClientUuid, wmsg)
ws.SendJsonMsg(ws.UserId(la.Id), wmsg.ClientId, wmsg)
}

View File

@@ -26,18 +26,19 @@ func (s *System) ConnectWs(g *gin.Context) {
}
}()
if err != nil {
panic(biz.NewBizErr("升级websocket失败"))
}
biz.ErrIsNilAppendErr(err, "%s")
clientId := g.Query("clientId")
biz.NotEmpty(clientId, "clientId不能为空")
// 权限校验
rc := req.NewCtxWithGin(g)
if err = req.PermissionHandler(rc); err != nil {
panic(biz.NewBizErr("没有权限"))
panic("sys ws连接没有权限")
}
// 登录账号信息
la := rc.LoginAccount
if la != nil {
ws.AddClient(la.Id, la.ClientUuid, wsConn)
ws.AddClient(ws.UserId(la.Id), clientId, wsConn)
}
}

View File

@@ -3,7 +3,4 @@ package model
type LoginAccount struct {
Id uint64
Username string
// ClientUuid 客户端UUID
ClientUuid string
}

View File

@@ -60,16 +60,10 @@ func PermissionHandler(rc *Ctx) error {
return biz.PermissionErr
}
}
clientUuid := rc.GinCtx.Request.Header.Get("Client-Uuid")
// header不存在则从查询参数token中获取
if clientUuid == "" {
clientUuid = rc.GinCtx.Query("clientUuid")
}
if rc.LoginAccount == nil {
rc.LoginAccount = &model.LoginAccount{
Id: userId,
Username: userName,
ClientUuid: clientUuid,
}
}
return nil

View File

@@ -114,3 +114,14 @@ func ArrayReduce[T any, V any](arr []T, initialValue V, reducer func(V, T) V) V
}
return value
}
// 数组元素移除操作
func ArrayRemoveFunc[T any](arr []T, isDeleteFunc func(T) bool) []T {
var newArr []T
for _, a := range arr {
if !isDeleteFunc(a) {
newArr = append(newArr, a)
}
}
return newArr
}

View File

@@ -1,9 +0,0 @@
package uniqueid
import "sync/atomic"
var id uint64 = 0
func IncrementID() uint64 {
return atomic.AddUint64(&id, 1)
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/stringx"
"time"
"github.com/gorilla/websocket"
@@ -19,17 +18,15 @@ type ReadMsgHandlerFunc func([]byte)
type Client struct {
ClientId string // 标识ID
UserId UserId // 用户ID
ClientUuid string // 客户端UUID
WsConn *websocket.Conn // 用户连接
ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数
ReadMsgHandler ReadMsgHandlerFunc // 读取消息处理函数
}
func NewClient(userId UserId, clientUuid string, socket *websocket.Conn) *Client {
func NewClient(userId UserId, clientId string, socket *websocket.Conn) *Client {
cli := &Client{
ClientId: stringx.Rand(16),
ClientId: clientId,
UserId: userId,
ClientUuid: clientUuid,
WsConn: socket,
}
@@ -37,7 +34,7 @@ func NewClient(userId UserId, clientUuid string, socket *websocket.Conn) *Client
}
func (c *Client) WithReadHandlerFunc(readMsgHandlerFunc ReadMsgHandlerFunc) *Client {
c.ReadMsgHander = readMsgHandlerFunc
c.ReadMsgHandler = readMsgHandlerFunc
return c
}
@@ -58,8 +55,8 @@ func (c *Client) Read() {
return
}
}
if c.ReadMsgHander != nil {
c.ReadMsgHander(data)
if c.ReadMsgHandler != nil {
c.ReadMsgHandler(data)
}
}
}()
@@ -67,7 +64,7 @@ func (c *Client) Read() {
// 向客户端写入消息
func (c *Client) WriteMsg(msg *Msg) error {
logx.Debugf("发送消息: toUid=%v, data=%v", c.UserId, msg.Data)
logx.Debugf("发送消息: toUserId=%v, toClientId=%s, data=%v", c.UserId, c.ClientId, msg.Data)
if msg.Type == JsonMsg {
bytes, _ := json.Marshal(msg.Data)

View File

@@ -9,9 +9,28 @@ import (
// 心跳间隔
const heartbeatInterval = 25 * time.Second
// 单个用户的全部的连接, key->clientId, value->Client
type UserClients map[string]*Client
func (ucs UserClients) GetByCid(clientId string) *Client {
return ucs[clientId]
}
func (ucs UserClients) AddClient(client *Client) {
ucs[client.ClientId] = client
}
func (ucs UserClients) DeleteByCid(clientId string) {
delete(ucs, clientId)
}
func (ucs UserClients) Count() int {
return len(ucs)
}
// 连接管理
type ClientManager struct {
ClientMap map[string]*Client // 全部的连接, key->token, value->&client
UserClientsMap map[UserId]UserClients // 全部的用户连接, key->userid, value->UserClients
RwLock sync.RWMutex // 读写锁
ConnectChan chan *Client // 连接处理
@@ -21,7 +40,7 @@ type ClientManager struct {
func NewClientManager() (clientManager *ClientManager) {
return &ClientManager{
ClientMap: make(map[string]*Client),
UserClientsMap: make(map[UserId]UserClients),
ConnectChan: make(chan *Client, 10),
DisConnectChan: make(chan *Client, 10),
MsgChan: make(chan *Msg, 100),
@@ -58,47 +77,45 @@ func (manager *ClientManager) CloseClient(client *Client) {
}
// 根据用户id关闭客户端连接
func (manager *ClientManager) CloseByClientUuid(clientUuid string) {
manager.CloseClient(manager.GetByClientUuid(clientUuid))
func (manager *ClientManager) CloseByUid(userId UserId) {
for _, client := range manager.GetByUid(userId) {
manager.CloseClient(client)
}
}
// 获取所有的客户端
func (manager *ClientManager) AllClient() map[string]*Client {
func (manager *ClientManager) AllUserClient() map[UserId]UserClients {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
return manager.ClientMap
return manager.UserClientsMap
}
// 通过userId获取
func (manager *ClientManager) GetByUid(userId UserId) *Client {
// 通过userId获取用户所有客户端信息
func (manager *ClientManager) GetByUid(userId UserId) UserClients {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
for _, client := range manager.ClientMap {
if userId == client.UserId {
return client
return manager.UserClientsMap[userId]
}
// 通过userId和clientId获取客户端信息
func (manager *ClientManager) GetByUidAndCid(uid UserId, clientId string) *Client {
if clients := manager.GetByUid(uid); clients != nil {
return clients.GetByCid(clientId)
}
return nil
}
// 通过userId获取
func (manager *ClientManager) GetByClientUuid(uuid string) *Client {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
return manager.ClientMap[uuid]
}
// 客户端数量
func (manager *ClientManager) Count() int {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
return len(manager.ClientMap)
return len(manager.UserClientsMap)
}
// 发送json数据给指定用户
func (manager *ClientManager) SendJsonMsg(clientUuid string, data any) {
manager.MsgChan <- &Msg{ToClientUuid: clientUuid, Data: data, Type: JsonMsg}
func (manager *ClientManager) SendJsonMsg(userId UserId, clientId string, data any) {
manager.MsgChan <- &Msg{ToUserId: userId, ToClientId: clientId, Data: data, Type: JsonMsg}
}
// 监听并发送给客户端信息
@@ -106,10 +123,22 @@ func (manager *ClientManager) WriteMessage() {
go func() {
for {
msg := <-manager.MsgChan
if cli := manager.GetByClientUuid(msg.ToClientUuid); cli != nil {
if err := cli.WriteMsg(msg); err != nil {
manager.CloseClient(cli)
uid := msg.ToUserId
cid := msg.ToClientId
// 客户端id不为空则向指定客户端发送消息即可
if cid != "" {
cli := manager.GetByUidAndCid(uid, cid)
if cli != nil {
cli.WriteMsg(msg)
} else {
logx.Warnf("[uid=%v, cid=%s]的ws连接不存在", uid, cid)
}
continue
}
// cid为空则向该用户所有客户端发送该消息
for _, cli := range manager.GetByUid(uid) {
cli.WriteMsg(msg)
}
}
}()
@@ -123,15 +152,17 @@ func (manager *ClientManager) HeartbeatTimer() {
for {
<-ticker.C
//发送心跳
for userId, cli := range manager.AllClient() {
for userId, clis := range manager.AllUserClient() {
for _, cli := range clis {
if cli == nil || cli.WsConn == nil {
continue
}
if err := cli.Ping(); err != nil {
manager.CloseClient(cli)
logx.Debugf("WS发送心跳失败: %v 总连接数:%d", userId, Manager.Count())
logx.Debugf("WS发送心跳失败: uid=%v, cid=%s, usercount=%d", userId, cli.ClientId, Manager.Count())
} else {
logx.Debugf("WS发送心跳成功: uid=%v", userId)
logx.Debugf("WS发送心跳成功: uid=%v, cid=%s", userId, cli.ClientId)
}
}
}
}
@@ -141,12 +172,12 @@ func (manager *ClientManager) HeartbeatTimer() {
// 处理建立连接
func (manager *ClientManager) doConnect(client *Client) {
cli := manager.GetByClientUuid(client.ClientUuid)
cli := manager.GetByUidAndCid(client.UserId, client.ClientId)
if cli != nil {
manager.doDisconnect(cli)
}
manager.addClient2Map(client)
logx.Debugf("WS客户端已连接: uid=%d, count=%d", client.UserId, manager.Count())
manager.addUserClient2Map(client)
logx.Debugf("WS客户端已连接: uid=%d, cid=%s, usercount=%d", client.UserId, client.ClientId, manager.Count())
}
// 处理断开连接
@@ -156,18 +187,32 @@ func (manager *ClientManager) doDisconnect(client *Client) {
_ = client.WsConn.Close()
client.WsConn = nil
}
manager.delClient4Map(client)
logx.Debugf("WS客户端已断开: uid=%d, count=%d", client.UserId, Manager.Count())
manager.delUserClient4Map(client)
logx.Debugf("WS客户端已断开: uid=%d, cid=%s, usercount=%d", client.UserId, client.ClientId, Manager.Count())
}
func (manager *ClientManager) addClient2Map(client *Client) {
func (manager *ClientManager) addUserClient2Map(client *Client) {
manager.RwLock.Lock()
defer manager.RwLock.Unlock()
manager.ClientMap[client.ClientUuid] = client
userClients := manager.UserClientsMap[client.UserId]
if userClients == nil {
userClients = make(UserClients)
manager.UserClientsMap[client.UserId] = userClients
}
userClients.AddClient(client)
}
func (manager *ClientManager) delClient4Map(client *Client) {
func (manager *ClientManager) delUserClient4Map(client *Client) {
manager.RwLock.Lock()
defer manager.RwLock.Unlock()
delete(manager.ClientMap, client.ClientUuid)
userClients := manager.UserClientsMap[client.UserId]
if userClients != nil {
userClients.DeleteByCid(client.ClientId)
// 如果用户所有客户端都关闭则移除manager中的UserClientsMap值
if userClients.Count() == 0 {
delete(manager.UserClientsMap, client.UserId)
}
}
}

View File

@@ -11,9 +11,9 @@ const (
// 消息信息
type Msg struct {
ToUserId UserId
Data any
ToUserId UserId // 用户id
ToClientId string // 客户端id
Type MsgType // 消息类型
ToClientUuid string
Data any
}

View File

@@ -21,21 +21,21 @@ func init() {
}
// 添加ws客户端
func AddClient(userId uint64, clientUuid string, conn *websocket.Conn) *Client {
if len(clientUuid) == 0 {
func AddClient(userId UserId, clientId string, conn *websocket.Conn) *Client {
if len(clientId) == 0 {
return nil
}
cli := NewClient(UserId(userId), clientUuid, conn)
cli := NewClient(UserId(userId), clientId, conn)
cli.Read()
Manager.AddClient(cli)
return cli
}
func CloseClient(clientUuid string) {
Manager.CloseByClientUuid(clientUuid)
func CloseClient(uid UserId) {
Manager.CloseByUid(uid)
}
// 对指定用户发送json类型消息
func SendJsonMsg(clientUuid string, msg any) {
Manager.SendJsonMsg(clientUuid, msg)
func SendJsonMsg(userId UserId, clientId string, msg any) {
Manager.SendJsonMsg(userId, clientId, msg)
}