feature: 每个客户端独立处理后端发送的系统消息

This commit is contained in:
wanli
2023-10-18 15:24:29 +08:00
committed by kanzihuang
parent 361eafedae
commit ccfc6bd1df
26 changed files with 171 additions and 95 deletions

View File

@@ -28,6 +28,7 @@
"screenfull": "^6.0.2", "screenfull": "^6.0.2",
"sortablejs": "^1.15.0", "sortablejs": "^1.15.0",
"sql-formatter": "^12.1.2", "sql-formatter": "^12.1.2",
"uuid": "^9.0.1",
"vue": "^3.3.4", "vue": "^3.3.4",
"vue-clipboard3": "^1.0.1", "vue-clipboard3": "^1.0.1",
"vue-router": "^4.2.5", "vue-router": "^4.2.5",

View File

@@ -1,7 +1,7 @@
import router from '../router'; import router from '../router';
import Axios from 'axios'; import Axios from 'axios';
import config from './config'; import config from './config';
import { getToken } from './utils/storage'; import { getClientUuid, getToken, joinClientParams } from './utils/storage';
import { templateResolve } from './utils/string'; import { templateResolve } from './utils/string';
import { ElMessage } from 'element-plus'; import { ElMessage } from 'element-plus';
@@ -54,6 +54,7 @@ service.interceptors.request.use(
if (token) { if (token) {
// 设置token // 设置token
config.headers['Authorization'] = token; config.headers['Authorization'] = token;
config.headers['Client-Uuid'] = getClientUuid();
} }
return config; return config;
}, },
@@ -176,7 +177,7 @@ function del(url: string, params: any = null, headers: any = null, options: any
function getApiUrl(url: string) { function getApiUrl(url: string) {
// 只是返回api地址而不做请求用在上传组件之类的 // 只是返回api地址而不做请求用在上传组件之类的
return baseUrl + url + '?token=' + getToken(); return baseUrl + url + '?' + joinClientParams();
} }
export default { export default {

View File

@@ -1,9 +1,9 @@
import Config from './config'; import Config from './config';
import { ElNotification, NotificationHandle } from 'element-plus'; import { ElNotification, NotificationHandle } from 'element-plus';
import SocketBuilder from './SocketBuilder'; import SocketBuilder from './SocketBuilder';
import { getToken } from '@/common/utils/storage'; import { getToken, joinClientParams } from '@/common/utils/storage';
import { createVNode, reactive } from "vue"; import { createVNode, reactive } from 'vue';
import { buildProgressProps } from "@/components/progress-notify/progress-notify"; import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue'; import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
export default { export default {
@@ -16,43 +16,40 @@ export default {
return null; return null;
} }
const messageTypes = { const messageTypes = {
0: "error", 0: 'error',
1: "success", 1: 'success',
2: "info", 2: 'info',
} };
const notifyMap: Map<Number, any> = new Map() const notifyMap: Map<Number, any> = new Map();
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`) return SocketBuilder.builder(sysMsgUrl)
.message((event: { data: string }) => { .message((event: { data: string }) => {
const message = JSON.parse(event.data); const message = JSON.parse(event.data);
const type = messageTypes[message.type] const type = messageTypes[message.type];
switch (message.category) { switch (message.category) {
case "execSqlFileProgress": case 'execSqlFileProgress':
const content = JSON.parse(message.msg) const content = JSON.parse(message.msg);
const id = content.id const id = content.id;
let progress = notifyMap.get(id) let progress = notifyMap.get(id);
if (content.terminated) { if (content.terminated) {
if (progress != undefined) { if (progress != undefined) {
progress.notification?.close() progress.notification?.close();
notifyMap.delete(id) notifyMap.delete(id);
progress = undefined progress = undefined;
} }
return return;
} }
if (progress == undefined) { if (progress == undefined) {
progress = { progress = {
props: reactive(buildProgressProps()), props: reactive(buildProgressProps()),
notification: undefined, notification: undefined,
} };
} }
progress.props.progress.title = content.title progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements progress.props.progress.executedStatements = content.executedStatements;
if (!notifyMap.has(id)) { if (!notifyMap.has(id)) {
const vNodeMessage = createVNode( const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
ProgressNotify,
progress.props,
null,
)
progress.notification = ElNotification({ progress.notification = ElNotification({
duration: 0, duration: 0,
title: message.title, title: message.title,
@@ -60,7 +57,7 @@ export default {
type: type, type: type,
showClose: false, showClose: false,
}); });
notifyMap.set(id, progress) notifyMap.set(id, progress);
} }
break; break;
default: default:

View File

@@ -1,6 +1,9 @@
import { v1 as uuidv1 } from 'uuid';
const TokenKey = 'token'; const TokenKey = 'token';
const UserKey = 'user'; const UserKey = 'user';
const TagViewsKey = 'tagViews'; const TagViewsKey = 'tagViews';
const ClientUuid = 'clientUuid'
// 获取请求token // 获取请求token
export function getToken(): string { export function getToken(): string {
@@ -48,6 +51,21 @@ export function removeTagViews() {
removeSession(TagViewsKey); removeSession(TagViewsKey);
} }
// 获取客户端UUID
export function getClientUuid(): string {
let uuid = getSession(ClientUuid)
if (uuid == null) {
uuid = uuidv1()
setSession(ClientUuid, uuid)
}
return uuid
}
// 组装客户端参数,包括 token 和 clientUuid
export function joinClientParams(): string {
return `token=${getToken()}&clientUuid=${getClientUuid()}`
}
// 1. localStorage // 1. localStorage
// 设置永久缓存 // 设置永久缓存
export function setLocal(key: string, val: any) { export function setLocal(key: string, val: any) {

View File

@@ -132,7 +132,7 @@ import { nextTick, onMounted, ref, toRefs, reactive, computed } from 'vue';
import { useRoute, useRouter } from 'vue-router'; import { useRoute, useRouter } from 'vue-router';
import { ElMessage } from 'element-plus'; import { ElMessage } from 'element-plus';
import { initRouter } from '@/router/index'; import { initRouter } from '@/router/index';
import { saveToken, saveUser } from '@/common/utils/storage'; import { saveToken, saveUser, setSession } from '@/common/utils/storage';
import { formatAxis } from '@/common/utils/format'; import { formatAxis } from '@/common/utils/format';
import openApi from '@/common/openApi'; import openApi from '@/common/openApi';
import { RsaEncrypt } from '@/common/rsa'; import { RsaEncrypt } from '@/common/rsa';
@@ -364,7 +364,7 @@ const loginResDeal = (loginRes: any) => {
useUserInfo().setUserInfo(userInfos); useUserInfo().setUserInfo(userInfos);
const token = loginRes.token; const token = loginRes.token;
// 如果不需要otp校验则该token即为accessToken否则为otp校验token // 如果不需要 otp校验则该token即为accessToken否则为otp校验token
if (loginRes.otp == -1) { if (loginRes.otp == -1) {
signInSuccess(token); signInSuccess(token);
return; return;
@@ -385,6 +385,7 @@ const signInSuccess = async (accessToken: string = '') => {
} }
// 存储 token 到浏览器缓存 // 存储 token 到浏览器缓存
saveToken(accessToken); saveToken(accessToken);
// 初始化路由 // 初始化路由
await initRouter(); await initRouter();

View File

@@ -172,7 +172,7 @@ import { ref, toRefs, reactive, onMounted, defineAsyncComponent } from 'vue';
import { ElMessage, ElMessageBox } from 'element-plus'; import { ElMessage, ElMessageBox } from 'element-plus';
import { dbApi } from './api'; import { dbApi } from './api';
import config from '@/common/config'; import config from '@/common/config';
import { getToken } from '@/common/utils/storage'; import { joinClientParams } from '@/common/utils/storage';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
import { Search as SearchIcon } from '@element-plus/icons-vue'; import { Search as SearchIcon } from '@element-plus/icons-vue';
import { dateFormat } from '@/common/utils/date'; import { dateFormat } from '@/common/utils/date';
@@ -406,7 +406,7 @@ const dumpDbs = () => {
'href', 'href',
`${config.baseApiUrl}/dbs/${state.exportDialog.dbId}/dump?db=${state.exportDialog.value.join(',')}&type=${type}&extName=${ `${config.baseApiUrl}/dbs/${state.exportDialog.dbId}/dump?db=${state.exportDialog.value.join(',')}&type=${type}&extName=${
state.exportDialog.extName state.exportDialog.extName
}&token=${getToken()}` }&${joinClientParams()}`
); );
a.click(); a.click();
state.exportDialog.visible = false; state.exportDialog.visible = false;

View File

@@ -88,7 +88,7 @@
<script lang="ts" setup> <script lang="ts" setup>
import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue'; import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue';
import { getToken } from '@/common/utils/storage'; import { getToken, joinClientParams } from '@/common/utils/storage';
import { isTrue, notBlank } from '@/common/assert'; import { isTrue, notBlank } from '@/common/assert';
import { format as sqlFormatter } from 'sql-formatter'; import { format as sqlFormatter } from 'sql-formatter';
import config from '@/common/config'; import config from '@/common/config';
@@ -485,7 +485,7 @@ const execSqlFileSuccess = (res: any) => {
// 获取sql文件上传执行url // 获取sql文件上传执行url
const getUploadSqlFileUrl = () => { const getUploadSqlFileUrl = () => {
return `${config.baseApiUrl}/dbs/${state.ti.dbId}/exec-sql-file?db=${state.ti.db}`; return `${config.baseApiUrl}/dbs/${state.ti.dbId}/exec-sql-file?db=${state.ti.db}&${joinClientParams()}`;
}; };
const onDataSelectionChange = (datas: []) => { const onDataSelectionChange = (datas: []) => {

View File

@@ -124,7 +124,7 @@ import { formatByteSize } from '@/common/utils/format';
import { dbApi } from '../api'; import { dbApi } from '../api';
import SqlExecBox from '../component/SqlExecBox'; import SqlExecBox from '../component/SqlExecBox';
import config from '@/common/config'; import config from '@/common/config';
import { getToken } from '@/common/utils/storage'; import { joinClientParams } from '@/common/utils/storage';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue')); const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue'));
@@ -259,7 +259,7 @@ const dump = (db: string) => {
const a = document.createElement('a'); const a = document.createElement('a');
a.setAttribute( a.setAttribute(
'href', 'href',
`${config.baseApiUrl}/dbs/${props.dbId}/dump?db=${db}&type=${state.dumpInfo.type}&tables=${state.dumpInfo.tables.join(',')}&token=${getToken()}` `${config.baseApiUrl}/dbs/${props.dbId}/dump?db=${db}&type=${state.dumpInfo.type}&tables=${state.dumpInfo.tables.join(',')}&${joinClientParams()}`
); );
a.click(); a.click();
state.showDumpInfo = false; state.showDumpInfo = false;

View File

@@ -1,6 +1,6 @@
import Api from '@/common/Api'; import Api from '@/common/Api';
import config from '@/common/config'; import config from '@/common/config';
import { getToken } from '@/common/utils/storage'; import { joinClientParams } from '@/common/utils/storage';
export const machineApi = { export const machineApi = {
// 获取权限列表 // 获取权限列表
@@ -33,7 +33,7 @@ export const machineApi = {
cpFile: Api.newPost('/machines/{machineId}/files/{fileId}/cp'), cpFile: Api.newPost('/machines/{machineId}/files/{fileId}/cp'),
renameFile: Api.newPost('/machines/{machineId}/files/{fileId}/rename'), renameFile: Api.newPost('/machines/{machineId}/files/{fileId}/rename'),
mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'), mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'),
uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?token={token}'), uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()),
fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'), fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'),
createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'), createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'),
// 修改文件内容 // 修改文件内容
@@ -63,5 +63,5 @@ export const cronJobApi = {
}; };
export function getMachineTerminalSocketUrl(machineId: any) { export function getMachineTerminalSocketUrl(machineId: any) {
return `${config.baseWsUrl}/machines/${machineId}/terminal?token=${getToken()}`; return `${config.baseWsUrl}/machines/${machineId}/terminal?${joinClientParams()}`;
} }

View File

@@ -274,7 +274,7 @@ import { ref, toRefs, reactive, onMounted, computed } from 'vue';
import { ElMessage, ElMessageBox, ElInput } from 'element-plus'; import { ElMessage, ElMessageBox, ElInput } from 'element-plus';
import { machineApi } from '../api'; import { machineApi } from '../api';
import { getToken } from '@/common/utils/storage'; import { joinClientParams } from '@/common/utils/storage';
import config from '@/common/config'; import config from '@/common/config';
import { isTrue } from '@/common/assert'; import { isTrue } from '@/common/assert';
import MachineFileContent from './MachineFileContent.vue'; import MachineFileContent from './MachineFileContent.vue';
@@ -607,7 +607,7 @@ const deleteFile = async (files: any) => {
const downloadFile = (data: any) => { const downloadFile = (data: any) => {
const a = document.createElement('a'); const a = document.createElement('a');
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&token=${token}`); a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`);
a.click(); a.click();
}; };
@@ -628,7 +628,7 @@ function getFolder(e: any) {
// 上传操作 // 上传操作
machineApi.uploadFile machineApi.uploadFile
.request(form, { .request(form, {
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload-folder?token=${token}`, url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload-folder?${joinClientParams()}`,
headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' }, headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' },
onUploadProgress: onUploadProgress, onUploadProgress: onUploadProgress,
baseURL: '', baseURL: '',
@@ -669,7 +669,7 @@ const getUploadFile = (content: any) => {
params.append('token', token); params.append('token', token);
machineApi.uploadFile machineApi.uploadFile
.request(params, { .request(params, {
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload?token=${token}`, url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload?${joinClientParams()}`,
headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' }, headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' },
onUploadProgress: onUploadProgress, onUploadProgress: onUploadProgress,
baseURL: '', baseURL: '',

View File

@@ -179,7 +179,7 @@ import { dateFormat } from '@/common/utils/date';
import { storeToRefs } from 'pinia'; import { storeToRefs } from 'pinia';
import { useUserInfo } from '@/store/userInfo'; import { useUserInfo } from '@/store/userInfo';
import config from '@/common/config'; import config from '@/common/config';
import { getToken } from '@/common/utils/storage'; import { joinClientParams } from '@/common/utils/storage';
const { userInfo } = storeToRefs(useUserInfo()); const { userInfo } = storeToRefs(useUserInfo());
const state = reactive({ const state = reactive({
@@ -248,7 +248,7 @@ const bindOAuth2 = () => {
var iLeft = (window.screen.width - 10 - width) / 2; //获得窗口的水平位置; var iLeft = (window.screen.width - 10 - width) / 2; //获得窗口的水平位置;
// 小窗口打开oauth2鉴权 // 小窗口打开oauth2鉴权
let oauthWindow = window.open( let oauthWindow = window.open(
config.baseApiUrl + '/auth/oauth2/bind?token=' + getToken(), `${config.baseApiUrl}/auth/oauth2/bind?${joinClientParams()}`,
'oauth2', 'oauth2',
`height=${height},width=${width},top=${iTop},left=${iLeft},location=no` `height=${height},width=${width},top=${iTop},left=${iLeft},location=no`
); );

View File

@@ -1781,6 +1781,11 @@ uri-js@^4.2.2:
dependencies: dependencies:
punycode "^2.1.0" punycode "^2.1.0"
uuid@^9.0.1:
version "9.0.1"
resolved "https://registry.npmmirror.com/uuid/-/uuid-9.0.1.tgz#e188d4c8853cc722220392c424cd637f32293f30"
integrity sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==
vite@^4.4.11: vite@^4.4.11:
version "4.4.11" version "4.4.11"
resolved "https://registry.npmmirror.com/vite/-/vite-4.4.11.tgz#babdb055b08c69cfc4c468072a2e6c9ca62102b0" resolved "https://registry.npmmirror.com/vite/-/vite-4.4.11.tgz#babdb055b08c69cfc4c468072a2e6c9ca62102b0"

View File

@@ -14,7 +14,7 @@ require (
github.com/go-sql-driver/mysql v1.7.1 github.com/go-sql-driver/mysql v1.7.1
github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang-jwt/jwt/v5 v5.0.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231014104824-e3b9aa5415a4 github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231018071450-ac8d9f0167e9
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
github.com/mojocn/base64Captcha v1.3.5 // github.com/mojocn/base64Captcha v1.3.5 //

View File

@@ -119,5 +119,5 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
func (a *AccountLogin) Logout(rc *req.Ctx) { func (a *AccountLogin) Logout(rc *req.Ctx) {
req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id) req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id)
ws.CloseClient(rc.LoginAccount.Id) ws.CloseClient(rc.LoginAccount.ClientUuid)
} }

View File

@@ -121,7 +121,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
LoginAccount: rc.LoginAccount, LoginAccount: rc.LoginAccount,
} }
sqls, err := sqlparser.SplitStatementToPieces(sql) sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
biz.ErrIsNil(err, "SQL解析错误,请检查您的执行SQL") biz.ErrIsNil(err, "SQL解析错误,请检查您的执行SQL")
isMulti := len(sqls) > 1 isMulti := len(sqls) > 1
var execResAll *application.DbSqlExecRes var execResAll *application.DbSqlExecRes
@@ -129,7 +129,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
progressId := uniqueid.IncrementID() progressId := uniqueid.IncrementID()
executedStatements := 0 executedStatements := 0
progressTitle := fmt.Sprintf("%s/%s", dbConn.Info.Name, dbConn.Info.Database) progressTitle := fmt.Sprintf("%s/%s", dbConn.Info.Name, dbConn.Info.Database)
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: progressTitle, Title: progressTitle,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,
@@ -140,7 +140,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
for _, s := range sqls { for _, s := range sqls {
select { select {
case <-ticker.C: case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: progressTitle, Title: progressTitle,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,
@@ -222,11 +222,13 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
} }
var sql string var sql string
tokenizer := sqlparser.NewReaderTokenizer(file, sqlparser.WithCacheInBuffer())
tokenizer := sqlparser.NewReaderTokenizer(file,
sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
progressId := uniqueid.IncrementID() progressId := uniqueid.IncrementID()
executedStatements := 0 executedStatements := 0
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: filename, Title: filename,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,
@@ -237,7 +239,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{ ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId, Id: progressId,
Title: filename, Title: filename,
ExecutedStatements: executedStatements, ExecutedStatements: executedStatements,

View File

@@ -2,6 +2,7 @@ package entity
import ( import (
"fmt" "fmt"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"github.com/lib/pq" "github.com/lib/pq"
"strings" "strings"
) )
@@ -59,6 +60,17 @@ func (dbType DbType) StmtSelectDbName() string {
} }
} }
func (dbType DbType) Dialect() sqlparser.Dialect {
switch dbType {
case DbTypeMysql:
return sqlparser.MysqlDialect{}
case DbTypePostgres:
return sqlparser.PostgresDialect{}
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example: // used as part of an SQL statement. For example:
// //

View File

@@ -40,5 +40,5 @@ func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *dto.SysMsg) {
now := time.Now() now := time.Now()
msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username} msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
a.msgRepo.Insert(msg) a.msgRepo.Insert(msg)
ws.SendJsonMsg(la.Id, wmsg) ws.SendJsonMsg(la.ClientUuid, wmsg)
} }

View File

@@ -81,7 +81,12 @@ func (a *Account) ChangePassword(rc *req.Ctx) {
a.AccountApp.Update(updateAccount) a.AccountApp.Update(updateAccount)
// 赋值loginAccount 主要用于记录操作日志,因为操作日志保存请求上下文没有该信息不保存日志 // 赋值loginAccount 主要用于记录操作日志,因为操作日志保存请求上下文没有该信息不保存日志
rc.LoginAccount = &model.LoginAccount{Id: account.Id, Username: account.Username} if rc.LoginAccount == nil {
rc.LoginAccount = &model.LoginAccount{
Id: account.Id,
Username: account.Username,
}
}
} }
// 获取个人账号信息 // 获取个人账号信息

View File

@@ -37,5 +37,7 @@ func (s *System) ConnectWs(g *gin.Context) {
// 登录账号信息 // 登录账号信息
la := rc.LoginAccount la := rc.LoginAccount
ws.AddClient(la.Id, wsConn) if la != nil {
ws.AddClient(la.Id, la.ClientUuid, wsConn)
}
} }

View File

@@ -3,4 +3,7 @@ package model
type LoginAccount struct { type LoginAccount struct {
Id uint64 Id uint64
Username string Username string
// ClientUuid 客户端UUID
ClientUuid string
} }

View File

@@ -6,6 +6,7 @@ import (
"mayfly-go/pkg/biz" "mayfly-go/pkg/biz"
"mayfly-go/pkg/cache" "mayfly-go/pkg/cache"
"mayfly-go/pkg/config" "mayfly-go/pkg/config"
"mayfly-go/pkg/model"
"mayfly-go/pkg/rediscli" "mayfly-go/pkg/rediscli"
"mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/stringx"
"time" "time"
@@ -49,18 +50,28 @@ func PermissionHandler(rc *Ctx) error {
if tokenStr == "" { if tokenStr == "" {
return biz.PermissionErr return biz.PermissionErr
} }
loginAccount, err := ParseToken(tokenStr) userId, userName, err := ParseToken(tokenStr)
if err != nil || loginAccount == nil { if err != nil || userId == 0 {
return biz.PermissionErr return biz.PermissionErr
} }
// 权限不为nil并且permission code不为空则校验是否有权限code // 权限不为nil并且permission code不为空则校验是否有权限code
if permission != nil && permission.Code != "" { if permission != nil && permission.Code != "" {
if !permissionCodeRegistry.HasCode(loginAccount.Id, permission.Code) { if !permissionCodeRegistry.HasCode(userId, permission.Code) {
return biz.PermissionErr return biz.PermissionErr
} }
} }
clientUuid := rc.GinCtx.Request.Header.Get("Client-Uuid")
rc.LoginAccount = loginAccount // header不存在则从查询参数token中获取
if clientUuid == "" {
clientUuid = rc.GinCtx.Query("clientUuid")
}
if rc.LoginAccount == nil {
rc.LoginAccount = &model.LoginAccount{
Id: userId,
Username: userName,
ClientUuid: clientUuid,
}
}
return nil return nil
} }

View File

@@ -2,10 +2,8 @@ package req
import ( import (
"errors" "errors"
"mayfly-go/pkg/biz" "mayfly-go/pkg/biz"
"mayfly-go/pkg/config" "mayfly-go/pkg/config"
"mayfly-go/pkg/model"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -28,9 +26,9 @@ func CreateToken(userId uint64, username string) string {
} }
// 解析token并返回登录者账号信息 // 解析token并返回登录者账号信息
func ParseToken(tokenStr string) (*model.LoginAccount, error) { func ParseToken(tokenStr string) (uint64, string, error) {
if tokenStr == "" { if tokenStr == "" {
return nil, errors.New("token error") return 0, "", errors.New("token error")
} }
// Parse token // Parse token
@@ -38,8 +36,8 @@ func ParseToken(tokenStr string) (*model.LoginAccount, error) {
return []byte(config.Conf.Jwt.Key), nil return []byte(config.Conf.Jwt.Key), nil
}) })
if err != nil || token == nil { if err != nil || token == nil {
return nil, err return 0, "", err
} }
i := token.Claims.(jwt.MapClaims) i := token.Claims.(jwt.MapClaims)
return &model.LoginAccount{Id: uint64(i["id"].(float64)), Username: i["username"].(string)}, nil return uint64(i["id"].(float64)), i["username"].(string), nil
} }

View File

@@ -3,6 +3,7 @@ package ws
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/stringx"
"time" "time"
@@ -16,18 +17,20 @@ type UserId uint64
type ReadMsgHandlerFunc func([]byte) type ReadMsgHandlerFunc func([]byte)
type Client struct { type Client struct {
ClientId string // 标识ID ClientId string // 标识ID
UserId UserId // 用户ID UserId UserId // 用户ID
WsConn *websocket.Conn // 用户连接 ClientUuid string // 客户端UUID
WsConn *websocket.Conn // 用户连接
ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数 ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数
} }
func NewClient(userId UserId, socket *websocket.Conn) *Client { func NewClient(userId UserId, clientUuid string, socket *websocket.Conn) *Client {
cli := &Client{ cli := &Client{
ClientId: stringx.Rand(16), ClientId: stringx.Rand(16),
UserId: userId, UserId: userId,
WsConn: socket, ClientUuid: clientUuid,
WsConn: socket,
} }
return cli return cli
@@ -64,6 +67,8 @@ func (c *Client) Read() {
// 向客户端写入消息 // 向客户端写入消息
func (c *Client) WriteMsg(msg *Msg) error { func (c *Client) WriteMsg(msg *Msg) error {
logx.Debugf("发送消息: toUid=%v, data=%v", c.UserId, msg.Data)
if msg.Type == JsonMsg { if msg.Type == JsonMsg {
bytes, _ := json.Marshal(msg.Data) bytes, _ := json.Marshal(msg.Data)
return c.WsConn.WriteMessage(websocket.TextMessage, bytes) return c.WsConn.WriteMessage(websocket.TextMessage, bytes)

View File

@@ -11,7 +11,7 @@ const heartbeatInterval = 25 * time.Second
// 连接管理 // 连接管理
type ClientManager struct { type ClientManager struct {
ClientMap map[UserId]*Client // 全部的连接, key->userid, value->&client ClientMap map[string]*Client // 全部的连接, key->token, value->&client
RwLock sync.RWMutex // 读写锁 RwLock sync.RWMutex // 读写锁
ConnectChan chan *Client // 连接处理 ConnectChan chan *Client // 连接处理
@@ -21,7 +21,7 @@ type ClientManager struct {
func NewClientManager() (clientManager *ClientManager) { func NewClientManager() (clientManager *ClientManager) {
return &ClientManager{ return &ClientManager{
ClientMap: make(map[UserId]*Client), ClientMap: make(map[string]*Client),
ConnectChan: make(chan *Client, 10), ConnectChan: make(chan *Client, 10),
DisConnectChan: make(chan *Client, 10), DisConnectChan: make(chan *Client, 10),
MsgChan: make(chan *Msg, 100), MsgChan: make(chan *Msg, 100),
@@ -58,12 +58,12 @@ func (manager *ClientManager) CloseClient(client *Client) {
} }
// 根据用户id关闭客户端连接 // 根据用户id关闭客户端连接
func (manager *ClientManager) CloseByUid(uid UserId) { func (manager *ClientManager) CloseByClientUuid(clientUuid string) {
manager.CloseClient(manager.GetByUid(UserId(uid))) manager.CloseClient(manager.GetByClientUuid(clientUuid))
} }
// 获取所有的客户端 // 获取所有的客户端
func (manager *ClientManager) AllClient() map[UserId]*Client { func (manager *ClientManager) AllClient() map[string]*Client {
manager.RwLock.RLock() manager.RwLock.RLock()
defer manager.RwLock.RUnlock() defer manager.RwLock.RUnlock()
@@ -74,7 +74,19 @@ func (manager *ClientManager) AllClient() map[UserId]*Client {
func (manager *ClientManager) GetByUid(userId UserId) *Client { func (manager *ClientManager) GetByUid(userId UserId) *Client {
manager.RwLock.RLock() manager.RwLock.RLock()
defer manager.RwLock.RUnlock() defer manager.RwLock.RUnlock()
return manager.ClientMap[userId] for _, client := range manager.ClientMap {
if userId == client.UserId {
return client
}
}
return nil
}
// 通过userId获取
func (manager *ClientManager) GetByClientUuid(uuid string) *Client {
manager.RwLock.RLock()
defer manager.RwLock.RUnlock()
return manager.ClientMap[uuid]
} }
// 客户端数量 // 客户端数量
@@ -85,9 +97,8 @@ func (manager *ClientManager) Count() int {
} }
// 发送json数据给指定用户 // 发送json数据给指定用户
func (manager *ClientManager) SendJsonMsg(userId UserId, data any) { func (manager *ClientManager) SendJsonMsg(clientUuid string, data any) {
logx.Debugf("发送消息: toUid=%v, data=%v", userId, data) manager.MsgChan <- &Msg{ToClientUuid: clientUuid, Data: data, Type: JsonMsg}
manager.MsgChan <- &Msg{ToUserId: userId, Data: data, Type: JsonMsg}
} }
// 监听并发送给客户端信息 // 监听并发送给客户端信息
@@ -95,7 +106,7 @@ func (manager *ClientManager) WriteMessage() {
go func() { go func() {
for { for {
msg := <-manager.MsgChan msg := <-manager.MsgChan
if cli := manager.GetByUid(msg.ToUserId); cli != nil { if cli := manager.GetByClientUuid(msg.ToClientUuid); cli != nil {
if err := cli.WriteMsg(msg); err != nil { if err := cli.WriteMsg(msg); err != nil {
manager.CloseClient(cli) manager.CloseClient(cli)
} }
@@ -130,7 +141,7 @@ func (manager *ClientManager) HeartbeatTimer() {
// 处理建立连接 // 处理建立连接
func (manager *ClientManager) doConnect(client *Client) { func (manager *ClientManager) doConnect(client *Client) {
cli := manager.GetByUid(client.UserId) cli := manager.GetByClientUuid(client.ClientUuid)
if cli != nil { if cli != nil {
manager.doDisconnect(cli) manager.doDisconnect(cli)
} }
@@ -152,11 +163,11 @@ func (manager *ClientManager) doDisconnect(client *Client) {
func (manager *ClientManager) addClient2Map(client *Client) { func (manager *ClientManager) addClient2Map(client *Client) {
manager.RwLock.Lock() manager.RwLock.Lock()
defer manager.RwLock.Unlock() defer manager.RwLock.Unlock()
manager.ClientMap[client.UserId] = client manager.ClientMap[client.ClientUuid] = client
} }
func (manager *ClientManager) delClient4Map(client *Client) { func (manager *ClientManager) delClient4Map(client *Client) {
manager.RwLock.Lock() manager.RwLock.Lock()
defer manager.RwLock.Unlock() defer manager.RwLock.Unlock()
delete(manager.ClientMap, client.UserId) delete(manager.ClientMap, client.ClientUuid)
} }

View File

@@ -14,5 +14,6 @@ type Msg struct {
ToUserId UserId ToUserId UserId
Data any Data any
Type MsgType // 消息类型 Type MsgType // 消息类型
ToClientUuid string
} }

View File

@@ -21,18 +21,21 @@ func init() {
} }
// 添加ws客户端 // 添加ws客户端
func AddClient(userId uint64, conn *websocket.Conn) *Client { func AddClient(userId uint64, clientUuid string, conn *websocket.Conn) *Client {
cli := NewClient(UserId(userId), conn) if len(clientUuid) == 0 {
return nil
}
cli := NewClient(UserId(userId), clientUuid, conn)
cli.Read() cli.Read()
Manager.AddClient(cli) Manager.AddClient(cli)
return cli return cli
} }
func CloseClient(userid uint64) { func CloseClient(clientUuid string) {
Manager.CloseByUid(UserId(userid)) Manager.CloseByClientUuid(clientUuid)
} }
// 对指定用户发送json类型消息 // 对指定用户发送json类型消息
func SendJsonMsg(userId uint64, msg any) { func SendJsonMsg(clientUuid string, msg any) {
Manager.SendJsonMsg(UserId(userId), msg) Manager.SendJsonMsg(clientUuid, msg)
} }