refactor: websocket支持单用户多连接

This commit is contained in:
meilin.huang
2023-10-19 19:00:23 +08:00
parent 747ea6404d
commit 2b91bbe185
31 changed files with 365 additions and 263 deletions

View File

@@ -1,13 +1,13 @@
# 🌈mayfly-go # 🌈mayfly-go
<p align="center"> <p align="center">
<a href="https://gitee.com/objs/mayfly-go" target="_blank"> <a href="https://gitee.com/dromara/mayfly-go" target="_blank">
<img src="https://gitee.com/objs/mayfly-go/badge/star.svg?theme=white" alt="star"/> <img src="https://gitee.com/dromara/mayfly-go/badge/star.svg?theme=white" alt="star"/>
<img src="https://gitee.com/objs/mayfly-go/badge/fork.svg" alt="fork"/> <img src="https://gitee.com/dromara/mayfly-go/badge/fork.svg" alt="fork"/>
</a> </a>
<a href="https://github.com/may-fly/mayfly-go" target="_blank"> <a href="https://github.com/dromara/mayfly-go" target="_blank">
<img src="https://img.shields.io/github/stars/may-fly/mayfly-go.svg?style=social" alt="github star"/> <img src="https://img.shields.io/github/stars/dromara/mayfly-go.svg?style=social" alt="github star"/>
<img src="https://img.shields.io/github/forks/may-fly/mayfly-go.svg?style=social" alt="github fork"/> <img src="https://img.shields.io/github/forks/dromara/mayfly-go.svg?style=social" alt="github fork"/>
</a> </a>
<a href="https://hub.docker.com/r/mayflygo/mayfly-go/tags" target="_blank"> <a href="https://hub.docker.com/r/mayflygo/mayfly-go/tags" target="_blank">
<img src="https://img.shields.io/docker/pulls/mayflygo/mayfly-go.svg?label=docker%20pulls&color=fac858" alt="docker pulls"/> <img src="https://img.shields.io/docker/pulls/mayflygo/mayfly-go.svg?label=docker%20pulls&color=fac858" alt="docker pulls"/>
@@ -100,4 +100,4 @@ http://go.mayfly.run
#### 💌 支持作者 #### 💌 支持作者
如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/may-fly/mayfly-go">Github</a> 或者 <a target="_blank" href="https://gitee.com/objs/mayfly-go">Gitee</a> 帮我点个 ⭐ Star这将是对我极大的鼓励与支持。 如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/dromara/mayfly-go">Github</a> 或者 <a target="_blank" href="https://gitee.com/dromara/mayfly-go">Gitee</a> 帮我点个 ⭐ Star这将是对我极大的鼓励与支持。

View File

@@ -15,7 +15,7 @@
"countup.js": "^2.7.0", "countup.js": "^2.7.0",
"cropperjs": "^1.5.11", "cropperjs": "^1.5.11",
"echarts": "^5.4.0", "echarts": "^5.4.0",
"element-plus": "^2.4.0", "element-plus": "^2.4.1",
"jsencrypt": "^3.3.1", "jsencrypt": "^3.3.1",
"lodash": "^4.17.21", "lodash": "^4.17.21",
"mitt": "^3.0.1", "mitt": "^3.0.1",

View File

@@ -1,7 +1,7 @@
import router from '../router'; import router from '../router';
import Axios from 'axios'; import Axios from 'axios';
import config from './config'; import config from './config';
import { getClientUuid, getToken, joinClientParams } from './utils/storage'; import { getClientId, getToken } from './utils/storage';
import { templateResolve } from './utils/string'; import { templateResolve } from './utils/string';
import { ElMessage } from 'element-plus'; import { ElMessage } from 'element-plus';
@@ -54,7 +54,7 @@ service.interceptors.request.use(
if (token) { if (token) {
// 设置token // 设置token
config.headers['Authorization'] = token; config.headers['Authorization'] = token;
config.headers['Client-Uuid'] = getClientUuid(); config.headers['ClientId'] = getClientId();
} }
return config; return config;
}, },
@@ -180,6 +180,11 @@ function getApiUrl(url: string) {
return baseUrl + url + '?' + joinClientParams(); return baseUrl + url + '?' + joinClientParams();
} }
// 组装客户端参数,包括 token 和 clientId
export function joinClientParams(): string {
return `token=${getToken()}&clientId=${getClientId()}`;
}
export default { export default {
request, request,
get, get,

View File

@@ -1,76 +0,0 @@
import Config from './config';
import { ElNotification, NotificationHandle } from 'element-plus';
import SocketBuilder from './SocketBuilder';
import { getToken, joinClientParams } from '@/common/utils/storage';
import { createVNode, reactive } from 'vue';
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
export default {
/**
* 全局系统消息websocket
*/
sysMsgSocket() {
const token = getToken();
if (!token) {
return null;
}
const messageTypes = {
0: 'error',
1: 'success',
2: 'info',
};
const notifyMap: Map<Number, any> = new Map();
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
return SocketBuilder.builder(sysMsgUrl)
.message((event: { data: string }) => {
const message = JSON.parse(event.data);
const type = messageTypes[message.type];
switch (message.category) {
case 'execSqlFileProgress':
const content = JSON.parse(message.msg);
const id = content.id;
let progress = notifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
notifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!notifyMap.has(id)) {
const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: vNodeMessage,
type: type,
showClose: false,
});
notifyMap.set(id, progress);
}
break;
default:
ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: type,
});
break;
}
})
.open((event: any) => console.log(event))
.build();
},
};

View File

@@ -0,0 +1,99 @@
import Config from './config';
import { ElNotification } from 'element-plus';
import SocketBuilder from './SocketBuilder';
import { getToken } from '@/common/utils/storage';
import { joinClientParams } from './request';
class SysSocket {
/**
* socket连接
*/
socket: any;
/**
* key -> 消息类别value -> 消息对应的处理器函数
*/
categoryHandlers: Map<string, any> = new Map();
/**
* 消息类型
*/
messageTypes = {
0: 'error',
1: 'success',
2: 'info',
};
/**
* 初始化全局系统消息websocket
*/
init() {
// 存在则不需要重新建立连接
if (this.socket) {
return;
}
const token = getToken();
if (!token) {
return null;
}
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
this.socket = SocketBuilder.builder(sysMsgUrl)
.message((event: { data: string }) => {
const message = JSON.parse(event.data);
// 存在消息类别对应的处理器,则进行处理,否则进行默认通知处理
const handler = this.categoryHandlers.get(message.category);
if (handler) {
handler(message);
return;
}
const type = this.getMsgType(message.type);
ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: type,
});
})
.open((event: any) => console.log(event))
.close(() => {
console.log('close sys socket');
this.socket = null;
})
.build();
}
destory() {
this.socket.close();
this.socket = null;
this.categoryHandlers.clear();
}
/**
* 注册消息处理函数
*
* @param category 消息类别
* @param handlerFunc 消息处理函数
*/
registerMsgHandler(category: any, handlerFunc: any) {
if (this.categoryHandlers.has(category)) {
console.log(`${category}该类别消息处理器已存在...`);
return;
}
if (typeof handlerFunc != 'function') {
throw new Error('message handler需为函数');
}
this.categoryHandlers.set(category, handlerFunc);
}
getMsgType(msgType: any) {
return this.messageTypes[msgType];
}
}
// 全局系统消息websocket;
const sysSocket = new SysSocket();
export default sysSocket;

View File

@@ -1,9 +1,9 @@
import { v1 as uuidv1 } from 'uuid'; import { randomUuid } from './string';
const TokenKey = 'token'; const TokenKey = 'token';
const UserKey = 'user'; const UserKey = 'user';
const TagViewsKey = 'tagViews'; const TagViewsKey = 'tagViews';
const ClientUuid = 'clientUuid' const ClientIdKey = 'clientId';
// 获取请求token // 获取请求token
export function getToken(): string { export function getToken(): string {
@@ -52,18 +52,13 @@ export function removeTagViews() {
} }
// 获取客户端UUID // 获取客户端UUID
export function getClientUuid(): string { export function getClientId(): string {
let uuid = getSession(ClientUuid) let uuid = getSession(ClientIdKey);
if (uuid == null) { if (uuid == null) {
uuid = uuidv1() uuid = randomUuid();
setSession(ClientUuid, uuid) setSession(ClientIdKey, uuid);
} }
return uuid return uuid;
}
// 组装客户端参数,包括 token 和 clientUuid
export function joinClientParams(): string {
return `token=${getToken()}&clientUuid=${getClientUuid()}`
} }
// 1. localStorage // 1. localStorage

View File

@@ -1,3 +1,5 @@
import { v1 as uuidv1 } from 'uuid';
/** /**
* 模板字符串解析template = 'hahaha{name}_{id}' ,param = {name: 'hh', id: 1} * 模板字符串解析template = 'hahaha{name}_{id}' ,param = {name: 'hh', id: 1}
* 解析后为 hahahahh_1 * 解析后为 hahahahh_1
@@ -129,3 +131,11 @@ export function getContentWidth(content: any): number {
// } // }
return flexWidth; return flexWidth;
} }
/**
*
* @returns uuid
*/
export function randomUuid() {
return uuidv1();
}

View File

@@ -6,7 +6,7 @@ import { templateResolve } from '@/common/utils/string';
import { NextLoading } from '@/common/utils/loading'; import { NextLoading } from '@/common/utils/loading';
import { dynamicRoutes, staticRoutes, pathMatch } from './route'; import { dynamicRoutes, staticRoutes, pathMatch } from './route';
import openApi from '@/common/openApi'; import openApi from '@/common/openApi';
import sockets from '@/common/sockets'; import syssocket from '@/common/syssocket';
import pinia from '@/store/index'; import pinia from '@/store/index';
import { useThemeConfig } from '@/store/themeConfig'; import { useThemeConfig } from '@/store/themeConfig';
import { useUserInfo } from '@/store/userInfo'; import { useUserInfo } from '@/store/userInfo';
@@ -179,7 +179,6 @@ export async function initRouter() {
} }
} }
let SysWs: any;
let loadRouter = false; let loadRouter = false;
// 路由加载前 // 路由加载前
@@ -204,10 +203,7 @@ router.beforeEach(async (to, from, next) => {
resetRoute(); resetRoute();
NProgress.done(); NProgress.done();
if (SysWs) { syssocket.destory();
SysWs.close();
SysWs = undefined;
}
return; return;
} }
if (token && to.path === '/login') { if (token && to.path === '/login') {
@@ -217,9 +213,10 @@ router.beforeEach(async (to, from, next) => {
} }
// 终端不需要连接系统websocket消息 // 终端不需要连接系统websocket消息
if (!SysWs && to.path != '/machine/terminal') { if (to.path != '/machine/terminal') {
SysWs = sockets.sysMsgSocket(); syssocket.init();
} }
// 不存在路由避免刷新页面找不到路由并且未加载过避免token过期导致获取权限接口报权限不足无限获取则重新初始化路由 // 不存在路由避免刷新页面找不到路由并且未加载过避免token过期导致获取权限接口报权限不足无限获取则重新初始化路由
if (useRoutesList().routesList.length == 0 && !loadRouter) { if (useRoutesList().routesList.length == 0 && !loadRouter) {
await initRouter(); await initRouter();

View File

@@ -171,6 +171,8 @@ const changeDatabase = () => {
}; };
const getAllDatabase = async () => { const getAllDatabase = async () => {
// 清空数据库列表,可能已经有选择库了
state.databaseList = [];
if (state.form.instanceId > 0) { if (state.form.instanceId > 0) {
state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId }); state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId });
} }

View File

@@ -172,7 +172,7 @@ import { ref, toRefs, reactive, onMounted, defineAsyncComponent } from 'vue';
import { ElMessage, ElMessageBox } from 'element-plus'; import { ElMessage, ElMessageBox } from 'element-plus';
import { dbApi } from './api'; import { dbApi } from './api';
import config from '@/common/config'; import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage'; import { joinClientParams } from '@/common/request';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
import { Search as SearchIcon } from '@element-plus/icons-vue'; import { Search as SearchIcon } from '@element-plus/icons-vue';
import { dateFormat } from '@/common/utils/date'; import { dateFormat } from '@/common/utils/date';
@@ -355,7 +355,9 @@ const deleteDb = async () => {
await dbApi.deleteDb.request({ id: state.selectionData.map((x: any) => x.id).join(',') }); await dbApi.deleteDb.request({ id: state.selectionData.map((x: any) => x.id).join(',') });
ElMessage.success('删除成功'); ElMessage.success('删除成功');
search(); search();
} catch (err) {} } catch (err) {
//
}
}; };
const onShowSqlExec = async (row: any) => { const onShowSqlExec = async (row: any) => {

View File

@@ -387,6 +387,9 @@ const addQueryTab = async (inst: any, db: string, sqlName: string = '') => {
dbs: inst.dbs, dbs: inst.dbs,
}; };
state.tabs.set(label, tab); state.tabs.set(label, tab);
// 注册当前sql编辑框提示词
registerDbCompletionItemProvider('sql', tab.dbId, tab.db, tab.params.dbs);
}; };
const onRemoveTab = (targetName: string) => { const onRemoveTab = (targetName: string) => {

View File

@@ -88,7 +88,7 @@
<script lang="ts" setup> <script lang="ts" setup>
import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue'; import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue';
import { getToken, joinClientParams } from '@/common/utils/storage'; import { getToken } from '@/common/utils/storage';
import { isTrue, notBlank } from '@/common/assert'; import { isTrue, notBlank } from '@/common/assert';
import { format as sqlFormatter } from 'sql-formatter'; import { format as sqlFormatter } from 'sql-formatter';
import config from '@/common/config'; import config from '@/common/config';
@@ -104,6 +104,12 @@ import { dateStrFormat } from '@/common/utils/date';
import { dbApi } from '../../api'; import { dbApi } from '../../api';
import MonacoEditor from '@/components/monaco/MonacoEditor.vue'; import MonacoEditor from '@/components/monaco/MonacoEditor.vue';
import { joinClientParams } from '@/common/request';
import { createVNode } from 'vue';
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
import { ElNotification } from 'element-plus';
import syssocket from '@/common/syssocket';
const emits = defineEmits(['saveSqlSuccess', 'deleteSqlSuccess']); const emits = defineEmits(['saveSqlSuccess', 'deleteSqlSuccess']);
@@ -384,7 +390,9 @@ const deleteSql = async () => {
await dbApi.deleteDbSql.request({ id: dbId, db: db, name: sqlName }); await dbApi.deleteDbSql.request({ id: dbId, db: db, name: sqlName });
ElMessage.success('删除成功'); ElMessage.success('删除成功');
emits('deleteSqlSuccess', dbId, db); emits('deleteSqlSuccess', dbId, db);
} catch (err) {} } catch (err) {
//
}
}; };
/** /**
@@ -472,8 +480,45 @@ const exportData = () => {
); );
}; };
/**
* sql文件执行进度通知缓存
*/
const sqlExecNotifyMap: Map<string, any> = new Map();
const beforeUpload = (file: File) => { const beforeUpload = (file: File) => {
ElMessage.success(`'${file.name}' 正在上传执行, 请关注结果通知`); ElMessage.success(`'${file.name}' 正在上传执行, 请关注结果通知`);
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
const content = JSON.parse(message.msg);
const id = content.id;
let progress = sqlExecNotifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
sqlExecNotifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!sqlExecNotifyMap.has(id)) {
const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: vNodeMessage,
type: syssocket.getMsgType(message.type),
showClose: false,
});
sqlExecNotifyMap.set(id, progress);
}
});
}; };
// 执行sql成功 // 执行sql成功

View File

@@ -345,7 +345,6 @@ export class DbInst {
// 获取该列中最长的数据(内容) // 获取该列中最长的数据(内容)
let maxWidthText = ''; let maxWidthText = '';
let maxWidthValue;
// 获取该列中最长的数据(内容) // 获取该列中最长的数据(内容)
for (let i = 0; i < tableData.length; i++) { for (let i = 0; i < tableData.length; i++) {
let nowValue = tableData[i][prop]; let nowValue = tableData[i][prop];
@@ -356,7 +355,6 @@ export class DbInst {
let nowText = nowValue + ''; let nowText = nowValue + '';
if (nowText.length > maxWidthText.length) { if (nowText.length > maxWidthText.length) {
maxWidthText = nowText; maxWidthText = nowText;
maxWidthValue = nowValue;
} }
} }
const contentWidth: number = getTextWidth(maxWidthText) + 15; const contentWidth: number = getTextWidth(maxWidthText) + 15;

View File

@@ -118,13 +118,13 @@
</template> </template>
<script lang="ts" setup> <script lang="ts" setup>
import { toRefs, reactive, watch, computed, onMounted, defineAsyncComponent, nextTick } from 'vue'; import { toRefs, reactive, watch, computed, onMounted, defineAsyncComponent } from 'vue';
import { ElMessageBox } from 'element-plus'; import { ElMessageBox } from 'element-plus';
import { formatByteSize } from '@/common/utils/format'; import { formatByteSize } from '@/common/utils/format';
import { dbApi } from '../api'; import { dbApi } from '../api';
import SqlExecBox from '../component/SqlExecBox'; import SqlExecBox from '../component/SqlExecBox';
import config from '@/common/config'; import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage'; import { joinClientParams } from '@/common/request';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue')); const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue'));
@@ -209,7 +209,7 @@ onMounted(async () => {
getTables(); getTables();
}); });
watch(props, async (newValue: any) => { watch(props, async () => {
await getTables(); await getTables();
}); });
@@ -239,6 +239,7 @@ const getTables = async () => {
state.tables = []; state.tables = [];
state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db }); state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db });
} catch (e) { } catch (e) {
//
} finally { } finally {
state.loading = false; state.loading = false;
} }
@@ -317,7 +318,9 @@ const dropTable = async (row: any) => {
state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db }); state.tables = await dbApi.tableInfos.request({ id: props.dbId, db: props.db });
}, },
}); });
} catch (err) {} } catch (err) {
//
}
}; };
// 打开编辑表 // 打开编辑表

View File

@@ -1,6 +1,6 @@
import Api from '@/common/Api'; import Api from '@/common/Api';
import config from '@/common/config'; import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage'; import { joinClientParams } from '@/common/request';
export const machineApi = { export const machineApi = {
// 获取权限列表 // 获取权限列表

View File

@@ -274,11 +274,12 @@ import { ref, toRefs, reactive, onMounted, computed } from 'vue';
import { ElMessage, ElMessageBox, ElInput } from 'element-plus'; import { ElMessage, ElMessageBox, ElInput } from 'element-plus';
import { machineApi } from '../api'; import { machineApi } from '../api';
import { joinClientParams } from '@/common/utils/storage'; import { joinClientParams } from '@/common/request';
import config from '@/common/config'; import config from '@/common/config';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
import MachineFileContent from './MachineFileContent.vue'; import MachineFileContent from './MachineFileContent.vue';
import { notBlank } from '@/common/assert'; import { notBlank } from '@/common/assert';
import { getToken } from '@/common/utils/storage';
const props = defineProps({ const props = defineProps({
machineId: { type: Number }, machineId: { type: Number },

View File

@@ -179,7 +179,7 @@ import { dateFormat } from '@/common/utils/date';
import { storeToRefs } from 'pinia'; import { storeToRefs } from 'pinia';
import { useUserInfo } from '@/store/userInfo'; import { useUserInfo } from '@/store/userInfo';
import config from '@/common/config'; import config from '@/common/config';
import { joinClientParams } from '@/common/utils/storage'; import { joinClientParams } from '@/common/request';
const { userInfo } = storeToRefs(useUserInfo()); const { userInfo } = storeToRefs(useUserInfo());
const state = reactive({ const state = reactive({

View File

@@ -773,10 +773,10 @@ echarts@^5.4.0:
tslib "2.3.0" tslib "2.3.0"
zrender "5.4.0" zrender "5.4.0"
element-plus@^2.4.0: element-plus@^2.4.1:
version "2.4.0" version "2.4.1"
resolved "https://registry.npmmirror.com/element-plus/-/element-plus-2.4.0.tgz#e79249ac4c0a606d377c2f31ad553aa992286fe3" resolved "https://registry.npmmirror.com/element-plus/-/element-plus-2.4.1.tgz#8a5faa69e856d82494b94d77fb485d9e727c8bc1"
integrity sha512-yJEa8LXkGOOgkfkeqMMEdeX/Dc8EH9qPcRuX91dlhSXxgCKKbp9tH3QFTOG99ibZsrN/Em62nh7ddvbc7I1frw== integrity sha512-t7nl+vQlkBKVk1Ag6AufSDyFV8YIXxTFsaya4Nz/0tiRlcz65WPN4WMFeNURuFJleu1HLNtP4YyQKMuS7El8uA==
dependencies: dependencies:
"@ctrl/tinycolor" "^3.4.1" "@ctrl/tinycolor" "^3.4.1"
"@element-plus/icons-vue" "^2.0.6" "@element-plus/icons-vue" "^2.0.6"

View File

@@ -119,5 +119,5 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
func (a *AccountLogin) Logout(rc *req.Ctx) { func (a *AccountLogin) Logout(rc *req.Ctx) {
req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id) req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id)
ws.CloseClient(rc.LoginAccount.ClientUuid) ws.CloseClient(ws.UserId(rc.LoginAccount.Id))
} }

View File

@@ -2,8 +2,6 @@ package api
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"io" "io"
"mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/api/vo"
@@ -15,15 +13,18 @@ import (
"mayfly-go/pkg/biz" "mayfly-go/pkg/biz"
"mayfly-go/pkg/ginx" "mayfly-go/pkg/ginx"
"mayfly-go/pkg/gormx" "mayfly-go/pkg/gormx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"mayfly-go/pkg/req" "mayfly-go/pkg/req"
"mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/uniqueid"
"mayfly-go/pkg/ws" "mayfly-go/pkg/ws"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
) )
type Db struct { type Db struct {
@@ -126,29 +127,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
isMulti := len(sqls) > 1 isMulti := len(sqls) > 1
var execResAll *application.DbSqlExecRes var execResAll *application.DbSqlExecRes
progressId := uniqueid.IncrementID()
executedStatements := 0
progressTitle := fmt.Sprintf("%s/%s", dbConn.Info.Name, dbConn.Info.Database)
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: progressTitle,
ExecutedStatements: executedStatements,
Terminated: true,
}).WithCategory(progressCategory))
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
for _, s := range sqls { for _, s := range sqls {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
Title: progressTitle,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
}
executedStatements++
s = stringx.TrimSpaceAndBr(s) s = stringx.TrimSpaceAndBr(s)
// 多条执行,如果有查询语句,则跳过 // 多条执行,如果有查询语句,则跳过
if isMulti && strings.HasPrefix(strings.ToLower(s), "select") { if isMulti && strings.HasPrefix(strings.ToLower(s), "select") {
@@ -178,7 +157,7 @@ const progressCategory = "execSqlFileProgress"
// progressMsg sql文件执行进度消息 // progressMsg sql文件执行进度消息
type progressMsg struct { type progressMsg struct {
Id uint64 `json:"id"` Id string `json:"id"`
Title string `json:"title"` Title string `json:"title"`
ExecutedStatements int `json:"executedStatements"` ExecutedStatements int `json:"executedStatements"`
Terminated bool `json:"terminated"` Terminated bool `json:"terminated"`
@@ -195,6 +174,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
filename := file.FileName() filename := file.FileName()
dbId := getDbId(g) dbId := getDbId(g)
dbName := getDbName(g) dbName := getDbName(g)
clientId := g.Query("clientId")
dbConn := d.DbApp.GetDbConnection(dbId, dbName) dbConn := d.DbApp.GetDbConnection(dbId, dbName)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
@@ -209,7 +189,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
errInfo = t errInfo = t
} }
if len(errInfo) > 0 { if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo))) d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
} }
}() }()
@@ -226,9 +206,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
tokenizer := sqlparser.NewReaderTokenizer(file, tokenizer := sqlparser.NewReaderTokenizer(file,
sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect())) sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
progressId := uniqueid.IncrementID()
executedStatements := 0 executedStatements := 0
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ progressId := stringx.Rand(32)
defer ws.SendJsonMsg(ws.UserId(rc.LoginAccount.Id), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: filename, Title: filename,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,
@@ -239,7 +219,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ ws.SendJsonMsg(ws.UserId(rc.LoginAccount.Id), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: filename, Title: filename,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,
@@ -252,21 +232,19 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
if err == io.EOF { if err == io.EOF {
break break
} }
if err != nil { biz.ErrIsNilAppendErr(err, "%s")
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
return
}
const prefixUse = "use " const prefixUse = "use "
const prefixUSE = "USE " const prefixUSE = "USE "
if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) { if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) {
var stmt sqlparser.Statement var stmt sqlparser.Statement
stmt, err = sqlparser.Parse(sql) stmt, err = sqlparser.Parse(sql)
if err != nil { biz.ErrIsNilAppendErr(err, "%s")
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
}
stmtUse, ok := stmt.(*sqlparser.Use) stmtUse, ok := stmt.(*sqlparser.Use)
// 最终执行结果以数据库直接结果为准
if !ok { if !ok {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本解析失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql))) logx.Warnf("sql解析失败: %s", sql)
} }
dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String()) dbConn = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
@@ -281,12 +259,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
_, err = dbConn.Exec(sql) _, err = dbConn.Exec(sql)
} }
if err != nil { biz.ErrIsNilAppendErr(err, "%s")
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
}
} }
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成%s", rc.ReqParam))) d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成%s", rc.ReqParam)).WithClientId(clientId))
} }
// 数据库dump // 数据库dump

View File

@@ -14,6 +14,8 @@ type SysMsg struct {
Category string `json:"category"` // 消息类别 Category string `json:"category"` // 消息类别
Title string `json:"title"` // 消息标题 Title string `json:"title"` // 消息标题
Msg string `json:"msg"` // 消息内容 Msg string `json:"msg"` // 消息内容
ClientId string
} }
func (sm *SysMsg) WithTitle(title string) *SysMsg { func (sm *SysMsg) WithTitle(title string) *SysMsg {
@@ -31,6 +33,11 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg {
return sm return sm
} }
func (sm *SysMsg) WithClientId(clientId string) *SysMsg {
sm.ClientId = clientId
return sm
}
// 普通消息 // 普通消息
func InfoSysMsg(title string, msg any) *SysMsg { func InfoSysMsg(title string, msg any) *SysMsg {
return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)} return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)}

View File

@@ -40,5 +40,5 @@ func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *dto.SysMsg) {
now := time.Now() now := time.Now()
msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username} msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
a.msgRepo.Insert(msg) a.msgRepo.Insert(msg)
ws.SendJsonMsg(la.ClientUuid, wmsg) ws.SendJsonMsg(ws.UserId(la.Id), wmsg.ClientId, wmsg)
} }

View File

@@ -26,18 +26,19 @@ func (s *System) ConnectWs(g *gin.Context) {
} }
}() }()
if err != nil { biz.ErrIsNilAppendErr(err, "%s")
panic(biz.NewBizErr("升级websocket失败")) clientId := g.Query("clientId")
} biz.NotEmpty(clientId, "clientId不能为空")
// 权限校验 // 权限校验
rc := req.NewCtxWithGin(g) rc := req.NewCtxWithGin(g)
if err = req.PermissionHandler(rc); err != nil { if err = req.PermissionHandler(rc); err != nil {
panic(biz.NewBizErr("没有权限")) panic("sys ws连接没有权限")
} }
// 登录账号信息 // 登录账号信息
la := rc.LoginAccount la := rc.LoginAccount
if la != nil { if la != nil {
ws.AddClient(la.Id, la.ClientUuid, wsConn) ws.AddClient(ws.UserId(la.Id), clientId, wsConn)
} }
} }

View File

@@ -3,7 +3,4 @@ package model
type LoginAccount struct { type LoginAccount struct {
Id uint64 Id uint64
Username string Username string
// ClientUuid 客户端UUID
ClientUuid string
} }

View File

@@ -60,16 +60,10 @@ func PermissionHandler(rc *Ctx) error {
return biz.PermissionErr return biz.PermissionErr
} }
} }
clientUuid := rc.GinCtx.Request.Header.Get("Client-Uuid")
// header不存在则从查询参数token中获取
if clientUuid == "" {
clientUuid = rc.GinCtx.Query("clientUuid")
}
if rc.LoginAccount == nil { if rc.LoginAccount == nil {
rc.LoginAccount = &model.LoginAccount{ rc.LoginAccount = &model.LoginAccount{
Id: userId, Id: userId,
Username: userName, Username: userName,
ClientUuid: clientUuid,
} }
} }
return nil return nil

View File

@@ -114,3 +114,14 @@ func ArrayReduce[T any, V any](arr []T, initialValue V, reducer func(V, T) V) V
} }
return value return value
} }
// 数组元素移除操作
func ArrayRemoveFunc[T any](arr []T, isDeleteFunc func(T) bool) []T {
var newArr []T
for _, a := range arr {
if !isDeleteFunc(a) {
newArr = append(newArr, a)
}
}
return newArr
}

View File

@@ -1,9 +0,0 @@
package uniqueid
import "sync/atomic"
var id uint64 = 0
func IncrementID() uint64 {
return atomic.AddUint64(&id, 1)
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"mayfly-go/pkg/logx" "mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/stringx"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@@ -17,27 +16,25 @@ type UserId uint64
type ReadMsgHandlerFunc func([]byte) type ReadMsgHandlerFunc func([]byte)
type Client struct { type Client struct {
ClientId string // 标识ID ClientId string // 标识ID
UserId UserId // 用户ID UserId UserId // 用户ID
ClientUuid string // 客户端UUID WsConn *websocket.Conn // 用户连接
WsConn *websocket.Conn // 用户连接
ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数 ReadMsgHandler ReadMsgHandlerFunc // 读取消息处理函数
} }
func NewClient(userId UserId, clientUuid string, socket *websocket.Conn) *Client { func NewClient(userId UserId, clientId string, socket *websocket.Conn) *Client {
cli := &Client{ cli := &Client{
ClientId: stringx.Rand(16), ClientId: clientId,
UserId: userId, UserId: userId,
ClientUuid: clientUuid, WsConn: socket,
WsConn: socket,
} }
return cli return cli
} }
func (c *Client) WithReadHandlerFunc(readMsgHandlerFunc ReadMsgHandlerFunc) *Client { func (c *Client) WithReadHandlerFunc(readMsgHandlerFunc ReadMsgHandlerFunc) *Client {
c.ReadMsgHander = readMsgHandlerFunc c.ReadMsgHandler = readMsgHandlerFunc
return c return c
} }
@@ -58,8 +55,8 @@ func (c *Client) Read() {
return return
} }
} }
if c.ReadMsgHander != nil { if c.ReadMsgHandler != nil {
c.ReadMsgHander(data) c.ReadMsgHandler(data)
} }
} }
}() }()
@@ -67,7 +64,7 @@ func (c *Client) Read() {
// 向客户端写入消息 // 向客户端写入消息
func (c *Client) WriteMsg(msg *Msg) error { func (c *Client) WriteMsg(msg *Msg) error {
logx.Debugf("发送消息: toUid=%v, data=%v", c.UserId, msg.Data) logx.Debugf("发送消息: toUserId=%v, toClientId=%s, data=%v", c.UserId, c.ClientId, msg.Data)
if msg.Type == JsonMsg { if msg.Type == JsonMsg {
bytes, _ := json.Marshal(msg.Data) bytes, _ := json.Marshal(msg.Data)

View File

@@ -9,10 +9,29 @@ import (
// 心跳间隔 // 心跳间隔
const heartbeatInterval = 25 * time.Second const heartbeatInterval = 25 * time.Second
// 单个用户的全部的连接, key->clientId, value->Client
type UserClients map[string]*Client
func (ucs UserClients) GetByCid(clientId string) *Client {
return ucs[clientId]
}
func (ucs UserClients) AddClient(client *Client) {
ucs[client.ClientId] = client
}
func (ucs UserClients) DeleteByCid(clientId string) {
delete(ucs, clientId)
}
func (ucs UserClients) Count() int {
return len(ucs)
}
// 连接管理 // 连接管理
type ClientManager struct { type ClientManager struct {
ClientMap map[string]*Client // 全部的连接, key->token, value->&client UserClientsMap map[UserId]UserClients // 全部的用户连接, key->userid, value->UserClients
RwLock sync.RWMutex // 读写锁 RwLock sync.RWMutex // 读写锁
ConnectChan chan *Client // 连接处理 ConnectChan chan *Client // 连接处理
DisConnectChan chan *Client // 断开连接处理 DisConnectChan chan *Client // 断开连接处理
@@ -21,7 +40,7 @@ type ClientManager struct {
func NewClientManager() (clientManager *ClientManager) { func NewClientManager() (clientManager *ClientManager) {
return &ClientManager{ return &ClientManager{
ClientMap: make(map[string]*Client), UserClientsMap: make(map[UserId]UserClients),
ConnectChan: make(chan *Client, 10), ConnectChan: make(chan *Client, 10),
DisConnectChan: make(chan *Client, 10), DisConnectChan: make(chan *Client, 10),
MsgChan: make(chan *Msg, 100), MsgChan: make(chan *Msg, 100),
@@ -58,47 +77,45 @@ func (manager *ClientManager) CloseClient(client *Client) {
} }
// 根据用户id关闭客户端连接 // 根据用户id关闭客户端连接
func (manager *ClientManager) CloseByClientUuid(clientUuid string) { func (manager *ClientManager) CloseByUid(userId UserId) {
manager.CloseClient(manager.GetByClientUuid(clientUuid)) for _, client := range manager.GetByUid(userId) {
manager.CloseClient(client)
}
} }
// 获取所有的客户端 // 获取所有的客户端
func (manager *ClientManager) AllClient() map[string]*Client { func (manager *ClientManager) AllUserClient() map[UserId]UserClients {
manager.RwLock.RLock() manager.RwLock.RLock()
defer manager.RwLock.RUnlock() defer manager.RwLock.RUnlock()
return manager.ClientMap return manager.UserClientsMap
} }
// 通过userId获取 // 通过userId获取用户所有客户端信息
func (manager *ClientManager) GetByUid(userId UserId) *Client { func (manager *ClientManager) GetByUid(userId UserId) UserClients {
manager.RwLock.RLock() manager.RwLock.RLock()
defer manager.RwLock.RUnlock() defer manager.RwLock.RUnlock()
for _, client := range manager.ClientMap { return manager.UserClientsMap[userId]
if userId == client.UserId { }
return client
} // 通过userId和clientId获取客户端信息
func (manager *ClientManager) GetByUidAndCid(uid UserId, clientId string) *Client {
if clients := manager.GetByUid(uid); clients != nil {
return clients.GetByCid(clientId)
} }
return nil return nil
} }
// 通过userId获取
func (manager *ClientManager) GetByClientUuid(uuid string) *Client {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
return manager.ClientMap[uuid]
}
// 客户端数量 // 客户端数量
func (manager *ClientManager) Count() int { func (manager *ClientManager) Count() int {
manager.RwLock.RLock() manager.RwLock.RLock()
defer manager.RwLock.RUnlock() defer manager.RwLock.RUnlock()
return len(manager.ClientMap) return len(manager.UserClientsMap)
} }
// 发送json数据给指定用户 // 发送json数据给指定用户
func (manager *ClientManager) SendJsonMsg(clientUuid string, data any) { func (manager *ClientManager) SendJsonMsg(userId UserId, clientId string, data any) {
manager.MsgChan <- &Msg{ToClientUuid: clientUuid, Data: data, Type: JsonMsg} manager.MsgChan <- &Msg{ToUserId: userId, ToClientId: clientId, Data: data, Type: JsonMsg}
} }
// 监听并发送给客户端信息 // 监听并发送给客户端信息
@@ -106,10 +123,22 @@ func (manager *ClientManager) WriteMessage() {
go func() { go func() {
for { for {
msg := <-manager.MsgChan msg := <-manager.MsgChan
if cli := manager.GetByClientUuid(msg.ToClientUuid); cli != nil { uid := msg.ToUserId
if err := cli.WriteMsg(msg); err != nil { cid := msg.ToClientId
manager.CloseClient(cli) // 客户端id不为空则向指定客户端发送消息即可
if cid != "" {
cli := manager.GetByUidAndCid(uid, cid)
if cli != nil {
cli.WriteMsg(msg)
} else {
logx.Warnf("[uid=%v, cid=%s]的ws连接不存在", uid, cid)
} }
continue
}
// cid为空则向该用户所有客户端发送该消息
for _, cli := range manager.GetByUid(uid) {
cli.WriteMsg(msg)
} }
} }
}() }()
@@ -123,15 +152,17 @@ func (manager *ClientManager) HeartbeatTimer() {
for { for {
<-ticker.C <-ticker.C
//发送心跳 //发送心跳
for userId, cli := range manager.AllClient() { for userId, clis := range manager.AllUserClient() {
if cli == nil || cli.WsConn == nil { for _, cli := range clis {
continue if cli == nil || cli.WsConn == nil {
} continue
if err := cli.Ping(); err != nil { }
manager.CloseClient(cli) if err := cli.Ping(); err != nil {
logx.Debugf("WS发送心跳失败: %v 总连接数:%d", userId, Manager.Count()) manager.CloseClient(cli)
} else { logx.Debugf("WS发送心跳失败: uid=%v, cid=%s, usercount=%d", userId, cli.ClientId, Manager.Count())
logx.Debugf("WS发送心跳成功: uid=%v", userId) } else {
logx.Debugf("WS发送心跳成功: uid=%v, cid=%s", userId, cli.ClientId)
}
} }
} }
} }
@@ -141,12 +172,12 @@ func (manager *ClientManager) HeartbeatTimer() {
// 处理建立连接 // 处理建立连接
func (manager *ClientManager) doConnect(client *Client) { func (manager *ClientManager) doConnect(client *Client) {
cli := manager.GetByClientUuid(client.ClientUuid) cli := manager.GetByUidAndCid(client.UserId, client.ClientId)
if cli != nil { if cli != nil {
manager.doDisconnect(cli) manager.doDisconnect(cli)
} }
manager.addClient2Map(client) manager.addUserClient2Map(client)
logx.Debugf("WS客户端已连接: uid=%d, count=%d", client.UserId, manager.Count()) logx.Debugf("WS客户端已连接: uid=%d, cid=%s, usercount=%d", client.UserId, client.ClientId, manager.Count())
} }
// 处理断开连接 // 处理断开连接
@@ -156,18 +187,32 @@ func (manager *ClientManager) doDisconnect(client *Client) {
_ = client.WsConn.Close() _ = client.WsConn.Close()
client.WsConn = nil client.WsConn = nil
} }
manager.delClient4Map(client) manager.delUserClient4Map(client)
logx.Debugf("WS客户端已断开: uid=%d, count=%d", client.UserId, Manager.Count()) logx.Debugf("WS客户端已断开: uid=%d, cid=%s, usercount=%d", client.UserId, client.ClientId, Manager.Count())
} }
func (manager *ClientManager) addClient2Map(client *Client) { func (manager *ClientManager) addUserClient2Map(client *Client) {
manager.RwLock.Lock() manager.RwLock.Lock()
defer manager.RwLock.Unlock() defer manager.RwLock.Unlock()
manager.ClientMap[client.ClientUuid] = client
userClients := manager.UserClientsMap[client.UserId]
if userClients == nil {
userClients = make(UserClients)
manager.UserClientsMap[client.UserId] = userClients
}
userClients.AddClient(client)
} }
func (manager *ClientManager) delClient4Map(client *Client) { func (manager *ClientManager) delUserClient4Map(client *Client) {
manager.RwLock.Lock() manager.RwLock.Lock()
defer manager.RwLock.Unlock() defer manager.RwLock.Unlock()
delete(manager.ClientMap, client.ClientUuid)
userClients := manager.UserClientsMap[client.UserId]
if userClients != nil {
userClients.DeleteByCid(client.ClientId)
// 如果用户所有客户端都关闭则移除manager中的UserClientsMap值
if userClients.Count() == 0 {
delete(manager.UserClientsMap, client.UserId)
}
}
} }

View File

@@ -11,9 +11,9 @@ const (
// 消息信息 // 消息信息
type Msg struct { type Msg struct {
ToUserId UserId ToUserId UserId // 用户id
Data any ToClientId string // 客户端id
Type MsgType // 消息类型 Type MsgType // 消息类型
ToClientUuid string Data any
} }

View File

@@ -21,21 +21,21 @@ func init() {
} }
// 添加ws客户端 // 添加ws客户端
func AddClient(userId uint64, clientUuid string, conn *websocket.Conn) *Client { func AddClient(userId UserId, clientId string, conn *websocket.Conn) *Client {
if len(clientUuid) == 0 { if len(clientId) == 0 {
return nil return nil
} }
cli := NewClient(UserId(userId), clientUuid, conn) cli := NewClient(UserId(userId), clientId, conn)
cli.Read() cli.Read()
Manager.AddClient(cli) Manager.AddClient(cli)
return cli return cli
} }
func CloseClient(clientUuid string) { func CloseClient(uid UserId) {
Manager.CloseByClientUuid(clientUuid) Manager.CloseByUid(uid)
} }
// 对指定用户发送json类型消息 // 对指定用户发送json类型消息
func SendJsonMsg(clientUuid string, msg any) { func SendJsonMsg(userId UserId, clientId string, msg any) {
Manager.SendJsonMsg(clientUuid, msg) Manager.SendJsonMsg(userId, clientId, msg)
} }