mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 15:30:25 +08:00
Merge pull request #73 from kanzihuang/feat-notify
feature: 每个客户端独立处理后端发送的系统消息
This commit is contained in:
@@ -28,6 +28,7 @@
|
||||
"screenfull": "^6.0.2",
|
||||
"sortablejs": "^1.15.0",
|
||||
"sql-formatter": "^12.1.2",
|
||||
"uuid": "^9.0.1",
|
||||
"vue": "^3.3.4",
|
||||
"vue-clipboard3": "^1.0.1",
|
||||
"vue-router": "^4.2.5",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import router from '../router';
|
||||
import Axios from 'axios';
|
||||
import config from './config';
|
||||
import { getToken } from './utils/storage';
|
||||
import { getClientUuid, getToken, joinClientParams } from './utils/storage';
|
||||
import { templateResolve } from './utils/string';
|
||||
import { ElMessage } from 'element-plus';
|
||||
|
||||
@@ -54,6 +54,7 @@ service.interceptors.request.use(
|
||||
if (token) {
|
||||
// 设置token
|
||||
config.headers['Authorization'] = token;
|
||||
config.headers['Client-Uuid'] = getClientUuid();
|
||||
}
|
||||
return config;
|
||||
},
|
||||
@@ -176,7 +177,7 @@ function del(url: string, params: any = null, headers: any = null, options: any
|
||||
|
||||
function getApiUrl(url: string) {
|
||||
// 只是返回api地址而不做请求,用在上传组件之类的
|
||||
return baseUrl + url + '?token=' + getToken();
|
||||
return baseUrl + url + '?' + joinClientParams();
|
||||
}
|
||||
|
||||
export default {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import Config from './config';
|
||||
import { ElNotification, NotificationHandle } from 'element-plus';
|
||||
import SocketBuilder from './SocketBuilder';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { createVNode, reactive } from "vue";
|
||||
import { buildProgressProps } from "@/components/progress-notify/progress-notify";
|
||||
import { getToken, joinClientParams } from '@/common/utils/storage';
|
||||
import { createVNode, reactive } from 'vue';
|
||||
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
|
||||
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
|
||||
|
||||
export default {
|
||||
@@ -16,43 +16,40 @@ export default {
|
||||
return null;
|
||||
}
|
||||
const messageTypes = {
|
||||
0: "error",
|
||||
1: "success",
|
||||
2: "info",
|
||||
}
|
||||
const notifyMap: Map<Number, any> = new Map()
|
||||
0: 'error',
|
||||
1: 'success',
|
||||
2: 'info',
|
||||
};
|
||||
const notifyMap: Map<Number, any> = new Map();
|
||||
const sysMsgUrl = `${Config.baseWsUrl}/sysmsg?${joinClientParams()}`;
|
||||
|
||||
return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`)
|
||||
return SocketBuilder.builder(sysMsgUrl)
|
||||
.message((event: { data: string }) => {
|
||||
const message = JSON.parse(event.data);
|
||||
const type = messageTypes[message.type]
|
||||
const type = messageTypes[message.type];
|
||||
switch (message.category) {
|
||||
case "execSqlFileProgress":
|
||||
const content = JSON.parse(message.msg)
|
||||
const id = content.id
|
||||
let progress = notifyMap.get(id)
|
||||
case 'execSqlFileProgress':
|
||||
const content = JSON.parse(message.msg);
|
||||
const id = content.id;
|
||||
let progress = notifyMap.get(id);
|
||||
if (content.terminated) {
|
||||
if (progress != undefined) {
|
||||
progress.notification?.close()
|
||||
notifyMap.delete(id)
|
||||
progress = undefined
|
||||
progress.notification?.close();
|
||||
notifyMap.delete(id);
|
||||
progress = undefined;
|
||||
}
|
||||
return
|
||||
return;
|
||||
}
|
||||
if (progress == undefined) {
|
||||
progress = {
|
||||
props: reactive(buildProgressProps()),
|
||||
notification: undefined,
|
||||
}
|
||||
};
|
||||
}
|
||||
progress.props.progress.sqlFileName = content.sqlFileName
|
||||
progress.props.progress.executedStatements = content.executedStatements
|
||||
progress.props.progress.title = content.title;
|
||||
progress.props.progress.executedStatements = content.executedStatements;
|
||||
if (!notifyMap.has(id)) {
|
||||
const vNodeMessage = createVNode(
|
||||
ProgressNotify,
|
||||
progress.props,
|
||||
null,
|
||||
)
|
||||
const vNodeMessage = createVNode(ProgressNotify, progress.props, null);
|
||||
progress.notification = ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
@@ -60,7 +57,7 @@ export default {
|
||||
type: type,
|
||||
showClose: false,
|
||||
});
|
||||
notifyMap.set(id, progress)
|
||||
notifyMap.set(id, progress);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { v1 as uuidv1 } from 'uuid';
|
||||
|
||||
const TokenKey = 'token';
|
||||
const UserKey = 'user';
|
||||
const TagViewsKey = 'tagViews';
|
||||
const ClientUuid = 'clientUuid'
|
||||
|
||||
// 获取请求token
|
||||
export function getToken(): string {
|
||||
@@ -48,6 +51,21 @@ export function removeTagViews() {
|
||||
removeSession(TagViewsKey);
|
||||
}
|
||||
|
||||
// 获取客户端UUID
|
||||
export function getClientUuid(): string {
|
||||
let uuid = getSession(ClientUuid)
|
||||
if (uuid == null) {
|
||||
uuid = uuidv1()
|
||||
setSession(ClientUuid, uuid)
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
// 组装客户端参数,包括 token 和 clientUuid
|
||||
export function joinClientParams(): string {
|
||||
return `token=${getToken()}&clientUuid=${getClientUuid()}`
|
||||
}
|
||||
|
||||
// 1. localStorage
|
||||
// 设置永久缓存
|
||||
export function setLocal(key: string, val: any) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
export const buildProgressProps = (): any => {
|
||||
return {
|
||||
progress: {
|
||||
sqlFileName: {
|
||||
title: {
|
||||
type: String,
|
||||
},
|
||||
executedStatements: {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<el-descriptions border size="small" :title="`${progress.sqlFileName}`">
|
||||
<el-descriptions border size="small" :title="`${progress.title}`">
|
||||
<el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
|
||||
<el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
|
||||
@@ -132,7 +132,7 @@ import { nextTick, onMounted, ref, toRefs, reactive, computed } from 'vue';
|
||||
import { useRoute, useRouter } from 'vue-router';
|
||||
import { ElMessage } from 'element-plus';
|
||||
import { initRouter } from '@/router/index';
|
||||
import { saveToken, saveUser } from '@/common/utils/storage';
|
||||
import { saveToken, saveUser, setSession } from '@/common/utils/storage';
|
||||
import { formatAxis } from '@/common/utils/format';
|
||||
import openApi from '@/common/openApi';
|
||||
import { RsaEncrypt } from '@/common/rsa';
|
||||
@@ -364,7 +364,7 @@ const loginResDeal = (loginRes: any) => {
|
||||
useUserInfo().setUserInfo(userInfos);
|
||||
|
||||
const token = loginRes.token;
|
||||
// 如果不需要otp校验,则该token即为accessToken,否则为otp校验token
|
||||
// 如果不需要 otp校验,则该token即为accessToken,否则为otp校验token
|
||||
if (loginRes.otp == -1) {
|
||||
signInSuccess(token);
|
||||
return;
|
||||
@@ -385,6 +385,7 @@ const signInSuccess = async (accessToken: string = '') => {
|
||||
}
|
||||
// 存储 token 到浏览器缓存
|
||||
saveToken(accessToken);
|
||||
|
||||
// 初始化路由
|
||||
await initRouter();
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ import { ref, toRefs, reactive, onMounted, defineAsyncComponent } from 'vue';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
import { dbApi } from './api';
|
||||
import config from '@/common/config';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { joinClientParams } from '@/common/utils/storage';
|
||||
import { isTrue } from '@/common/assert';
|
||||
import { Search as SearchIcon } from '@element-plus/icons-vue';
|
||||
import { dateFormat } from '@/common/utils/date';
|
||||
@@ -406,7 +406,7 @@ const dumpDbs = () => {
|
||||
'href',
|
||||
`${config.baseApiUrl}/dbs/${state.exportDialog.dbId}/dump?db=${state.exportDialog.value.join(',')}&type=${type}&extName=${
|
||||
state.exportDialog.extName
|
||||
}&token=${getToken()}`
|
||||
}&${joinClientParams()}`
|
||||
);
|
||||
a.click();
|
||||
state.exportDialog.visible = false;
|
||||
|
||||
@@ -88,7 +88,7 @@
|
||||
|
||||
<script lang="ts" setup>
|
||||
import { nextTick, watch, onMounted, reactive, toRefs, ref, Ref } from 'vue';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { getToken, joinClientParams } from '@/common/utils/storage';
|
||||
import { isTrue, notBlank } from '@/common/assert';
|
||||
import { format as sqlFormatter } from 'sql-formatter';
|
||||
import config from '@/common/config';
|
||||
@@ -485,7 +485,7 @@ const execSqlFileSuccess = (res: any) => {
|
||||
|
||||
// 获取sql文件上传执行url
|
||||
const getUploadSqlFileUrl = () => {
|
||||
return `${config.baseApiUrl}/dbs/${state.ti.dbId}/exec-sql-file?db=${state.ti.db}`;
|
||||
return `${config.baseApiUrl}/dbs/${state.ti.dbId}/exec-sql-file?db=${state.ti.db}&${joinClientParams()}`;
|
||||
};
|
||||
|
||||
const onDataSelectionChange = (datas: []) => {
|
||||
|
||||
@@ -124,7 +124,7 @@ import { formatByteSize } from '@/common/utils/format';
|
||||
import { dbApi } from '../api';
|
||||
import SqlExecBox from '../component/SqlExecBox';
|
||||
import config from '@/common/config';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { joinClientParams } from '@/common/utils/storage';
|
||||
import { isTrue } from '@/common/assert';
|
||||
|
||||
const DbTableEdit = defineAsyncComponent(() => import('./DbTableEdit.vue'));
|
||||
@@ -259,7 +259,7 @@ const dump = (db: string) => {
|
||||
const a = document.createElement('a');
|
||||
a.setAttribute(
|
||||
'href',
|
||||
`${config.baseApiUrl}/dbs/${props.dbId}/dump?db=${db}&type=${state.dumpInfo.type}&tables=${state.dumpInfo.tables.join(',')}&token=${getToken()}`
|
||||
`${config.baseApiUrl}/dbs/${props.dbId}/dump?db=${db}&type=${state.dumpInfo.type}&tables=${state.dumpInfo.tables.join(',')}&${joinClientParams()}`
|
||||
);
|
||||
a.click();
|
||||
state.showDumpInfo = false;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import Api from '@/common/Api';
|
||||
import config from '@/common/config';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { joinClientParams } from '@/common/utils/storage';
|
||||
|
||||
export const machineApi = {
|
||||
// 获取权限列表
|
||||
@@ -33,7 +33,7 @@ export const machineApi = {
|
||||
cpFile: Api.newPost('/machines/{machineId}/files/{fileId}/cp'),
|
||||
renameFile: Api.newPost('/machines/{machineId}/files/{fileId}/rename'),
|
||||
mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'),
|
||||
uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?token={token}'),
|
||||
uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()),
|
||||
fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'),
|
||||
createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'),
|
||||
// 修改文件内容
|
||||
@@ -63,5 +63,5 @@ export const cronJobApi = {
|
||||
};
|
||||
|
||||
export function getMachineTerminalSocketUrl(machineId: any) {
|
||||
return `${config.baseWsUrl}/machines/${machineId}/terminal?token=${getToken()}`;
|
||||
return `${config.baseWsUrl}/machines/${machineId}/terminal?${joinClientParams()}`;
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ import { ref, toRefs, reactive, onMounted, computed } from 'vue';
|
||||
import { ElMessage, ElMessageBox, ElInput } from 'element-plus';
|
||||
import { machineApi } from '../api';
|
||||
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { joinClientParams } from '@/common/utils/storage';
|
||||
import config from '@/common/config';
|
||||
import { isTrue } from '@/common/assert';
|
||||
import MachineFileContent from './MachineFileContent.vue';
|
||||
@@ -607,7 +607,7 @@ const deleteFile = async (files: any) => {
|
||||
|
||||
const downloadFile = (data: any) => {
|
||||
const a = document.createElement('a');
|
||||
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&token=${token}`);
|
||||
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`);
|
||||
a.click();
|
||||
};
|
||||
|
||||
@@ -628,7 +628,7 @@ function getFolder(e: any) {
|
||||
// 上传操作
|
||||
machineApi.uploadFile
|
||||
.request(form, {
|
||||
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload-folder?token=${token}`,
|
||||
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload-folder?${joinClientParams()}`,
|
||||
headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' },
|
||||
onUploadProgress: onUploadProgress,
|
||||
baseURL: '',
|
||||
@@ -669,7 +669,7 @@ const getUploadFile = (content: any) => {
|
||||
params.append('token', token);
|
||||
machineApi.uploadFile
|
||||
.request(params, {
|
||||
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload?token=${token}`,
|
||||
url: `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/upload?${joinClientParams()}`,
|
||||
headers: { 'Content-Type': 'multipart/form-data; boundary=----WebKitFormBoundaryF1uyUD0tWdqmJqpl' },
|
||||
onUploadProgress: onUploadProgress,
|
||||
baseURL: '',
|
||||
|
||||
@@ -179,7 +179,7 @@ import { dateFormat } from '@/common/utils/date';
|
||||
import { storeToRefs } from 'pinia';
|
||||
import { useUserInfo } from '@/store/userInfo';
|
||||
import config from '@/common/config';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { joinClientParams } from '@/common/utils/storage';
|
||||
|
||||
const { userInfo } = storeToRefs(useUserInfo());
|
||||
const state = reactive({
|
||||
@@ -248,7 +248,7 @@ const bindOAuth2 = () => {
|
||||
var iLeft = (window.screen.width - 10 - width) / 2; //获得窗口的水平位置;
|
||||
// 小窗口打开oauth2鉴权
|
||||
let oauthWindow = window.open(
|
||||
config.baseApiUrl + '/auth/oauth2/bind?token=' + getToken(),
|
||||
`${config.baseApiUrl}/auth/oauth2/bind?${joinClientParams()}`,
|
||||
'oauth2',
|
||||
`height=${height},width=${width},top=${iTop},left=${iLeft},location=no`
|
||||
);
|
||||
|
||||
@@ -1781,6 +1781,11 @@ uri-js@^4.2.2:
|
||||
dependencies:
|
||||
punycode "^2.1.0"
|
||||
|
||||
uuid@^9.0.1:
|
||||
version "9.0.1"
|
||||
resolved "https://registry.npmmirror.com/uuid/-/uuid-9.0.1.tgz#e188d4c8853cc722220392c424cd637f32293f30"
|
||||
integrity sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==
|
||||
|
||||
vite@^4.4.11:
|
||||
version "4.4.11"
|
||||
resolved "https://registry.npmmirror.com/vite/-/vite-4.4.11.tgz#babdb055b08c69cfc4c468072a2e6c9ca62102b0"
|
||||
|
||||
@@ -14,7 +14,7 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231014104824-e3b9aa5415a4
|
||||
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231018071450-ac8d9f0167e9
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
|
||||
github.com/mojocn/base64Captcha v1.3.5 // 验证码
|
||||
|
||||
@@ -119,5 +119,5 @@ func (a *AccountLogin) OtpVerify(rc *req.Ctx) {
|
||||
|
||||
func (a *AccountLogin) Logout(rc *req.Ctx) {
|
||||
req.GetPermissionCodeRegistery().Remove(rc.LoginAccount.Id)
|
||||
ws.CloseClient(rc.LoginAccount.Id)
|
||||
ws.CloseClient(rc.LoginAccount.ClientUuid)
|
||||
}
|
||||
|
||||
@@ -121,11 +121,34 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
LoginAccount: rc.LoginAccount,
|
||||
}
|
||||
|
||||
sqls, err := sqlparser.SplitStatementToPieces(sql)
|
||||
sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
|
||||
biz.ErrIsNil(err, "SQL解析错误,请检查您的执行SQL")
|
||||
isMulti := len(sqls) > 1
|
||||
var execResAll *application.DbSqlExecRes
|
||||
|
||||
progressId := uniqueid.IncrementID()
|
||||
executedStatements := 0
|
||||
progressTitle := fmt.Sprintf("%s/%s", dbConn.Info.Name, dbConn.Info.Database)
|
||||
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
Title: progressTitle,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
for _, s := range sqls {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
Title: progressTitle,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
executedStatements++
|
||||
s = stringx.TrimSpaceAndBr(s)
|
||||
// 多条执行,如果有查询语句,则跳过
|
||||
if isMulti && strings.HasPrefix(strings.ToLower(s), "select") {
|
||||
@@ -156,7 +179,7 @@ const progressCategory = "execSqlFileProgress"
|
||||
// progressMsg sql文件执行进度消息
|
||||
type progressMsg struct {
|
||||
Id uint64 `json:"id"`
|
||||
SqlFileName string `json:"sqlFileName"`
|
||||
Title string `json:"title"`
|
||||
ExecutedStatements int `json:"executedStatements"`
|
||||
Terminated bool `json:"terminated"`
|
||||
}
|
||||
@@ -198,30 +221,33 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
LoginAccount: rc.LoginAccount,
|
||||
}
|
||||
|
||||
var sql string
|
||||
|
||||
tokenizer := sqlparser.NewReaderTokenizer(file,
|
||||
sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
|
||||
|
||||
progressId := uniqueid.IncrementID()
|
||||
executedStatements := 0
|
||||
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
defer ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
SqlFileName: filename,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
|
||||
var sql string
|
||||
tokenizer := sqlparser.NewReaderTokenizer(file, sqlparser.WithCacheInBuffer())
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
ws.SendJsonMsg(rc.LoginAccount.ClientUuid, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
SqlFileName: filename,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
executedStatements++
|
||||
sql, err = sqlparser.SplitNext(tokenizer)
|
||||
if err == io.EOF {
|
||||
break
|
||||
@@ -259,7 +285,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
|
||||
return
|
||||
}
|
||||
executedStatements++
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成:%s", rc.ReqParam)))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"github.com/lib/pq"
|
||||
"strings"
|
||||
)
|
||||
@@ -59,6 +60,17 @@ func (dbType DbType) StmtSelectDbName() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) Dialect() sqlparser.Dialect {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
return sqlparser.MysqlDialect{}
|
||||
case DbTypePostgres:
|
||||
return sqlparser.PostgresDialect{}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
|
||||
// used as part of an SQL statement. For example:
|
||||
//
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_escapeSql(t *testing.T) {
|
||||
func Test_QuoteLiteral(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType DbType
|
||||
sql string
|
||||
want string
|
||||
@@ -24,7 +22,6 @@ func Test_escapeSql(t *testing.T) {
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: DbTypeMysql,
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
@@ -40,14 +37,13 @@ func Test_escapeSql(t *testing.T) {
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: DbTypePostgres,
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run(string(tt.dbType)+"_"+tt.sql, func(t *testing.T) {
|
||||
got := tt.dbType.QuoteLiteral(tt.sql)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
|
||||
@@ -40,5 +40,5 @@ func (a *msgAppImpl) CreateAndSend(la *model.LoginAccount, wmsg *dto.SysMsg) {
|
||||
now := time.Now()
|
||||
msg := &entity.Msg{Type: 2, Msg: wmsg.Msg, RecipientId: int64(la.Id), CreateTime: &now, CreatorId: la.Id, Creator: la.Username}
|
||||
a.msgRepo.Insert(msg)
|
||||
ws.SendJsonMsg(la.Id, wmsg)
|
||||
ws.SendJsonMsg(la.ClientUuid, wmsg)
|
||||
}
|
||||
|
||||
@@ -81,7 +81,12 @@ func (a *Account) ChangePassword(rc *req.Ctx) {
|
||||
a.AccountApp.Update(updateAccount)
|
||||
|
||||
// 赋值loginAccount 主要用于记录操作日志,因为操作日志保存请求上下文没有该信息不保存日志
|
||||
rc.LoginAccount = &model.LoginAccount{Id: account.Id, Username: account.Username}
|
||||
if rc.LoginAccount == nil {
|
||||
rc.LoginAccount = &model.LoginAccount{
|
||||
Id: account.Id,
|
||||
Username: account.Username,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取个人账号信息
|
||||
|
||||
@@ -37,5 +37,7 @@ func (s *System) ConnectWs(g *gin.Context) {
|
||||
|
||||
// 登录账号信息
|
||||
la := rc.LoginAccount
|
||||
ws.AddClient(la.Id, wsConn)
|
||||
if la != nil {
|
||||
ws.AddClient(la.Id, la.ClientUuid, wsConn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,4 +3,7 @@ package model
|
||||
type LoginAccount struct {
|
||||
Id uint64
|
||||
Username string
|
||||
|
||||
// ClientUuid 客户端UUID
|
||||
ClientUuid string
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/config"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/rediscli"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"time"
|
||||
@@ -49,18 +50,28 @@ func PermissionHandler(rc *Ctx) error {
|
||||
if tokenStr == "" {
|
||||
return biz.PermissionErr
|
||||
}
|
||||
loginAccount, err := ParseToken(tokenStr)
|
||||
if err != nil || loginAccount == nil {
|
||||
userId, userName, err := ParseToken(tokenStr)
|
||||
if err != nil || userId == 0 {
|
||||
return biz.PermissionErr
|
||||
}
|
||||
// 权限不为nil,并且permission code不为空,则校验是否有权限code
|
||||
if permission != nil && permission.Code != "" {
|
||||
if !permissionCodeRegistry.HasCode(loginAccount.Id, permission.Code) {
|
||||
if !permissionCodeRegistry.HasCode(userId, permission.Code) {
|
||||
return biz.PermissionErr
|
||||
}
|
||||
}
|
||||
|
||||
rc.LoginAccount = loginAccount
|
||||
clientUuid := rc.GinCtx.Request.Header.Get("Client-Uuid")
|
||||
// header不存在则从查询参数token中获取
|
||||
if clientUuid == "" {
|
||||
clientUuid = rc.GinCtx.Query("clientUuid")
|
||||
}
|
||||
if rc.LoginAccount == nil {
|
||||
rc.LoginAccount = &model.LoginAccount{
|
||||
Id: userId,
|
||||
Username: userName,
|
||||
ClientUuid: clientUuid,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,8 @@ package req
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/config"
|
||||
"mayfly-go/pkg/model"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -28,9 +26,9 @@ func CreateToken(userId uint64, username string) string {
|
||||
}
|
||||
|
||||
// 解析token,并返回登录者账号信息
|
||||
func ParseToken(tokenStr string) (*model.LoginAccount, error) {
|
||||
func ParseToken(tokenStr string) (uint64, string, error) {
|
||||
if tokenStr == "" {
|
||||
return nil, errors.New("token error")
|
||||
return 0, "", errors.New("token error")
|
||||
}
|
||||
|
||||
// Parse token
|
||||
@@ -38,8 +36,8 @@ func ParseToken(tokenStr string) (*model.LoginAccount, error) {
|
||||
return []byte(config.Conf.Jwt.Key), nil
|
||||
})
|
||||
if err != nil || token == nil {
|
||||
return nil, err
|
||||
return 0, "", err
|
||||
}
|
||||
i := token.Claims.(jwt.MapClaims)
|
||||
return &model.LoginAccount{Id: uint64(i["id"].(float64)), Username: i["username"].(string)}, nil
|
||||
return uint64(i["id"].(float64)), i["username"].(string), nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package ws
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"time"
|
||||
|
||||
@@ -16,18 +17,20 @@ type UserId uint64
|
||||
type ReadMsgHandlerFunc func([]byte)
|
||||
|
||||
type Client struct {
|
||||
ClientId string // 标识ID
|
||||
UserId UserId // 用户ID
|
||||
WsConn *websocket.Conn // 用户连接
|
||||
ClientId string // 标识ID
|
||||
UserId UserId // 用户ID
|
||||
ClientUuid string // 客户端UUID
|
||||
WsConn *websocket.Conn // 用户连接
|
||||
|
||||
ReadMsgHander ReadMsgHandlerFunc // 读取消息处理函数
|
||||
}
|
||||
|
||||
func NewClient(userId UserId, socket *websocket.Conn) *Client {
|
||||
func NewClient(userId UserId, clientUuid string, socket *websocket.Conn) *Client {
|
||||
cli := &Client{
|
||||
ClientId: stringx.Rand(16),
|
||||
UserId: userId,
|
||||
WsConn: socket,
|
||||
ClientId: stringx.Rand(16),
|
||||
UserId: userId,
|
||||
ClientUuid: clientUuid,
|
||||
WsConn: socket,
|
||||
}
|
||||
|
||||
return cli
|
||||
@@ -64,6 +67,8 @@ func (c *Client) Read() {
|
||||
|
||||
// 向客户端写入消息
|
||||
func (c *Client) WriteMsg(msg *Msg) error {
|
||||
logx.Debugf("发送消息: toUid=%v, data=%v", c.UserId, msg.Data)
|
||||
|
||||
if msg.Type == JsonMsg {
|
||||
bytes, _ := json.Marshal(msg.Data)
|
||||
return c.WsConn.WriteMessage(websocket.TextMessage, bytes)
|
||||
|
||||
@@ -11,7 +11,7 @@ const heartbeatInterval = 25 * time.Second
|
||||
|
||||
// 连接管理
|
||||
type ClientManager struct {
|
||||
ClientMap map[UserId]*Client // 全部的连接, key->userid, value->&client
|
||||
ClientMap map[string]*Client // 全部的连接, key->token, value->&client
|
||||
RwLock sync.RWMutex // 读写锁
|
||||
|
||||
ConnectChan chan *Client // 连接处理
|
||||
@@ -21,7 +21,7 @@ type ClientManager struct {
|
||||
|
||||
func NewClientManager() (clientManager *ClientManager) {
|
||||
return &ClientManager{
|
||||
ClientMap: make(map[UserId]*Client),
|
||||
ClientMap: make(map[string]*Client),
|
||||
ConnectChan: make(chan *Client, 10),
|
||||
DisConnectChan: make(chan *Client, 10),
|
||||
MsgChan: make(chan *Msg, 100),
|
||||
@@ -58,12 +58,12 @@ func (manager *ClientManager) CloseClient(client *Client) {
|
||||
}
|
||||
|
||||
// 根据用户id关闭客户端连接
|
||||
func (manager *ClientManager) CloseByUid(uid UserId) {
|
||||
manager.CloseClient(manager.GetByUid(UserId(uid)))
|
||||
func (manager *ClientManager) CloseByClientUuid(clientUuid string) {
|
||||
manager.CloseClient(manager.GetByClientUuid(clientUuid))
|
||||
}
|
||||
|
||||
// 获取所有的客户端
|
||||
func (manager *ClientManager) AllClient() map[UserId]*Client {
|
||||
func (manager *ClientManager) AllClient() map[string]*Client {
|
||||
manager.RwLock.RLock()
|
||||
defer manager.RwLock.RUnlock()
|
||||
|
||||
@@ -74,7 +74,19 @@ func (manager *ClientManager) AllClient() map[UserId]*Client {
|
||||
func (manager *ClientManager) GetByUid(userId UserId) *Client {
|
||||
manager.RwLock.RLock()
|
||||
defer manager.RwLock.RUnlock()
|
||||
return manager.ClientMap[userId]
|
||||
for _, client := range manager.ClientMap {
|
||||
if userId == client.UserId {
|
||||
return client
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 通过userId获取
|
||||
func (manager *ClientManager) GetByClientUuid(uuid string) *Client {
|
||||
manager.RwLock.RLock()
|
||||
defer manager.RwLock.RUnlock()
|
||||
return manager.ClientMap[uuid]
|
||||
}
|
||||
|
||||
// 客户端数量
|
||||
@@ -85,9 +97,8 @@ func (manager *ClientManager) Count() int {
|
||||
}
|
||||
|
||||
// 发送json数据给指定用户
|
||||
func (manager *ClientManager) SendJsonMsg(userId UserId, data any) {
|
||||
logx.Debugf("发送消息: toUid=%v, data=%v", userId, data)
|
||||
manager.MsgChan <- &Msg{ToUserId: userId, Data: data, Type: JsonMsg}
|
||||
func (manager *ClientManager) SendJsonMsg(clientUuid string, data any) {
|
||||
manager.MsgChan <- &Msg{ToClientUuid: clientUuid, Data: data, Type: JsonMsg}
|
||||
}
|
||||
|
||||
// 监听并发送给客户端信息
|
||||
@@ -95,7 +106,7 @@ func (manager *ClientManager) WriteMessage() {
|
||||
go func() {
|
||||
for {
|
||||
msg := <-manager.MsgChan
|
||||
if cli := manager.GetByUid(msg.ToUserId); cli != nil {
|
||||
if cli := manager.GetByClientUuid(msg.ToClientUuid); cli != nil {
|
||||
if err := cli.WriteMsg(msg); err != nil {
|
||||
manager.CloseClient(cli)
|
||||
}
|
||||
@@ -130,7 +141,7 @@ func (manager *ClientManager) HeartbeatTimer() {
|
||||
|
||||
// 处理建立连接
|
||||
func (manager *ClientManager) doConnect(client *Client) {
|
||||
cli := manager.GetByUid(client.UserId)
|
||||
cli := manager.GetByClientUuid(client.ClientUuid)
|
||||
if cli != nil {
|
||||
manager.doDisconnect(cli)
|
||||
}
|
||||
@@ -152,11 +163,11 @@ func (manager *ClientManager) doDisconnect(client *Client) {
|
||||
func (manager *ClientManager) addClient2Map(client *Client) {
|
||||
manager.RwLock.Lock()
|
||||
defer manager.RwLock.Unlock()
|
||||
manager.ClientMap[client.UserId] = client
|
||||
manager.ClientMap[client.ClientUuid] = client
|
||||
}
|
||||
|
||||
func (manager *ClientManager) delClient4Map(client *Client) {
|
||||
manager.RwLock.Lock()
|
||||
defer manager.RwLock.Unlock()
|
||||
delete(manager.ClientMap, client.UserId)
|
||||
delete(manager.ClientMap, client.ClientUuid)
|
||||
}
|
||||
|
||||
@@ -14,5 +14,6 @@ type Msg struct {
|
||||
ToUserId UserId
|
||||
Data any
|
||||
|
||||
Type MsgType // 消息类型
|
||||
Type MsgType // 消息类型
|
||||
ToClientUuid string
|
||||
}
|
||||
|
||||
@@ -21,18 +21,21 @@ func init() {
|
||||
}
|
||||
|
||||
// 添加ws客户端
|
||||
func AddClient(userId uint64, conn *websocket.Conn) *Client {
|
||||
cli := NewClient(UserId(userId), conn)
|
||||
func AddClient(userId uint64, clientUuid string, conn *websocket.Conn) *Client {
|
||||
if len(clientUuid) == 0 {
|
||||
return nil
|
||||
}
|
||||
cli := NewClient(UserId(userId), clientUuid, conn)
|
||||
cli.Read()
|
||||
Manager.AddClient(cli)
|
||||
return cli
|
||||
}
|
||||
|
||||
func CloseClient(userid uint64) {
|
||||
Manager.CloseByUid(UserId(userid))
|
||||
func CloseClient(clientUuid string) {
|
||||
Manager.CloseByClientUuid(clientUuid)
|
||||
}
|
||||
|
||||
// 对指定用户发送json类型消息
|
||||
func SendJsonMsg(userId uint64, msg any) {
|
||||
Manager.SendJsonMsg(UserId(userId), msg)
|
||||
func SendJsonMsg(clientUuid string, msg any) {
|
||||
Manager.SendJsonMsg(clientUuid, msg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user