mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 15:30:25 +08:00
refactor: dbm
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"@vueuse/core": "^11.3.0",
|
||||
"@vueuse/core": "^12.0.0",
|
||||
"asciinema-player": "^3.8.1",
|
||||
"axios": "^1.6.2",
|
||||
"clipboard": "^2.0.11",
|
||||
@@ -19,7 +19,7 @@
|
||||
"crypto-js": "^4.2.0",
|
||||
"dayjs": "^1.11.13",
|
||||
"echarts": "^5.5.1",
|
||||
"element-plus": "^2.8.8",
|
||||
"element-plus": "^2.9.0",
|
||||
"js-base64": "^3.7.7",
|
||||
"jsencrypt": "^3.3.2",
|
||||
"lodash": "^4.17.21",
|
||||
@@ -28,16 +28,16 @@
|
||||
"monaco-sql-languages": "^0.12.2",
|
||||
"monaco-themes": "^0.4.4",
|
||||
"nprogress": "^0.2.0",
|
||||
"pinia": "^2.2.7",
|
||||
"qrcode.vue": "^3.5.1",
|
||||
"pinia": "^2.3.0",
|
||||
"qrcode.vue": "^3.6.0",
|
||||
"screenfull": "^6.0.2",
|
||||
"sortablejs": "^1.15.3",
|
||||
"sortablejs": "^1.15.6",
|
||||
"splitpanes": "^3.1.5",
|
||||
"sql-formatter": "^15.4.5",
|
||||
"trzsz": "^1.1.5",
|
||||
"uuid": "^9.0.1",
|
||||
"vue": "^3.5.13",
|
||||
"vue-i18n": "^10.0.4",
|
||||
"vue-i18n": "^10.0.5",
|
||||
"vue-router": "^4.5.0",
|
||||
"xterm": "^5.3.0",
|
||||
"xterm-addon-fit": "^0.8.0",
|
||||
@@ -59,9 +59,9 @@
|
||||
"eslint": "^8.35.0",
|
||||
"eslint-plugin-vue": "^9.31.0",
|
||||
"prettier": "^3.2.5",
|
||||
"sass": "^1.81.0",
|
||||
"sass": "^1.82.0",
|
||||
"typescript": "^5.7.2",
|
||||
"vite": "^6.0.1",
|
||||
"vite": "^6.0.3",
|
||||
"vue-eslint-parser": "^9.4.3"
|
||||
},
|
||||
"browserslist": [
|
||||
|
||||
@@ -7,7 +7,7 @@ export default {
|
||||
getPublicKey: () => request.get('/common/public-key'),
|
||||
getConfigValue: (params: any) => request.get('/sys/configs/value', params),
|
||||
getServerConf: () => request.get('/sys/configs/server'),
|
||||
oauth2LoginConfig: () => request.get('/auth/oauth2-config'),
|
||||
oauth2LoginConfig: () => request.get('/auth/oauth2/config'),
|
||||
changePwd: (param: any) => request.post('/sys/accounts/change-pwd', param),
|
||||
captcha: () => request.get('/sys/captcha'),
|
||||
logout: () => request.post('/auth/accounts/logout'),
|
||||
|
||||
47
frontend/src/common/sysmsgs.ts
Normal file
47
frontend/src/common/sysmsgs.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
|
||||
import syssocket from './syssocket';
|
||||
import { h, reactive } from 'vue';
|
||||
import { ElNotification } from 'element-plus';
|
||||
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
|
||||
|
||||
export function initSysMsgs() {
|
||||
registerDbSqlExecProgress();
|
||||
}
|
||||
|
||||
const sqlExecNotifyMap: Map<string, any> = new Map();
|
||||
|
||||
function registerDbSqlExecProgress() {
|
||||
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
|
||||
const content = JSON.parse(message.msg);
|
||||
const id = content.id;
|
||||
let progress = sqlExecNotifyMap.get(id);
|
||||
if (content.terminated) {
|
||||
if (progress != undefined) {
|
||||
progress.notification?.close();
|
||||
sqlExecNotifyMap.delete(id);
|
||||
progress = undefined;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (progress == undefined) {
|
||||
progress = {
|
||||
props: reactive(buildProgressProps()),
|
||||
notification: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
progress.props.progress.title = content.title;
|
||||
progress.props.progress.executedStatements = content.executedStatements;
|
||||
if (!sqlExecNotifyMap.has(id)) {
|
||||
progress.notification = ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
message: h(ProgressNotify, progress.props),
|
||||
type: syssocket.getMsgType(message.type),
|
||||
showClose: false,
|
||||
});
|
||||
sqlExecNotifyMap.set(id, progress);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import Config from './config';
|
||||
import {ElNotification} from 'element-plus';
|
||||
import SocketBuilder from './SocketBuilder';
|
||||
import {getToken} from '@/common/utils/storage';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
|
||||
import {joinClientParams} from './request';
|
||||
import { joinClientParams } from './request';
|
||||
import { ElNotification } from 'element-plus';
|
||||
|
||||
class SysSocket {
|
||||
/**
|
||||
@@ -23,7 +23,6 @@ class SysSocket {
|
||||
0: 'error',
|
||||
1: 'success',
|
||||
2: 'info',
|
||||
22: 'info',
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -57,21 +56,16 @@ class SysSocket {
|
||||
return;
|
||||
}
|
||||
|
||||
// 默认通知处理
|
||||
const type = this.getMsgType(message.type);
|
||||
let msg = message.msg
|
||||
let duration = 0
|
||||
if (message.type == 22) {
|
||||
let obj = JSON.parse(msg);
|
||||
msg = `文件:${obj['title']} 执行成功: ${obj['executedStatements']} 条`
|
||||
duration = 2000
|
||||
}
|
||||
let msg = message.msg;
|
||||
let duration = 0;
|
||||
ElNotification({
|
||||
duration: duration,
|
||||
title: message.title,
|
||||
message: msg,
|
||||
type: type,
|
||||
});
|
||||
console.log(message)
|
||||
})
|
||||
.open((event: any) => console.log(event))
|
||||
.close(() => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<el-descriptions border size="small" :title="`${progress.title}`">
|
||||
<el-descriptions border size="small" :title="`${props.progress.title}`">
|
||||
<el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
|
||||
<el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
|
||||
@@ -214,7 +214,7 @@ export default {
|
||||
getDbNamesModeAssign: 'Specifying the db name',
|
||||
|
||||
ignore: 'Ignore',
|
||||
replate: 'Replate',
|
||||
replace: 'Replate',
|
||||
|
||||
running: 'Running',
|
||||
waitRun: 'Wait Run',
|
||||
|
||||
@@ -210,7 +210,7 @@ export default {
|
||||
getDbNamesModeAssign: '指定库名',
|
||||
|
||||
ignore: '忽略',
|
||||
replate: '替换',
|
||||
replace: '替换',
|
||||
|
||||
running: '运行中',
|
||||
waitRun: '待运行',
|
||||
|
||||
@@ -18,11 +18,13 @@ import '@/theme/index.scss';
|
||||
import '@/assets/font/font.css';
|
||||
import '@/assets/iconfont/iconfont.js';
|
||||
import { getThemeConfig } from './common/utils/storage';
|
||||
import { initSysMsgs } from './common/sysmsgs';
|
||||
|
||||
const app = createApp(App);
|
||||
|
||||
registElSvgIcon(app);
|
||||
directive(app);
|
||||
initSysMsgs();
|
||||
|
||||
app.use(pinia).use(router).use(i18n).use(ElementPlus, { size: getThemeConfig()?.globalComponentSize }).mount('#app');
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@
|
||||
v-model:selection-data="state.selectionData"
|
||||
:columns="columns"
|
||||
@page-num-change="
|
||||
(args) => {
|
||||
(args: any) => {
|
||||
state.query.pageNum = args.pageNum;
|
||||
search();
|
||||
}
|
||||
"
|
||||
@page-size-change="
|
||||
(args) => {
|
||||
(args: any) => {
|
||||
state.query.pageSize = args.pageNum;
|
||||
search();
|
||||
}
|
||||
@@ -85,14 +85,14 @@ import { dbApi } from '@/views/ops/db/api';
|
||||
import { getDbDialect } from '@/views/ops/db/dialect';
|
||||
import PageTable from '@/components/pagetable/PageTable.vue';
|
||||
import { TableColumn } from '@/components/pagetable';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
import { ElMessage } from 'element-plus';
|
||||
import { hasPerms } from '@/components/auth/auth';
|
||||
import TerminalLog from '@/components/terminal/TerminalLog.vue';
|
||||
import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue';
|
||||
import { getClientId } from '@/common/utils/storage';
|
||||
import FileInfo from '@/components/file/FileInfo.vue';
|
||||
import { DbTransferFileStatusEnum } from './enums';
|
||||
import { useI18nDeleteConfirm, useI18nDeleteSuccessMsg, useI18nFormValidate, useI18nPleaseSelect, useI18nSaveSuccessMsg } from '@/hooks/useI18n';
|
||||
import { useI18nDeleteConfirm, useI18nDeleteSuccessMsg, useI18nFormValidate, useI18nOperateSuccessMsg, useI18nPleaseSelect } from '@/hooks/useI18n';
|
||||
import { useI18n } from 'vue-i18n';
|
||||
|
||||
const { t } = useI18n();
|
||||
@@ -179,14 +179,13 @@ const state = reactive({
|
||||
},
|
||||
btnOk: async function () {
|
||||
await useI18nFormValidate(runFormRef);
|
||||
console.log(state.runDialog.runForm);
|
||||
if (state.runDialog.runForm.targetDbType !== state.runDialog.runForm.dbType) {
|
||||
ElMessage.warning(t('db.targetDbTypeSelectError', { dbType: state.runDialog.runForm.dbType }));
|
||||
return false;
|
||||
}
|
||||
state.runDialog.runForm.clientId = getClientId();
|
||||
await dbApi.dbTransferFileRun.request(state.runDialog.runForm);
|
||||
useI18nSaveSuccessMsg();
|
||||
useI18nOperateSuccessMsg();
|
||||
state.runDialog.cancel();
|
||||
await search();
|
||||
},
|
||||
@@ -228,9 +227,7 @@ const openLog = function (data: any) {
|
||||
|
||||
// 运行sql,弹出选择需要运行的库,默认运行当前数据库,需要保证数据库类型与sql文件一致
|
||||
const openRun = function (data: any) {
|
||||
console.log(data);
|
||||
state.runDialog.runForm = { id: data.id, dbType: data.fileDbType } as any;
|
||||
console.log(state.runDialog.runForm);
|
||||
state.runDialog.visible = true;
|
||||
};
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ const editInstance = async (data: any) => {
|
||||
|
||||
const deleteInstance = async () => {
|
||||
try {
|
||||
useI18nDeleteConfirm(state.selectionData.map((x: any) => x.name).join('、'));
|
||||
await useI18nDeleteConfirm(state.selectionData.map((x: any) => x.name).join('、'));
|
||||
await dbApi.deleteInstance.request({ id: state.selectionData.map((x: any) => x.id).join(',') });
|
||||
useI18nDeleteSuccessMsg();
|
||||
search();
|
||||
|
||||
@@ -140,7 +140,7 @@
|
||||
<el-table-column prop="src" :label="$t('db.srcField')" :width="200" />
|
||||
<el-table-column prop="target" :label="$t('db.targetField')">
|
||||
<template #default="scope">
|
||||
<el-select v-model="scope.row.target">
|
||||
<el-select v-model="scope.row.target" allow-create filterable>
|
||||
<el-option
|
||||
v-for="item in state.targetColumnList"
|
||||
:key="item.columnName"
|
||||
@@ -502,11 +502,13 @@ const handleGetSrcFields = async () => {
|
||||
sql,
|
||||
});
|
||||
|
||||
if (!res.columns) {
|
||||
if (res.length && !res[0].columns) {
|
||||
ElMessage.warning(t('db.notColumnSql'));
|
||||
return;
|
||||
}
|
||||
|
||||
let data = res[0];
|
||||
|
||||
let filedMap: any = {};
|
||||
if (state.form.fieldMap && state.form.fieldMap.length > 0) {
|
||||
state.form.fieldMap.forEach((a: any) => {
|
||||
@@ -514,11 +516,11 @@ const handleGetSrcFields = async () => {
|
||||
});
|
||||
}
|
||||
|
||||
state.srcTableFields = res.columns.map((a: any) => a.name);
|
||||
state.srcTableFields = data.columns.map((a: any) => a.name);
|
||||
|
||||
state.form.fieldMap = res.columns.map((a: any) => ({ src: a.name, target: filedMap[a.name] || '' }));
|
||||
state.form.fieldMap = data.columns.map((a: any) => ({ src: a.name, target: filedMap[a.name] || '' }));
|
||||
|
||||
state.previewRes = res;
|
||||
state.previewRes = data;
|
||||
};
|
||||
|
||||
const handleGetTargetFields = async () => {
|
||||
|
||||
@@ -104,7 +104,7 @@
|
||||
</el-row>
|
||||
<db-table-data
|
||||
v-if="!dt.errorMsg"
|
||||
:ref="(el) => (dt.dbTableRef = el)"
|
||||
:ref="(el: any) => (dt.dbTableRef = el)"
|
||||
:db-id="dbId"
|
||||
:db="dbName"
|
||||
:data="dt.data"
|
||||
@@ -128,12 +128,12 @@
|
||||
</template>
|
||||
|
||||
<script lang="ts" setup>
|
||||
import { h, nextTick, onMounted, reactive, ref, toRefs, unref } from 'vue';
|
||||
import { nextTick, onMounted, reactive, ref, toRefs, unref } from 'vue';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { notBlank } from '@/common/assert';
|
||||
import { format as sqlFormatter } from 'sql-formatter';
|
||||
import config from '@/common/config';
|
||||
import { ElMessage, ElMessageBox, ElNotification } from 'element-plus';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
|
||||
import * as monaco from 'monaco-editor/esm/vs/editor/editor.api';
|
||||
import { editor } from 'monaco-editor';
|
||||
@@ -144,9 +144,6 @@ import { dbApi } from '../../api';
|
||||
|
||||
import MonacoEditor from '@/components/monaco/MonacoEditor.vue';
|
||||
import { joinClientParams } from '@/common/request';
|
||||
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
|
||||
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
|
||||
import syssocket from '@/common/syssocket';
|
||||
import SvgIcon from '@/components/svgIcon/index.vue';
|
||||
import { Pane, Splitpanes } from 'splitpanes';
|
||||
import { useI18n } from 'vue-i18n';
|
||||
@@ -593,44 +590,8 @@ const replaceSelection = (str: string, selection: any) => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* sql文件执行进度通知缓存
|
||||
*/
|
||||
const sqlExecNotifyMap: Map<string, any> = new Map();
|
||||
const beforeUpload = (file: File) => {
|
||||
ElMessage.success(t('db.scriptFileUploadRunning', { filename: file.name }));
|
||||
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
|
||||
const content = JSON.parse(message.msg);
|
||||
const id = content.id;
|
||||
let progress = sqlExecNotifyMap.get(id);
|
||||
if (content.terminated) {
|
||||
if (progress != undefined) {
|
||||
progress.notification?.close();
|
||||
sqlExecNotifyMap.delete(id);
|
||||
progress = undefined;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (progress == undefined) {
|
||||
progress = {
|
||||
props: reactive(buildProgressProps()),
|
||||
notification: undefined,
|
||||
};
|
||||
}
|
||||
progress.props.progress.title = content.title;
|
||||
progress.props.progress.executedStatements = content.executedStatements;
|
||||
if (!sqlExecNotifyMap.has(id)) {
|
||||
progress.notification = ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
message: h(ProgressNotify, progress.props),
|
||||
type: syssocket.getMsgType(message.type),
|
||||
showClose: false,
|
||||
});
|
||||
sqlExecNotifyMap.set(id, progress);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// 执行sql成功
|
||||
|
||||
@@ -1,36 +1,54 @@
|
||||
<template>
|
||||
<div class="string-input-container w100" v-if="dataType == DataType.String || dataType == DataType.Number">
|
||||
<el-input
|
||||
:ref="(el: any) => focus && el?.focus()"
|
||||
:ref="
|
||||
(el: any) => {
|
||||
nextTick(() => {
|
||||
focus && el?.focus();
|
||||
});
|
||||
}
|
||||
"
|
||||
:disabled="disabled"
|
||||
@blur="handleBlur"
|
||||
:class="`w100 mb4 ${showEditorIcon ? 'string-input-container-show-icon' : ''}`"
|
||||
size="small"
|
||||
v-model="itemValue"
|
||||
:placeholder="placeholder"
|
||||
:placeholder="placeholder ?? $t('common.pleaseInput')"
|
||||
/>
|
||||
<SvgIcon v-if="showEditorIcon" @mousedown="openEditor" class="string-input-container-icon" name="FullScreen" :size="10" />
|
||||
</div>
|
||||
|
||||
<el-date-picker
|
||||
v-else-if="dataType == DataType.Date"
|
||||
:ref="(el: any) => focus && el?.focus()"
|
||||
:ref="
|
||||
(el: any) => {
|
||||
nextTick(() => {
|
||||
focus && el?.focus();
|
||||
});
|
||||
}
|
||||
"
|
||||
:disabled="disabled"
|
||||
@change="emit('blur')"
|
||||
@change="handleBlur"
|
||||
@blur="handleBlur"
|
||||
class="edit-time-picker mb4"
|
||||
popper-class="edit-time-picker-popper"
|
||||
size="small"
|
||||
v-model="itemValue"
|
||||
:clearable="false"
|
||||
type="Date"
|
||||
type="date"
|
||||
value-format="YYYY-MM-DD"
|
||||
:placeholder="`date-${placeholder}`"
|
||||
:placeholder="`date-${placeholder ?? $t('common.pleaseSelect')}`"
|
||||
/>
|
||||
|
||||
<el-date-picker
|
||||
v-else-if="dataType == DataType.DateTime"
|
||||
:ref="(el: any) => focus && el?.focus()"
|
||||
:ref="
|
||||
(el: any) => {
|
||||
nextTick(() => {
|
||||
focus && el?.focus();
|
||||
});
|
||||
}
|
||||
"
|
||||
:disabled="disabled"
|
||||
@change="handleBlur"
|
||||
@blur="handleBlur"
|
||||
@@ -41,12 +59,18 @@
|
||||
:clearable="false"
|
||||
type="datetime"
|
||||
value-format="YYYY-MM-DD HH:mm:ss"
|
||||
:placeholder="`datetime-${placeholder}`"
|
||||
:placeholder="`datetime-${placeholder ?? $t('common.pleaseSelect')}`"
|
||||
/>
|
||||
|
||||
<el-time-picker
|
||||
v-else-if="dataType == DataType.Time"
|
||||
:ref="(el: any) => focus && el?.focus()"
|
||||
:ref="
|
||||
(el: any) => {
|
||||
nextTick(() => {
|
||||
focus && el?.focus();
|
||||
});
|
||||
}
|
||||
"
|
||||
:disabled="disabled"
|
||||
@change="handleBlur"
|
||||
@blur="handleBlur"
|
||||
@@ -56,12 +80,12 @@
|
||||
v-model="itemValue"
|
||||
:clearable="false"
|
||||
value-format="HH:mm:ss"
|
||||
:placeholder="`time-${placeholder}`"
|
||||
:placeholder="`time-${placeholder ?? $t('common.pleaseSelect')}`"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script lang="ts" setup>
|
||||
import { computed, ref, Ref } from 'vue';
|
||||
import { computed, nextTick, ref, Ref } from 'vue';
|
||||
import { ElInput, ElMessage } from 'element-plus';
|
||||
import { DataType } from '../../dialect/index';
|
||||
import SvgIcon from '@/components/svgIcon/index.vue';
|
||||
|
||||
@@ -164,7 +164,7 @@ import SvgIcon from '@/components/svgIcon/index.vue';
|
||||
import { exportCsv, exportFile } from '@/common/utils/export';
|
||||
import { formatDate } from '@/common/utils/format';
|
||||
import { useIntervalFn, useStorage } from '@vueuse/core';
|
||||
import { ColumnTypeSubscript, compatibleMysql, DataType, DbDialect, getDbDialect } from '../../dialect/index';
|
||||
import { ColumnTypeSubscript, DataType, DbDialect, getDbDialect } from '../../dialect/index';
|
||||
import ColumnFormItem from './ColumnFormItem.vue';
|
||||
import DbTableDataForm from './DbTableDataForm.vue';
|
||||
import { useI18n } from 'vue-i18n';
|
||||
@@ -454,24 +454,11 @@ onBeforeUnmount(() => {
|
||||
endLoading();
|
||||
});
|
||||
|
||||
const formatDataValues = (datas: any) => {
|
||||
// mysql数据暂不做处理
|
||||
if (compatibleMysql(getNowDbInst().type)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (let data of datas) {
|
||||
for (let column of props.columns!) {
|
||||
data[column.columnName] = getFormatTimeValue(dbDialect.getDataType(column.dataType), data[column.columnName]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const setTableData = (datas: any) => {
|
||||
tableRef.value?.scrollTo({ scrollLeft: 0, scrollTop: 0 });
|
||||
selectionRowsMap.value.clear();
|
||||
cellUpdateMap.value.clear();
|
||||
formatDataValues(datas);
|
||||
// formatDataValues(datas);
|
||||
state.datas = datas;
|
||||
setTableColumns(props.columns);
|
||||
};
|
||||
@@ -684,7 +671,8 @@ const onExportSql = async () => {
|
||||
};
|
||||
|
||||
const onEnterEditMode = (rowData: any, column: any, rowIndex = 0, columnIndex = 0) => {
|
||||
if (!state.table) {
|
||||
// 不存在表,或者已经在编辑中,则不处理
|
||||
if (!state.table || nowUpdateCell.value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -697,7 +685,7 @@ const onEnterEditMode = (rowData: any, column: any, rowIndex = 0, columnIndex =
|
||||
};
|
||||
|
||||
const onExitEditMode = (rowData: any, column: any, rowIndex = 0) => {
|
||||
if (!nowUpdateCell) {
|
||||
if (!nowUpdateCell.value) {
|
||||
return;
|
||||
}
|
||||
const oldValue = nowUpdateCell.value.oldValue;
|
||||
@@ -788,32 +776,6 @@ const rowClass = (row: any) => {
|
||||
return '';
|
||||
};
|
||||
|
||||
/**
|
||||
* 根据数据库返回的时间字段类型,获取格式化后的时间值
|
||||
* @param dataType getDataType返回的数据类型
|
||||
* @param originValue 原始值
|
||||
* @return 格式化后的值
|
||||
*/
|
||||
const getFormatTimeValue = (dataType: DataType, originValue: string): string => {
|
||||
if (!originValue || dataType === DataType.Number || dataType === DataType.String) {
|
||||
return originValue;
|
||||
}
|
||||
|
||||
// 把Z去掉
|
||||
originValue = originValue.replace('Z', '');
|
||||
|
||||
switch (dataType) {
|
||||
case DataType.Time:
|
||||
return formatDate(originValue, 'HH:mm:ss');
|
||||
case DataType.Date:
|
||||
return formatDate(originValue, 'YYYY-MM-DD');
|
||||
case DataType.DateTime:
|
||||
return formatDate(originValue, 'YYYY-MM-DD HH:mm:ss');
|
||||
default:
|
||||
return originValue;
|
||||
}
|
||||
};
|
||||
|
||||
const scrollLeftValue = ref(0);
|
||||
const onTableScroll = (param: any) => {
|
||||
scrollLeftValue.value = param.scrollLeft;
|
||||
|
||||
@@ -592,7 +592,7 @@ const onSelectByCondition = async () => {
|
||||
*/
|
||||
const onTableSortChange = async (sort: any) => {
|
||||
const sortType = sort.order == 'desc' ? 'DESC' : 'ASC';
|
||||
state.orderBy = `ORDER BY ${sort.columnName} ${sortType}`;
|
||||
state.orderBy = `ORDER BY ${state.dbDialect.quoteIdentifier(sort.columnName)} ${sortType}`;
|
||||
await onRefresh();
|
||||
};
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
</el-tab-pane>
|
||||
<el-tab-pane :label="$t('db.index')" name="2">
|
||||
<el-table :data="tableData.indexs.res" :max-height="tableData.height">
|
||||
<el-table-column :prop="item.prop" :label="item.label" v-for="item in tableData.indexs.colNames" :key="item.prop">
|
||||
<el-table-column :prop="item.prop" :label="$t(item.label)" v-for="item in tableData.indexs.colNames" :key="item.prop">
|
||||
<template #default="scope">
|
||||
<el-input v-if="item.prop === 'indexName'" size="small" disabled v-model="scope.row.indexName"></el-input>
|
||||
|
||||
@@ -241,7 +241,7 @@ const state = reactive({
|
||||
colNames: [
|
||||
{
|
||||
prop: 'indexName',
|
||||
label: 'db.indexName',
|
||||
label: 'common.name',
|
||||
},
|
||||
{
|
||||
prop: 'columnNames',
|
||||
|
||||
@@ -209,8 +209,7 @@ class PostgresqlDialect implements DbDialect {
|
||||
}
|
||||
|
||||
quoteIdentifier = (name: string) => {
|
||||
// 后端sql解析器暂不支持pgsql
|
||||
return name;
|
||||
return `"${name}"`;
|
||||
};
|
||||
|
||||
matchType(text: string, arr: string[]): boolean {
|
||||
|
||||
@@ -22,7 +22,7 @@ export const DbSqlExecStatusEnum = {
|
||||
|
||||
export const DbDataSyncDuplicateStrategyEnum = {
|
||||
None: EnumValue.of(-1, 'db.none'),
|
||||
Ignore: EnumValue.of(1, 'db.ingore'),
|
||||
Ignore: EnumValue.of(1, 'db.ignore'),
|
||||
Replace: EnumValue.of(2, 'db.replace'),
|
||||
};
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/veops/go-ansiterm v0.0.5
|
||||
go.mongodb.org/mongo-driver v1.16.0 // mongo
|
||||
golang.org/x/crypto v0.29.0 // ssh
|
||||
golang.org/x/crypto v0.30.0 // ssh
|
||||
golang.org/x/oauth2 v0.23.0
|
||||
golang.org/x/sync v0.9.0
|
||||
golang.org/x/sync v0.10.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
// gorm
|
||||
@@ -93,8 +93,8 @@ require (
|
||||
golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect
|
||||
golang.org/x/image v0.13.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.27.0 // indirect
|
||||
golang.org/x/text v0.20.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
"mayfly-go/internal/auth/api/form"
|
||||
"mayfly-go/internal/auth/config"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/internal/auth/pkg/captcha"
|
||||
"mayfly-go/internal/auth/pkg/otp"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
sysapp "mayfly-go/internal/sys/application"
|
||||
sysentity "mayfly-go/internal/sys/domain/entity"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/captcha"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/otp"
|
||||
"mayfly-go/pkg/req"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/cryptox"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/auth/pkg/captcha"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/captcha"
|
||||
"mayfly-go/pkg/req"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
)
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/auth/config"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/internal/auth/pkg/otp"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
msgentity "mayfly-go/internal/msg/domain/entity"
|
||||
sysapp "mayfly-go/internal/sys/application"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/i18n"
|
||||
"mayfly-go/pkg/otp"
|
||||
"mayfly-go/pkg/req"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/jsonx"
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"mayfly-go/internal/auth/api/form"
|
||||
"mayfly-go/internal/auth/config"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/internal/auth/pkg/captcha"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
sysapp "mayfly-go/internal/sys/application"
|
||||
sysentity "mayfly-go/internal/sys/domain/entity"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/captcha"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
@@ -11,5 +11,5 @@ type oauth2AccountRepoImpl struct {
|
||||
}
|
||||
|
||||
func newAuthAccountRepo() repository.Oauth2Account {
|
||||
return &oauth2AccountRepoImpl{base.RepoImpl[*entity.Oauth2Account]{M: new(entity.Oauth2Account)}}
|
||||
return &oauth2AccountRepoImpl{}
|
||||
}
|
||||
|
||||
37
server/internal/auth/router/account.go
Normal file
37
server/internal/auth/router/account.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/auth/api"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ioc"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func InitAccount(router *gin.RouterGroup) {
|
||||
accountLogin := new(api.AccountLogin)
|
||||
biz.ErrIsNil(ioc.Inject(accountLogin))
|
||||
|
||||
ldapLogin := new(api.LdapLogin)
|
||||
biz.ErrIsNil(ioc.Inject(ldapLogin))
|
||||
|
||||
rg := router.Group("/auth/accounts")
|
||||
|
||||
reqs := [...]*req.Conf{
|
||||
|
||||
// 用户账号密码登录
|
||||
req.NewPost("/login", accountLogin.Login).Log(req.NewLogSaveI(imsg.LogAccountLogin)).DontNeedToken(),
|
||||
|
||||
req.NewGet("/refreshToken", accountLogin.RefreshToken).DontNeedToken(),
|
||||
|
||||
// 用户退出登录
|
||||
req.NewPost("/logout", accountLogin.Logout),
|
||||
|
||||
// 用户otp双因素校验
|
||||
req.NewPost("/otp-verify", accountLogin.OtpVerify).DontNeedToken(),
|
||||
}
|
||||
|
||||
req.BatchSetGroup(rg, reqs[:])
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/sys/api"
|
||||
"mayfly-go/internal/auth/api"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func InitCaptchaRouter(router *gin.RouterGroup) {
|
||||
func InitCaptcha(router *gin.RouterGroup) {
|
||||
captcha := router.Group("sys/captcha")
|
||||
{
|
||||
req.NewGet("", api.GenerateCaptcha).DontNeedToken().Group(captcha)
|
||||
25
server/internal/auth/router/ldap.go
Normal file
25
server/internal/auth/router/ldap.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/auth/api"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ioc"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func InitLdap(router *gin.RouterGroup) {
|
||||
ldapLogin := new(api.LdapLogin)
|
||||
biz.ErrIsNil(ioc.Inject(ldapLogin))
|
||||
|
||||
rg := router.Group("/auth/ldap")
|
||||
|
||||
reqs := [...]*req.Conf{
|
||||
req.NewGet("/enabled", ldapLogin.GetLdapEnabled).DontNeedToken(),
|
||||
req.NewPost("/login", ldapLogin.Login).Log(req.NewLogSaveI(imsg.LogLdapLogin)).DontNeedToken(),
|
||||
}
|
||||
|
||||
req.BatchSetGroup(rg, reqs[:])
|
||||
}
|
||||
37
server/internal/auth/router/oauth2.go
Normal file
37
server/internal/auth/router/oauth2.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/auth/api"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ioc"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func InitOauth2(router *gin.RouterGroup) {
|
||||
oauth2Login := new(api.Oauth2Login)
|
||||
biz.ErrIsNil(ioc.Inject(oauth2Login))
|
||||
|
||||
rg := router.Group("/auth/oauth2")
|
||||
|
||||
reqs := [...]*req.Conf{
|
||||
|
||||
req.NewGet("/config", oauth2Login.Oauth2Config).DontNeedToken(),
|
||||
|
||||
// oauth2登录
|
||||
req.NewGet("/login", oauth2Login.OAuth2Login).DontNeedToken(),
|
||||
|
||||
req.NewGet("/bind", oauth2Login.OAuth2Bind),
|
||||
|
||||
// oauth2回调地址
|
||||
req.NewGet("/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSaveI(imsg.LogOauth2Callback)).DontNeedToken(),
|
||||
|
||||
req.NewGet("/status", oauth2Login.Oauth2Status),
|
||||
|
||||
req.NewGet("/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSaveI(imsg.LogOauth2Unbind)),
|
||||
}
|
||||
|
||||
req.BatchSetGroup(rg, reqs[:])
|
||||
}
|
||||
@@ -1,60 +1,10 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/auth/api"
|
||||
"mayfly-go/internal/auth/imsg"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ioc"
|
||||
"mayfly-go/pkg/req"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Init(router *gin.RouterGroup) {
|
||||
accountLogin := new(api.AccountLogin)
|
||||
biz.ErrIsNil(ioc.Inject(accountLogin))
|
||||
|
||||
ldapLogin := new(api.LdapLogin)
|
||||
biz.ErrIsNil(ioc.Inject(ldapLogin))
|
||||
|
||||
oauth2Login := new(api.Oauth2Login)
|
||||
biz.ErrIsNil(ioc.Inject(oauth2Login))
|
||||
|
||||
rg := router.Group("/auth")
|
||||
|
||||
reqs := [...]*req.Conf{
|
||||
|
||||
// 用户账号密码登录
|
||||
req.NewPost("/accounts/login", accountLogin.Login).Log(req.NewLogSaveI(imsg.LogAccountLogin)).DontNeedToken(),
|
||||
|
||||
req.NewGet("/accounts/refreshToken", accountLogin.RefreshToken).DontNeedToken(),
|
||||
|
||||
// 用户退出登录
|
||||
req.NewPost("/accounts/logout", accountLogin.Logout),
|
||||
|
||||
// 用户otp双因素校验
|
||||
req.NewPost("/accounts/otp-verify", accountLogin.OtpVerify).DontNeedToken(),
|
||||
|
||||
/*--------oauth2登录相关----------*/
|
||||
|
||||
req.NewGet("/oauth2-config", oauth2Login.Oauth2Config).DontNeedToken(),
|
||||
|
||||
// oauth2登录
|
||||
req.NewGet("/oauth2/login", oauth2Login.OAuth2Login).DontNeedToken(),
|
||||
|
||||
req.NewGet("/oauth2/bind", oauth2Login.OAuth2Bind),
|
||||
|
||||
// oauth2回调地址
|
||||
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSaveI(imsg.LogOauth2Callback)).DontNeedToken(),
|
||||
|
||||
req.NewGet("/oauth2/status", oauth2Login.Oauth2Status),
|
||||
|
||||
req.NewGet("/oauth2/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSaveI(imsg.LogOauth2Unbind)),
|
||||
|
||||
// LDAP 登录
|
||||
req.NewGet("/ldap/enabled", ldapLogin.GetLdapEnabled).DontNeedToken(),
|
||||
req.NewPost("/ldap/login", ldapLogin.Login).Log(req.NewLogSaveI(imsg.LogLdapLogin)).DontNeedToken(),
|
||||
}
|
||||
|
||||
req.BatchSetGroup(rg, reqs[:])
|
||||
InitCaptcha(router)
|
||||
InitAccount(router)
|
||||
InitOauth2(router)
|
||||
InitLdap(router)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"mayfly-go/internal/db/application/dto"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/imsg"
|
||||
"mayfly-go/internal/event"
|
||||
@@ -26,9 +25,7 @@ import (
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/cryptox"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"mayfly-go/pkg/utils/writerx"
|
||||
"mayfly-go/pkg/ws"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -118,7 +115,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sqlStr)
|
||||
biz.NotEmpty(form.Sql, "sql cannot be empty")
|
||||
|
||||
execReq := &application.DbSqlExecReq{
|
||||
execReq := &dto.DbSqlExecReq{
|
||||
DbId: dbId,
|
||||
Db: form.Db,
|
||||
Remark: form.Remark,
|
||||
@@ -163,47 +160,12 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
|
||||
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
errInfo := anyx.ToString(err)
|
||||
if len(errInfo) > 300 {
|
||||
errInfo = errInfo[:300] + "..."
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] execution failure: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
|
||||
}
|
||||
}()
|
||||
|
||||
executedStatements := 0
|
||||
progressId := stringx.Rand(32)
|
||||
laId := rc.GetLoginAccount().Id
|
||||
defer ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
|
||||
Id: progressId,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
|
||||
err = sqlparser.SQLSplit(file, func(sql string) error {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
|
||||
Id: progressId,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
|
||||
executedStatements++
|
||||
_, err = dbConn.Exec(sql)
|
||||
return err
|
||||
})
|
||||
|
||||
biz.ErrIsNilAppendErr(err, "%s")
|
||||
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), fmt.Sprintf("execution success: %s", rc.ReqParam)).WithClientId(clientId))
|
||||
biz.ErrIsNil(d.DbSqlExecApp.ExecReader(rc.MetaCtx, &dto.SqlReaderExec{
|
||||
Reader: file,
|
||||
Filename: filename,
|
||||
DbConn: dbConn,
|
||||
ClientId: clientId,
|
||||
}))
|
||||
}
|
||||
|
||||
// 数据库dump
|
||||
|
||||
@@ -2,28 +2,18 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
"mayfly-go/internal/db/application"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/application/dto"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/imsg"
|
||||
fileapp "mayfly-go/internal/file/application"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
msgdto "mayfly-go/internal/msg/application/dto"
|
||||
tagapp "mayfly-go/internal/tag/application"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/i18n"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/req"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"mayfly-go/pkg/ws"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
@@ -120,7 +110,6 @@ func (d *DbTransferTask) FileDel(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (d *DbTransferTask) FileRun(rc *req.Ctx) {
|
||||
|
||||
fm := req.BindJsonAndValid(rc, &form.DbTransferFileRunForm{})
|
||||
|
||||
rc.ReqParam = fm
|
||||
@@ -132,69 +121,15 @@ func (d *DbTransferTask) FileRun(rc *req.Ctx) {
|
||||
biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s")
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, targetDbConn.Info.CodePath...), "%s")
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
errInfo := anyx.ToString(err)
|
||||
if len(errInfo) > 300 {
|
||||
errInfo = errInfo[:300] + "..."
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] run failed: [%s]", tFile.FileKey, targetDbConn.Info.GetLogDesc(), errInfo)).WithClientId(fm.ClientId))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
d.fileRun(rc.GetLoginAccount(), fm, tFile, targetDbConn)
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (d *DbTransferTask) fileRun(la *model.LoginAccount, fm *form.DbTransferFileRunForm, tFile *entity.DbTransferFile, targetDbConn *dbi.DbConn) {
|
||||
filename, reader, err := d.FileApp.GetReader(context.TODO(), tFile.FileKey)
|
||||
|
||||
executedStatements := 0
|
||||
progressId := stringx.Rand(32)
|
||||
laId := la.Id
|
||||
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
errInfo := anyx.ToString(err)
|
||||
if len(errInfo) > 500 {
|
||||
errInfo = errInfo[:500] + "..."
|
||||
}
|
||||
d.MsgApp.CreateAndSend(la, msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), errInfo).WithClientId(fm.ClientId))
|
||||
}
|
||||
biz.ErrIsNil(err)
|
||||
go func() {
|
||||
biz.ErrIsNil(d.DbSqlExecApp.ExecReader(rc.MetaCtx, &dto.SqlReaderExec{
|
||||
Reader: reader,
|
||||
Filename: filename,
|
||||
DbConn: targetDbConn,
|
||||
ClientId: fm.ClientId,
|
||||
}))
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s")
|
||||
}
|
||||
|
||||
errSql := ""
|
||||
err = sqlparser.SQLSplit(reader, func(sql string) error {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.SendJsonMsg(ws.UserId(laId), fm.ClientId, msgdto.InfoSqlProgressMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
|
||||
Id: progressId,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
executedStatements++
|
||||
_, err = targetDbConn.Exec(sql)
|
||||
if err != nil {
|
||||
errSql = sql
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
biz.ErrIsNil(err, "[%s] execution failed: %s", errSql, err)
|
||||
}
|
||||
|
||||
d.MsgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), fmt.Sprintf("sql execution successfully: %s", filename)).WithClientId(fm.ClientId))
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error
|
||||
return nil, errorx.NewBiz("failed to get database list")
|
||||
}
|
||||
if len(dbs) == 0 {
|
||||
return nil, errorx.NewBiz("DB instance [%d] Database is not configured, please configure it first", instanceId)
|
||||
return nil, errorx.NewBiz("DB instance [%d] database is not configured, please configure it first", instanceId)
|
||||
}
|
||||
|
||||
// 使用该实例关联的已配置数据库中的第一个库进行连接并返回
|
||||
@@ -227,6 +227,10 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
|
||||
if reqParam.Log != nil {
|
||||
log = reqParam.Log
|
||||
}
|
||||
progress := dto.DefaultDumpProgress
|
||||
if reqParam.Progress != nil {
|
||||
progress = reqParam.Progress
|
||||
}
|
||||
|
||||
writer := writerx.NewStringWriter(reqParam.Writer)
|
||||
defer writer.Close()
|
||||
@@ -249,23 +253,16 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
|
||||
// 获取目标元数据,仅生成sql,用于生成建表语句和插入数据,不能用于查询
|
||||
targetDialect := dbConn.GetDialect()
|
||||
if reqParam.TargetDbType != "" && dbConn.Info.Type != reqParam.TargetDbType {
|
||||
// 创建一个假连接,仅用于调用方言生成sql,不做数据库连接操作
|
||||
meta := dbi.GetMeta(reqParam.TargetDbType)
|
||||
dbConn := &dbi.DbConn{Info: &dbi.DbInfo{
|
||||
Type: reqParam.TargetDbType,
|
||||
Meta: meta,
|
||||
}}
|
||||
|
||||
targetDialect = meta.GetDialect(dbConn)
|
||||
targetDialect = dbi.GetDialect(reqParam.TargetDbType)
|
||||
}
|
||||
|
||||
srcMeta := dbConn.GetMetadata()
|
||||
srcDialect := dbConn.GetDialect()
|
||||
if len(tables) == 0 {
|
||||
log("Gets the table information that can be export...")
|
||||
log("gets the table information that can be export...")
|
||||
ti, err := srcMeta.GetTables()
|
||||
if err != nil {
|
||||
log(fmt.Sprintf("Failed to get table info %s", err.Error()))
|
||||
log(fmt.Sprintf("failed to get table info %s", err.Error()))
|
||||
}
|
||||
biz.ErrIsNil(err)
|
||||
tables = make([]string, len(ti))
|
||||
@@ -275,105 +272,133 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
|
||||
log(fmt.Sprintf("Get %d tables", len(tables)))
|
||||
}
|
||||
if len(tables) == 0 {
|
||||
log("No table to export. End export")
|
||||
log("no table to export. end export")
|
||||
return errorx.NewBiz("there is no table to export")
|
||||
}
|
||||
|
||||
log("Querying column information...")
|
||||
log("querying column information...")
|
||||
// 查询列信息,后面生成建表ddl和insert都需要列信息
|
||||
columns, err := srcMeta.GetColumns(tables...)
|
||||
if err != nil {
|
||||
log(fmt.Sprintf("Failed to query column information: %s", err.Error()))
|
||||
log(fmt.Sprintf("failed to query column information: %s", err.Error()))
|
||||
}
|
||||
biz.ErrIsNil(err)
|
||||
|
||||
// 以表名分组,存放每个表的列信息
|
||||
columnMap := make(map[string][]dbi.Column)
|
||||
for _, column := range columns {
|
||||
if err := dbi.ConvToTargetDbColumn(dbConn.Info.Type, cmp.Or(reqParam.TargetDbType, dbConn.Info.Type), targetDialect, &column); err != nil {
|
||||
return err
|
||||
}
|
||||
columnMap[column.TableName] = append(columnMap[column.TableName], column)
|
||||
}
|
||||
|
||||
// 按表名排序
|
||||
sort.Strings(tables)
|
||||
|
||||
quoteSchema := srcDialect.QuoteIdentifier(dbConn.Info.CurrentSchema())
|
||||
quoteSchema := srcDialect.Quoter().Quote(dbConn.Info.CurrentSchema())
|
||||
dumpHelper := targetDialect.GetDumpHelper()
|
||||
dataHelper := targetDialect.GetDataHelper()
|
||||
|
||||
targetSqlGenerator := targetDialect.GetSQLGenerator()
|
||||
targetDialectQuote := targetDialect.Quoter().Quote
|
||||
// 遍历获取每个表的信息
|
||||
for _, tableName := range tables {
|
||||
log(fmt.Sprintf("Get table [%s] information...", tableName))
|
||||
quoteTableName := targetDialect.QuoteIdentifier(tableName)
|
||||
log(fmt.Sprintf("get table [%s] information...", tableName))
|
||||
quoteTableName := targetDialectQuote(tableName)
|
||||
|
||||
// 查询表信息,主要是为了查询表注释
|
||||
tbs, err := srcMeta.GetTables(tableName)
|
||||
if err != nil {
|
||||
log(fmt.Sprintf("Failed to get table [%s] information: %s", tableName, err.Error()))
|
||||
log(fmt.Sprintf("failed to get table [%s] information: %s", tableName, err.Error()))
|
||||
return err
|
||||
}
|
||||
if len(tbs) <= 0 {
|
||||
log(fmt.Sprintf("Failed to get table [%s] information: No table information was retrieved", tableName))
|
||||
log(fmt.Sprintf("failed to get table [%s] information: No table information was retrieved", tableName))
|
||||
return errorx.NewBiz(fmt.Sprintf("Failed to get table information: %s", tableName))
|
||||
}
|
||||
tabInfo := dbi.Table{
|
||||
TableName: tableName,
|
||||
TableComment: tbs[0].TableComment,
|
||||
}
|
||||
|
||||
tableInfo := tbs[0]
|
||||
columns := columnMap[tableName]
|
||||
|
||||
// 生成表结构信息
|
||||
if reqParam.DumpDDL {
|
||||
log(fmt.Sprintf("Generate table [%s] DDL...", tableName))
|
||||
log(fmt.Sprintf("generate table [%s] DDL...", tableName))
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Table structure: %s \n-- ----------------------------\n", tableName))
|
||||
tbDdlArr := targetDialect.GenerateTableDDL(columnMap[tableName], tabInfo, true)
|
||||
tbDdlArr := targetSqlGenerator.GenTableDDL(tableInfo, columns, true)
|
||||
for _, ddl := range tbDdlArr {
|
||||
writer.WriteString(ddl + ";\n")
|
||||
if _, err := writer.WriteString(ddl + ";\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
progress(tableName, dbi.StmtTypeDDL, len(tbDdlArr), true)
|
||||
}
|
||||
|
||||
// 生成insert sql,数据在索引前,加速insert
|
||||
if reqParam.DumpData {
|
||||
log(fmt.Sprintf("Generate table [%s] DML...", tableName))
|
||||
log(fmt.Sprintf("generate table [%s] DML...", tableName))
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Data: %s \n-- ----------------------------\n", tableName))
|
||||
|
||||
dumpHelper.BeforeInsert(writer, quoteTableName)
|
||||
// 获取列信息
|
||||
quoteColNames := make([]string, 0)
|
||||
for _, col := range columnMap[tableName] {
|
||||
quoteColNames = append(quoteColNames, targetDialect.QuoteIdentifier(col.ColumnName))
|
||||
}
|
||||
|
||||
_, _ = dbConn.WalkTableRows(ctx, tableName, func(row map[string]any, _ []*dbi.QueryColumn) error {
|
||||
rowValues := make([]string, len(columnMap[tableName]))
|
||||
for i, col := range columnMap[tableName] {
|
||||
rowValues[i] = dataHelper.WrapValue(row[col.ColumnName], dataHelper.GetDataType(string(col.DataType)))
|
||||
dataCount := 0
|
||||
rows := make([][]any, 0)
|
||||
_, err = dbConn.WalkTableRows(ctx, tableName, func(row map[string]any, _ []*dbi.QueryColumn) error {
|
||||
rowValues := make([]any, len(columns))
|
||||
for i, col := range columns {
|
||||
rowValues[i] = row[col.ColumnName]
|
||||
}
|
||||
rows = append(rows, rowValues)
|
||||
dataCount++
|
||||
if dataCount%500 != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
|
||||
insertSQL := fmt.Sprintf("%s INSERT INTO %s (%s) values(%s)", beforeInsert, quoteTableName, strings.Join(quoteColNames, ", "), strings.Join(rowValues, ", "))
|
||||
writer.WriteString(insertSQL + ";\n")
|
||||
writer.WriteString(beforeInsert)
|
||||
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
|
||||
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
progress(tableName, dbi.StmtTypeInsert, dataCount, false)
|
||||
rows = make([][]any, 0)
|
||||
return nil
|
||||
})
|
||||
|
||||
dumpHelper.AfterInsert(writer, tableName, columnMap[tableName])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rows) > 0 {
|
||||
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
|
||||
writer.WriteString(beforeInsert)
|
||||
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
|
||||
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dumpHelper.AfterInsert(writer, tableName, columns)
|
||||
progress(tableName, dbi.StmtTypeInsert, dataCount, true)
|
||||
}
|
||||
|
||||
log(fmt.Sprintf("Get table [%s] index information...", tableName))
|
||||
log(fmt.Sprintf("get table [%s] index information...", tableName))
|
||||
indexs, err := srcMeta.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
log(fmt.Sprintf("Failed to get table [%s] index information: %s", tableName, err.Error()))
|
||||
log(fmt.Sprintf("failed to get table [%s] index information: %s", tableName, err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
if len(indexs) > 0 {
|
||||
// 最后添加索引
|
||||
log(fmt.Sprintf("Generate table [%s] index...", tableName))
|
||||
log(fmt.Sprintf("generate table [%s] index...", tableName))
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Table Index: %s \n-- ----------------------------\n", tableName))
|
||||
sqlArr := targetDialect.GenerateIndexDDL(indexs, tabInfo)
|
||||
sqlArr := targetSqlGenerator.GenIndexDDL(tableInfo, indexs)
|
||||
for _, sqlStr := range sqlArr {
|
||||
writer.WriteString(sqlStr + ";\n")
|
||||
if _, err := writer.WriteString(sqlStr + ";\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
progress(tableName, dbi.StmtTypeDDL, len(sqlArr), true)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -14,8 +15,8 @@ import (
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/scheduler"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -145,28 +146,11 @@ func (app *dataSyncAppImpl) RunCronJob(ctx context.Context, id uint64) error {
|
||||
updSql := ""
|
||||
orderSql := ""
|
||||
if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" {
|
||||
srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
|
||||
if err != nil {
|
||||
logx.ErrorfContext(ctx, "data source connection unavailable: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ")
|
||||
|
||||
// 判断UpdFieldVal数据类型
|
||||
var updFieldValType dbi.DataType
|
||||
if _, err = strconv.Atoi(task.UpdFieldVal); err != nil {
|
||||
if dateTimeReg.MatchString(task.UpdFieldVal) || dateTimeIsoReg.MatchString(task.UpdFieldVal) {
|
||||
updFieldValType = dbi.DataTypeDateTime
|
||||
} else {
|
||||
updFieldValType = dbi.DataTypeString
|
||||
}
|
||||
} else {
|
||||
updFieldValType = dbi.DataTypeNumber
|
||||
}
|
||||
wrapUpdFieldVal := srcConn.GetDialect().GetDataHelper().WrapValue(task.UpdFieldVal, updFieldValType)
|
||||
updSql = fmt.Sprintf("and %s > %s", task.UpdField, wrapUpdFieldVal)
|
||||
|
||||
updSql = fmt.Sprintf("and %s > %s", task.UpdField, strings.Trim(task.UpdFieldVal, " "))
|
||||
orderSql = "order by " + task.UpdField + " asc "
|
||||
}
|
||||
// 正则判断DataSql是否以where .*结尾,如果是则不添加where 1 = 1
|
||||
@@ -221,15 +205,13 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
}
|
||||
}()
|
||||
|
||||
srcDialect := srcConn.GetDialect()
|
||||
|
||||
// task.FieldMap为json数组字符串 [{"src":"id","target":"id"}],转为map
|
||||
var fieldMap []map[string]string
|
||||
err = json.Unmarshal([]byte(task.FieldMap), &fieldMap)
|
||||
if err != nil {
|
||||
return syncLog, errorx.NewBiz("there was an error parsing the field map json: %s", err.Error())
|
||||
}
|
||||
var updFieldType dbi.DataType
|
||||
var updFieldType *dbi.DbDataType
|
||||
|
||||
// 记录本次同步数据总数
|
||||
total := 0
|
||||
@@ -243,15 +225,28 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
updFieldName = strings.Split(task.UpdField, ".")[1]
|
||||
}
|
||||
|
||||
targetTableColumns, err := targetConn.GetMetadata().GetColumns(task.TargetTableName)
|
||||
if err != nil {
|
||||
return syncLog, errorx.NewBiz("failed to get target table columns: %s", err.Error())
|
||||
}
|
||||
targetColumnName2Column := collx.ArrayToMap(targetTableColumns, func(column dbi.Column) string {
|
||||
return column.ColumnName
|
||||
})
|
||||
|
||||
// 目标库对应的insert columns
|
||||
targetInsertColumns := collx.ArrayMap[map[string]string, dbi.Column](fieldMap, func(val map[string]string) dbi.Column {
|
||||
return targetColumnName2Column[val["target"]]
|
||||
})
|
||||
|
||||
_, err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
|
||||
if len(queryColumns) == 0 {
|
||||
queryColumns = columns
|
||||
|
||||
// 遍历columns 取task.UpdField的字段类型
|
||||
updFieldType = dbi.DataTypeString
|
||||
updFieldType = dbi.DefaultDbDataType
|
||||
for _, column := range columns {
|
||||
if strings.EqualFold(column.Name, updFieldName) {
|
||||
updFieldType = srcDialect.GetDataHelper().GetDataType(column.Type)
|
||||
updFieldType = column.DbDataType
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -260,7 +255,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
total++
|
||||
result = append(result, row)
|
||||
if total%batchSize == 0 {
|
||||
if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
|
||||
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, updFieldName, task, targetConn, targetDbTx, targetInsertColumns); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -283,7 +278,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
|
||||
// 处理剩余的数据
|
||||
if len(result) > 0 {
|
||||
if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
|
||||
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, updFieldName, task, targetConn, targetDbTx, targetInsertColumns); err != nil {
|
||||
targetDbTx.Rollback()
|
||||
return syncLog, err
|
||||
}
|
||||
@@ -291,7 +286,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
|
||||
// 如果是mssql,暂不手动提交事务,否则报错 mssql: The COMMIT TRANSACTION request has no corresponding BEGIN TRANSACTION.
|
||||
if err := targetDbTx.Commit(); err != nil {
|
||||
if targetConn.Info.Type != dbi.DbTypeMssql {
|
||||
if targetConn.Info.Type != dbi.ToDbType("mssql") {
|
||||
return syncLog, errorx.NewBiz("data synchronization - The target database transaction failed to commit: %s", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -307,36 +302,38 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
|
||||
return syncLog, nil
|
||||
}
|
||||
|
||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
|
||||
|
||||
// 遍历src字段列表,取出字段对应的类型
|
||||
var srcColumnTypes = make(map[string]string)
|
||||
for _, column := range columns {
|
||||
srcColumnTypes[column.Name] = column.Type
|
||||
}
|
||||
|
||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType *dbi.DbDataType, updFieldName string, task *entity.DataSyncTask, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx, targetInsertColumns []dbi.Column) error {
|
||||
// 遍历res,组装数据
|
||||
var data = make([]map[string]any, 0)
|
||||
for _, record := range srcRes {
|
||||
var rowData = make(map[string]any)
|
||||
var targetData = make([]map[string]any, 0)
|
||||
for _, srcData := range srcRes {
|
||||
var data = make(map[string]any)
|
||||
// 遍历字段映射, target字段的值为src字段取值
|
||||
for _, item := range fieldMap {
|
||||
srcField := item["src"]
|
||||
targetField := item["target"]
|
||||
// target字段的值为src字段取值
|
||||
rowData[targetField] = record[srcField]
|
||||
data[item["target"]] = srcData[item["src"]]
|
||||
}
|
||||
|
||||
data = append(data, rowData)
|
||||
targetData = append(targetData, data)
|
||||
}
|
||||
|
||||
tragetValues := make([][]any, 0)
|
||||
for _, item := range targetData {
|
||||
var values = make([]any, 0)
|
||||
for _, column := range targetInsertColumns {
|
||||
values = append(values, item[column.ColumnName])
|
||||
}
|
||||
tragetValues = append(tragetValues, values)
|
||||
}
|
||||
|
||||
// 执行插入
|
||||
|
||||
setUpdateFieldVal := func(field string) {
|
||||
// 解决字段大小写问题
|
||||
updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(field)]
|
||||
if updFieldVal == "" || updFieldVal == nil {
|
||||
updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(field)]
|
||||
}
|
||||
task.UpdFieldVal = srcDialect.GetDataHelper().FormatData(updFieldVal, updFieldType)
|
||||
|
||||
task.UpdFieldVal = updFieldType.DataType.SQLValue(updFieldVal)
|
||||
}
|
||||
|
||||
// 如果指定了更新字段,则以更新字段取值
|
||||
@@ -346,36 +343,15 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
|
||||
setUpdateFieldVal(updFieldName)
|
||||
}
|
||||
|
||||
// 获取目标库字段数组
|
||||
targetWrapColumns := make([]string, 0)
|
||||
// 获取源库字段数组
|
||||
srcColumns := make([]string, 0)
|
||||
srcFieldTypes := make(map[string]dbi.DataType)
|
||||
targetDialect := targetDbConn.GetDialect()
|
||||
for _, item := range fieldMap {
|
||||
targetField := item["target"]
|
||||
srcField := item["target"]
|
||||
srcFieldTypes[srcField] = srcDialect.GetDataHelper().GetDataType(srcColumnTypes[item["src"]])
|
||||
targetWrapColumns = append(targetWrapColumns, targetDialect.QuoteIdentifier(targetField))
|
||||
srcColumns = append(srcColumns, srcField)
|
||||
}
|
||||
|
||||
// 从目标库数据中取出源库字段对应的值
|
||||
values := make([][]any, 0)
|
||||
for _, record := range data {
|
||||
rawValue := make([]any, 0)
|
||||
for _, column := range srcColumns {
|
||||
// 某些情况,如oracle,需要转换时间类型的字符串为time类型
|
||||
res := srcDialect.GetDataHelper().ParseData(record[column], srcFieldTypes[column])
|
||||
rawValue = append(rawValue, res)
|
||||
// 生成目标数据库批量插入sql,并执行
|
||||
sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, tragetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone))
|
||||
for _, sql := range sqls {
|
||||
_, err := targetDbTx.Exec(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, rawValue)
|
||||
}
|
||||
|
||||
// 目标数据库执行sql批量插入
|
||||
_, err := targetDialect.BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values, task.DuplicateStrategy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 运行过程中,判断状态是否为已关闭,是则结束运行,否则继续运行
|
||||
|
||||
@@ -3,6 +3,7 @@ package application
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/application/dto"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
@@ -12,26 +13,21 @@ import (
|
||||
"mayfly-go/internal/db/imsg"
|
||||
flowapp "mayfly-go/internal/flow/application"
|
||||
flowentity "mayfly-go/internal/flow/domain/entity"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
msgdto "mayfly-go/internal/msg/application/dto"
|
||||
"mayfly-go/pkg/contextx"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/i18n"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/jsonx"
|
||||
"os"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"mayfly-go/pkg/ws"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DbSqlExecReq struct {
|
||||
DbId uint64
|
||||
Db string
|
||||
Sql string // 需要执行的sql,支持多条
|
||||
SqlFile *os.File // sql文件
|
||||
Remark string // 执行备注
|
||||
DbConn *dbi.DbConn
|
||||
CheckFlow bool // 是否检查存储审批流程
|
||||
}
|
||||
|
||||
type sqlExecParam struct {
|
||||
DbConn *dbi.DbConn
|
||||
Sql string // 执行的sql
|
||||
@@ -41,18 +37,25 @@ type sqlExecParam struct {
|
||||
SqlExecRecord *entity.DbSqlExec // sql执行记录
|
||||
}
|
||||
|
||||
type DbSqlExecRes struct {
|
||||
Sql string `json:"sql"` // 执行的sql
|
||||
ErrorMsg string `json:"errorMsg"` // 若执行失败,则将失败内容记录到该字段
|
||||
Columns []*dbi.QueryColumn `json:"columns"` // 响应的列信息
|
||||
Res []map[string]any `json:"res"` // 响应结果
|
||||
// progressCategory sql文件执行进度消息类型
|
||||
const progressCategory = "execSqlFileProgress"
|
||||
|
||||
// progressMsg sql文件执行进度消息
|
||||
type progressMsg struct {
|
||||
Id string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
ExecutedStatements int `json:"executedStatements"`
|
||||
Terminated bool `json:"terminated"`
|
||||
}
|
||||
|
||||
type DbSqlExec interface {
|
||||
flowapp.FlowBizHandler
|
||||
|
||||
// 执行sql
|
||||
Exec(ctx context.Context, execSqlReq *DbSqlExecReq) ([]*DbSqlExecRes, error)
|
||||
Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error)
|
||||
|
||||
// ExecReader 从reader中读取sql并执行
|
||||
ExecReader(ctx context.Context, execReader *dto.SqlReaderExec) error
|
||||
|
||||
// 根据条件删除sql执行记录
|
||||
DeleteBy(ctx context.Context, condition *entity.DbSqlExec) error
|
||||
@@ -66,9 +69,10 @@ type dbSqlExecAppImpl struct {
|
||||
dbSqlExecRepo repository.DbSqlExec `inject:"DbSqlExecRepo"`
|
||||
|
||||
flowProcdefApp flowapp.Procdef `inject:"ProcdefApp"`
|
||||
msgApp msgapp.Msg `inject:"MsgApp"`
|
||||
}
|
||||
|
||||
func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq, sql string) *entity.DbSqlExec {
|
||||
func createSqlExecRecord(ctx context.Context, execSqlReq *dto.DbSqlExecReq, sql string) *entity.DbSqlExec {
|
||||
dbSqlExecRecord := new(entity.DbSqlExec)
|
||||
dbSqlExecRecord.DbId = execSqlReq.DbId
|
||||
dbSqlExecRecord.Db = execSqlReq.Db
|
||||
@@ -79,7 +83,7 @@ func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq, sql stri
|
||||
return dbSqlExecRecord
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) ([]*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error) {
|
||||
dbConn := execSqlReq.DbConn
|
||||
execSql := execSqlReq.Sql
|
||||
sp := dbConn.GetDialect().GetSQLParser()
|
||||
@@ -89,13 +93,13 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
|
||||
flowProcdef = d.flowProcdefApp.GetProcdefByCodePath(ctx, dbConn.Info.CodePath...)
|
||||
}
|
||||
|
||||
allExecRes := make([]*DbSqlExecRes, 0)
|
||||
allExecRes := make([]*dto.DbSqlExecRes, 0)
|
||||
|
||||
stmts, err := sp.Parse(execSql)
|
||||
// sql解析失败,则使用默认方式切割
|
||||
if err != nil {
|
||||
sqlparser.SQLSplit(strings.NewReader(execSql), func(oneSql string) error {
|
||||
var execRes *DbSqlExecRes
|
||||
var execRes *dto.DbSqlExecRes
|
||||
var err error
|
||||
|
||||
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, oneSql)
|
||||
@@ -120,7 +124,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
|
||||
// 执行错误
|
||||
if err != nil {
|
||||
if execRes == nil {
|
||||
execRes = &DbSqlExecRes{Sql: oneSql}
|
||||
execRes = &dto.DbSqlExecRes{Sql: oneSql}
|
||||
}
|
||||
execRes.ErrorMsg = err.Error()
|
||||
} else {
|
||||
@@ -133,7 +137,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
var execRes *DbSqlExecRes
|
||||
var execRes *dto.DbSqlExecRes
|
||||
var err error
|
||||
|
||||
sql := stmt.GetText()
|
||||
@@ -172,7 +176,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
|
||||
|
||||
if err != nil {
|
||||
if execRes == nil {
|
||||
execRes = &DbSqlExecRes{Sql: sql}
|
||||
execRes = &dto.DbSqlExecRes{Sql: sql}
|
||||
}
|
||||
execRes.ErrorMsg = err.Error()
|
||||
} else {
|
||||
@@ -184,6 +188,64 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
|
||||
return allExecRes, nil
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlReaderExec) error {
|
||||
dbConn := execReader.DbConn
|
||||
|
||||
clientId := execReader.ClientId
|
||||
filename := stringx.Truncate(execReader.Filename, 20, 10, "...")
|
||||
la := contextx.GetLoginAccount(ctx)
|
||||
needSendMsg := la != nil && clientId != ""
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
errInfo := anyx.ToString(err)
|
||||
logx.Errorf("exec sql reader error: %s", errInfo)
|
||||
if needSendMsg {
|
||||
errInfo = stringx.Truncate(errInfo, 300, 10, "...")
|
||||
d.msgApp.CreateAndSend(la, msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] execution failure: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
executedStatements := 0
|
||||
progressId := stringx.Rand(32)
|
||||
if needSendMsg {
|
||||
defer ws.SendJsonMsg(ws.UserId(la.Id), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
|
||||
Id: progressId,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
}
|
||||
|
||||
err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error {
|
||||
if executedStatements%50 == 0 {
|
||||
if needSendMsg {
|
||||
ws.SendJsonMsg(ws.UserId(la.Id), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
|
||||
Id: progressId,
|
||||
Title: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
}
|
||||
}
|
||||
|
||||
executedStatements++
|
||||
if _, err := dbConn.Exec(sql); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if needSendMsg {
|
||||
d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FlowDbExecSqlBizForm struct {
|
||||
DbId uint64 `json:"dbId"` // 库id
|
||||
DbName string `json:"dbName"` // 库名
|
||||
@@ -211,7 +273,7 @@ func (d *dbSqlExecAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *fl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
execRes, err := d.Exec(contextx.NewLoginAccount(&model.LoginAccount{Id: procinst.CreatorId, Username: procinst.Creator}), &DbSqlExecReq{
|
||||
execRes, err := d.Exec(contextx.NewLoginAccount(&model.LoginAccount{Id: procinst.CreatorId, Username: procinst.Creator}), &dto.DbSqlExecReq{
|
||||
DbId: execSqlBizForm.DbId,
|
||||
Db: execSqlBizForm.DbName,
|
||||
Sql: execSqlBizForm.Sql,
|
||||
@@ -257,7 +319,7 @@ func (d *dbSqlExecAppImpl) saveSqlExecLog(dbSqlExecRecord *entity.DbSqlExec, res
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
maxCount := config.GetDbms().MaxResultSet
|
||||
selectStmt := sqlExecParam.Stmt
|
||||
selectSql := sqlExecParam.Sql
|
||||
@@ -314,7 +376,7 @@ func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
selectSql := sqlExecParam.Sql
|
||||
sqlExecParam.SqlExecRecord.Type = entity.DbSqlExecTypeQuery
|
||||
|
||||
@@ -327,7 +389,7 @@ func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExe
|
||||
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
selectSql := sqlExecParam.Sql
|
||||
sqlExecParam.SqlExecRecord.Type = entity.DbSqlExecTypeDDL
|
||||
|
||||
@@ -340,7 +402,7 @@ func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecP
|
||||
return d.doExec(ctx, sqlExecParam.DbConn, selectSql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
dbConn := sqlExecParam.DbConn
|
||||
|
||||
if procdef := sqlExecParam.Procdef; procdef != nil {
|
||||
@@ -365,7 +427,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
tableSources := updatestmt.TableSources.TableSources
|
||||
// 不支持多表更新记录旧值
|
||||
if len(tableSources) != 1 {
|
||||
logx.ErrorContext(ctx, "Update SQL - logging old values only supports single-table updates")
|
||||
logx.ErrorContext(ctx, "update SQL - logging old values only supports single-table updates")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
@@ -379,21 +441,21 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
}
|
||||
|
||||
if tableName == "" {
|
||||
logx.ErrorContext(ctx, "Update SQL - failed to get table name")
|
||||
logx.ErrorContext(ctx, "update SQL - failed to get table name")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
execRecord.Table = tableName
|
||||
|
||||
whereStr := updatestmt.Where.GetText()
|
||||
if whereStr == "" {
|
||||
logx.ErrorContext(ctx, "Update SQL - there is no where condition")
|
||||
logx.ErrorContext(ctx, "update SQL - there is no where condition")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
// 获取表主键列名,排除使用别名
|
||||
primaryKey, err := dbConn.GetMetadata().GetPrimaryKey(tableName)
|
||||
if err != nil {
|
||||
logx.ErrorfContext(ctx, "Update SQL - failed to get primary key column: %s", err.Error())
|
||||
logx.ErrorfContext(ctx, "update SQL - failed to get primary key column: %s", err.Error())
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
@@ -417,12 +479,12 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
nowRec++
|
||||
res = append(res, row)
|
||||
if nowRec == maxRec {
|
||||
return errorx.NewBiz(fmt.Sprintf("Update SQL - the maximum number of updated queries is exceeded: %d", maxRec))
|
||||
return errorx.NewBiz(fmt.Sprintf("update SQL - the maximum number of updated queries is exceeded: %d", maxRec))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
logx.ErrorfContext(ctx, "Update SQL - failed to get the updated old value: %s", err.Error())
|
||||
logx.ErrorfContext(ctx, "update SQL - failed to get the updated old value: %s", err.Error())
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
execRecord.OldValue = jsonx.ToStr(res)
|
||||
@@ -430,7 +492,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
if procdef := sqlExecParam.Procdef; procdef != nil {
|
||||
if needStartProc := procdef.MatchCondition(DbSqlExecFlowBizType, collx.Kvs("stmtType", "delete")); needStartProc {
|
||||
return nil, errorx.NewBizI(ctx, imsg.ErrNeedSubmitWorkTicket)
|
||||
@@ -454,7 +516,7 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
tableSources := deletestmt.TableSources.TableSources
|
||||
// 不支持多表删除记录旧值
|
||||
if len(tableSources) != 1 {
|
||||
logx.ErrorContext(ctx, "Delete SQL - logging old values only supports single-table deletion")
|
||||
logx.ErrorContext(ctx, "delete SQL - logging old values only supports single-table deletion")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
@@ -468,14 +530,14 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
}
|
||||
|
||||
if tableName == "" {
|
||||
logx.ErrorContext(ctx, "Delete SQL - failed to get table name")
|
||||
logx.ErrorContext(ctx, "delete SQL - failed to get table name")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
execRecord.Table = tableName
|
||||
|
||||
whereStr := deletestmt.Where.GetText()
|
||||
if whereStr == "" {
|
||||
logx.ErrorContext(ctx, "Delete SQL - there is no where condition")
|
||||
logx.ErrorContext(ctx, "delete SQL - there is no where condition")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
@@ -487,7 +549,7 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
|
||||
if procdef := sqlExecParam.Procdef; procdef != nil {
|
||||
if needStartProc := procdef.MatchCondition(DbSqlExecFlowBizType, collx.Kvs("stmtType", "insert")); needStartProc {
|
||||
return nil, errorx.NewBizI(ctx, imsg.ErrNeedSubmitWorkTicket)
|
||||
@@ -513,19 +575,19 @@ func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, sqlExecParam.DbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string) (*dto.DbSqlExecRes, error) {
|
||||
cols, res, err := dbConn.QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DbSqlExecRes{
|
||||
return &dto.DbSqlExecRes{
|
||||
Sql: sql,
|
||||
Columns: cols,
|
||||
Res: res,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql string) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql string) (*dto.DbSqlExecRes, error) {
|
||||
rowsAffected, err := dbConn.ExecContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -534,7 +596,7 @@ func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql s
|
||||
res := make([]map[string]any, 0)
|
||||
res = append(res, collx.Kvs("rowsAffected", rowsAffected))
|
||||
|
||||
return &DbSqlExecRes{
|
||||
return &dto.DbSqlExecRes{
|
||||
Columns: []*dbi.QueryColumn{
|
||||
{Name: "rowsAffected", Type: "number"},
|
||||
},
|
||||
|
||||
@@ -3,10 +3,11 @@ package application
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/db/application/dto"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
fileapp "mayfly-go/internal/file/application"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
"mayfly-go/pkg/scheduler"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/timex"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -252,12 +252,90 @@ func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, lo
|
||||
app.EndTransfer(ctx, logId, taskId, "failed to get target db connection", err, nil)
|
||||
return
|
||||
}
|
||||
// 迁移表
|
||||
if err = app.transferDbTables(ctx, logId, task, srcConn, targetConn, tables); err != nil {
|
||||
|
||||
ctx = context.Background()
|
||||
|
||||
tableNames := collx.ArrayMap(tables, func(t dbi.Table) string { return t.TableName })
|
||||
// 分组迁移
|
||||
tableGroups := collx.ArraySplit[string](tableNames, 2)
|
||||
errGroup, _ := errgroup.WithContext(ctx)
|
||||
|
||||
for _, tables := range tableGroups {
|
||||
errGroup.Go(func() error {
|
||||
if !app.IsRunning(taskId) {
|
||||
return errorx.NewBiz("transfer stopped")
|
||||
}
|
||||
|
||||
currentDumpTable := tables[0]
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
err := app.dbApp.DumpDb(ctx, &dto.DumpDb{
|
||||
LogId: logId,
|
||||
DbId: uint64(task.SrcDbId),
|
||||
DbName: task.SrcDbName,
|
||||
TargetDbType: dbi.DbType(task.TargetDbType),
|
||||
Tables: tables,
|
||||
DumpDDL: true,
|
||||
DumpData: true,
|
||||
Writer: pw,
|
||||
Log: func(msg string) { // 记录日志
|
||||
app.Log(ctx, logId, msg)
|
||||
},
|
||||
Progress: func(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) {
|
||||
logExtraKey := fmt.Sprintf("`%s` amount of transfer data currently: ", currentDumpTable)
|
||||
if stmtType == dbi.StmtTypeInsert {
|
||||
app.logApp.SetExtra(logId, logExtraKey, stmtCount)
|
||||
if currentStmtTypeEnd {
|
||||
app.Log(ctx, logId, fmt.Sprintf("execute transfer table [%s] insert %d rows", currentDumpTable, stmtCount))
|
||||
}
|
||||
}
|
||||
|
||||
if currentDumpTable != currentTable {
|
||||
currentDumpTable = currentTable
|
||||
stmtCount = 0
|
||||
// 置空当前表数据迁移量进度
|
||||
app.logApp.SetExtra(logId, logExtraKey, nil)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
pr.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
tx, err := targetConn.Begin()
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
err = sqlparser.SQLSplit(pr, func(stmt string) error {
|
||||
if _, err := targetConn.TxExecContext(ctx, tx, stmt); err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
_ = tx.Commit()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = errGroup.Wait()
|
||||
|
||||
if err != nil {
|
||||
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
|
||||
return
|
||||
}
|
||||
|
||||
app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("execute transfer task [taskId = %d] complete, time: %v", taskId, time.Since(start)), nil, nil)
|
||||
}
|
||||
|
||||
@@ -281,10 +359,7 @@ func (app *dbTransferAppImpl) transfer2File(ctx context.Context, taskId uint64,
|
||||
}
|
||||
|
||||
// 从tables提取表名
|
||||
tableNames := make([]string, 0)
|
||||
for _, table := range tables {
|
||||
tableNames = append(tableNames, table.TableName)
|
||||
}
|
||||
tableNames := collx.ArrayMap(tables, func(t dbi.Table) string { return t.TableName })
|
||||
// 2、把源库数据迁移到文件
|
||||
app.Log(ctx, logId, fmt.Sprintf("start transfer table data to files: %s", filename))
|
||||
app.Log(ctx, logId, fmt.Sprintf("dialect type of target db file: %s", task.TargetFileDbType))
|
||||
@@ -341,228 +416,6 @@ func (app *dbTransferAppImpl) Stop(ctx context.Context, taskId uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移表
|
||||
func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64, task *entity.DbTransferTask, srcConn *dbi.DbConn, targetConn *dbi.DbConn, tables []dbi.Table) error {
|
||||
tableNames := make([]string, 0)
|
||||
tableMap := make(map[string]dbi.Table) // 以表名分组,存放表信息
|
||||
for _, table := range tables {
|
||||
tableNames = append(tableNames, table.TableName)
|
||||
tableMap[table.TableName] = table
|
||||
}
|
||||
|
||||
if len(tableNames) == 0 {
|
||||
return errorx.NewBiz("there are no tables to migrate")
|
||||
}
|
||||
srcDialect := srcConn.GetDialect()
|
||||
srcMetadata := srcConn.GetMetadata()
|
||||
// 查询源表列信息
|
||||
columns, err := srcMetadata.GetColumns(tableNames...)
|
||||
if err != nil {
|
||||
return errorx.NewBiz("failed to get the source table column information")
|
||||
}
|
||||
|
||||
// 以表名分组,存放每个表的列信息
|
||||
columnMap := make(map[string][]dbi.Column)
|
||||
for _, column := range columns {
|
||||
columnMap[column.TableName] = append(columnMap[column.TableName], column)
|
||||
}
|
||||
|
||||
// 以表名排序
|
||||
sortTableNames := collx.MapKeys(columnMap)
|
||||
sort.Strings(sortTableNames)
|
||||
|
||||
targetDialect := targetConn.GetDialect()
|
||||
srcColumnHelper := srcDialect.GetColumnHelper()
|
||||
targetColumnHelper := targetConn.GetDialect().GetColumnHelper()
|
||||
|
||||
// 分组迁移
|
||||
tableGroups := collx.ArraySplit[string](sortTableNames, 2)
|
||||
errGroup, _ := errgroup.WithContext(ctx)
|
||||
|
||||
for _, tables := range tableGroups {
|
||||
errGroup.Go(func() error {
|
||||
for _, tbName := range tables {
|
||||
cols := columnMap[tbName]
|
||||
targetCols := make([]dbi.Column, 0)
|
||||
for _, col := range cols {
|
||||
colPtr := &col
|
||||
// 源库列转为公共列
|
||||
srcColumnHelper.ToCommonColumn(colPtr)
|
||||
// 公共列转为目标库列
|
||||
targetColumnHelper.ToColumn(colPtr)
|
||||
targetCols = append(targetCols, *colPtr)
|
||||
}
|
||||
|
||||
// 通过公共列信息生成目标库的建表语句,并执行目标库建表
|
||||
app.Log(ctx, logId, fmt.Sprintf("start creating the target table: %s", tbName))
|
||||
|
||||
sqlArr := targetDialect.GenerateTableDDL(targetCols, tableMap[tbName], true)
|
||||
for _, sqlStr := range sqlArr {
|
||||
_, err := targetConn.Exec(sqlStr)
|
||||
if err != nil {
|
||||
return errorx.NewBiz(fmt.Sprintf("failed to create target table: %s, error: %s", tbName, err.Error()))
|
||||
}
|
||||
}
|
||||
app.Log(ctx, logId, fmt.Sprintf("target table created successfully: %s", tbName))
|
||||
|
||||
// 迁移数据
|
||||
app.Log(ctx, logId, fmt.Sprintf("start transfer data: %s", tbName))
|
||||
total, err := app.transferData(ctx, logId, task.Id, tbName, targetCols, srcConn, targetConn)
|
||||
if err != nil {
|
||||
return errorx.NewBiz(fmt.Sprintf("failed to transfer => table: %s, error: %s", tbName, err.Error()))
|
||||
}
|
||||
app.Log(ctx, logId, fmt.Sprintf("successfully transfer data => table: %s, data: %d entries", tbName, total))
|
||||
|
||||
// 有些数据库迁移完数据之后,需要更新表自增序列为当前表最大值
|
||||
targetDialect.UpdateSequence(tbName, targetCols)
|
||||
|
||||
// 迁移索引信息
|
||||
app.Log(ctx, logId, fmt.Sprintf("start transfer index => table: %s", tbName))
|
||||
err = app.transferIndex(ctx, tableMap[tbName], srcConn, targetConn)
|
||||
if err != nil {
|
||||
return errorx.NewBiz(fmt.Sprintf("failed to transfer index => table: %s, error: %s", tbName, err.Error()))
|
||||
}
|
||||
app.Log(ctx, logId, fmt.Sprintf("successfully transfer index => table: %s", tbName))
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func (app *dbTransferAppImpl) transferData(ctx context.Context, logId uint64, taskId uint64, tableName string, targetColumns []dbi.Column, srcConn *dbi.DbConn, targetConn *dbi.DbConn) (int, error) {
|
||||
result := make([]map[string]any, 0)
|
||||
total := 0 // 总条数
|
||||
batchSize := 1000 // 每次查询并迁移1000条数据
|
||||
var err error
|
||||
srcDialect := srcConn.GetDialect()
|
||||
srcConverter := srcDialect.GetDataHelper()
|
||||
targetDialect := targetConn.GetDialect()
|
||||
logExtraKey := fmt.Sprintf("`%s` amount of transfer data currently: ", tableName)
|
||||
|
||||
// 游标查询源表数据,并批量插入目标表
|
||||
_, err = srcConn.WalkTableRows(context.Background(), tableName, func(row map[string]any, columns []*dbi.QueryColumn) error {
|
||||
total++
|
||||
rawValue := map[string]any{}
|
||||
for _, column := range columns {
|
||||
// 某些情况,如oracle,需要转换时间类型的字符串为time类型
|
||||
res := srcConverter.ParseData(row[column.Name], srcConverter.GetDataType(column.Type))
|
||||
rawValue[column.Name] = res
|
||||
}
|
||||
result = append(result, rawValue)
|
||||
if total%batchSize == 0 {
|
||||
err = app.transfer2Target(taskId, targetConn, targetColumns, result, targetDialect, tableName)
|
||||
if err != nil {
|
||||
logx.ErrorfContext(ctx, "batch insert data to target table failed: %v", err)
|
||||
return err
|
||||
}
|
||||
result = result[:0]
|
||||
app.logApp.SetExtra(logId, logExtraKey, total)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
// 处理剩余的数据
|
||||
if len(result) > 0 {
|
||||
err = app.transfer2Target(taskId, targetConn, targetColumns, result, targetDialect, tableName)
|
||||
if err != nil {
|
||||
logx.ErrorfContext(ctx, "batch insert data to target table failed => table: %s, error: %v", tableName, err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
// 置空当前表数据迁移量进度
|
||||
app.logApp.SetExtra(logId, logExtraKey, nil)
|
||||
return total, err
|
||||
}
|
||||
|
||||
func (app *dbTransferAppImpl) transfer2Target(taskId uint64, targetConn *dbi.DbConn, targetColumns []dbi.Column, result []map[string]any, targetDialect dbi.Dialect, tbName string) error {
|
||||
if !app.IsRunning(taskId) {
|
||||
return errorx.NewBiz("transfer stopped")
|
||||
}
|
||||
|
||||
tx, err := targetConn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 收集字段名
|
||||
var columnNames []string
|
||||
for _, col := range targetColumns {
|
||||
columnNames = append(columnNames, targetDialect.QuoteIdentifier(col.ColumnName))
|
||||
}
|
||||
|
||||
dataHelper := targetDialect.GetDataHelper()
|
||||
|
||||
// 从目标库数据中取出源库字段对应的值
|
||||
values := make([][]any, 0)
|
||||
for _, record := range result {
|
||||
rawValue := make([]any, 0)
|
||||
for _, tc := range targetColumns {
|
||||
columnName := tc.ColumnName
|
||||
val := record[targetDialect.RemoveQuote(columnName)]
|
||||
if !tc.Nullable {
|
||||
// 如果val是文本,则设置为空格字符
|
||||
switch val.(type) {
|
||||
case string:
|
||||
if val == "" {
|
||||
val = " "
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dataHelper.GetDataType(string(tc.DataType)) == dbi.DataTypeBlob {
|
||||
decodeBytes, err := hex.DecodeString(val.(string))
|
||||
if err == nil {
|
||||
val = decodeBytes
|
||||
}
|
||||
}
|
||||
|
||||
rawValue = append(rawValue, val)
|
||||
}
|
||||
values = append(values, rawValue)
|
||||
}
|
||||
// 批量插入
|
||||
_, err = targetDialect.BatchInsert(tx, tbName, columnNames, values, -1)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
logx.Errorf("batch insert data to target table failed: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = tx.Commit()
|
||||
return err
|
||||
}
|
||||
|
||||
func (app *dbTransferAppImpl) transferIndex(ctx context.Context, tableInfo dbi.Table, srcConn *dbi.DbConn, targetConn *dbi.DbConn) error {
|
||||
// 查询源表索引信息
|
||||
indexs, err := srcConn.GetMetadata().GetTableIndex(tableInfo.TableName)
|
||||
if err != nil {
|
||||
logx.Errorf("failed to get index information: %s", err)
|
||||
return err
|
||||
}
|
||||
if len(indexs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 通过表名、索引信息生成建索引语句,并执行到目标表
|
||||
sqlArr := targetConn.GetDialect().GenerateIndexDDL(indexs, tableInfo)
|
||||
for _, sqlStr := range sqlArr {
|
||||
_, err := targetConn.Exec(sqlStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dbTransferAppImpl) TimerDeleteTransferFile() {
|
||||
logx.Debug("start deleting transfer files periodically...")
|
||||
scheduler.AddFun("@every 100m", func() {
|
||||
|
||||
@@ -23,10 +23,16 @@ type DumpDb struct {
|
||||
LogId uint64
|
||||
|
||||
Writer io.WriteCloser
|
||||
Log func(msg string)
|
||||
TargetDbType dbi.DbType
|
||||
|
||||
Log func(msg string)
|
||||
Progress func(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) // dump进度
|
||||
}
|
||||
|
||||
func DefaultDumpLog(msg string) {
|
||||
|
||||
}
|
||||
|
||||
func DefaultDumpProgress(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) {
|
||||
|
||||
}
|
||||
|
||||
31
server/internal/db/application/dto/sql_exec.go
Normal file
31
server/internal/db/application/dto/sql_exec.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
type DbSqlExecReq struct {
|
||||
DbId uint64
|
||||
Db string
|
||||
Sql string // 需要执行的sql,支持多条
|
||||
Remark string // 执行备注
|
||||
DbConn *dbi.DbConn
|
||||
CheckFlow bool // 是否检查存储审批流程
|
||||
}
|
||||
|
||||
type DbSqlExecRes struct {
|
||||
Sql string `json:"sql"` // 执行的sql
|
||||
ErrorMsg string `json:"errorMsg"` // 若执行失败,则将失败内容记录到该字段
|
||||
Columns []*dbi.QueryColumn `json:"columns"` // 响应的列信息
|
||||
Res []map[string]any `json:"res"` // 响应结果
|
||||
}
|
||||
|
||||
type SqlReaderExec struct {
|
||||
DbConn *dbi.DbConn
|
||||
|
||||
Reader io.Reader
|
||||
Filename string
|
||||
|
||||
ClientId string // 客户端id,若存在则会向其发送执行进度消息
|
||||
}
|
||||
594
server/internal/db/dbm/dbi/column.go
Normal file
594
server/internal/db/dbm/dbi/column.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
var (
|
||||
dbDataTypes = make(map[DbType]map[string]*DbDataType) // 列类型
|
||||
)
|
||||
|
||||
// registerColumnDbDataTypes 注册数据库对应的数据类型
|
||||
func registerColumnDbDataTypes(dbType DbType, cts ...*DbDataType) {
|
||||
dbDataTypes[dbType] = collx.ArrayToMap(cts, func(ct *DbDataType) string {
|
||||
return strings.ToLower(string(ct.Name))
|
||||
})
|
||||
}
|
||||
|
||||
func GetDbDataType(dbType DbType, databaseColumnType string) *DbDataType {
|
||||
return cmp.Or(dbDataTypes[dbType][strings.ToLower(databaseColumnType)], DefaultDbDataType)
|
||||
}
|
||||
|
||||
var DefaultDbDataType = NewDbDataType("string", DTString).WithCT(CTVarchar)
|
||||
|
||||
// 表的列信息
|
||||
type Column struct {
|
||||
TableName string `json:"tableName"` // 表名
|
||||
ColumnName string `json:"columnName"` // 列名
|
||||
DataType string `json:"dataType"` // 数据类型
|
||||
ColumnComment string `json:"columnComment"` // 列备注
|
||||
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否为主键
|
||||
IsIdentity bool `json:"isIdentity"` // 是否自增
|
||||
ColumnDefault string `json:"columnDefault"` // 默认值
|
||||
Nullable bool `json:"nullable"` // 是否可为null
|
||||
CharMaxLength int `json:"charMaxLength"` // 字符最大长度
|
||||
NumPrecision int `json:"numPrecision"` // 精度(总数字位数)
|
||||
NumScale int `json:"numScale"` // 小数点位数
|
||||
Extra collx.M `json:"extra"` // 其他额外信息
|
||||
}
|
||||
|
||||
// GetColumnType 获取完整的列类型,拼接数据类型与长度等。如varchar(2000),decimal(20,2)
|
||||
func (c *Column) GetColumnType() string {
|
||||
if c.CharMaxLength > 0 {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.CharMaxLength)
|
||||
}
|
||||
if c.NumPrecision > 0 {
|
||||
if c.NumScale > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", c.DataType, c.NumPrecision, c.NumScale)
|
||||
} else {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.NumPrecision)
|
||||
}
|
||||
}
|
||||
|
||||
return string(c.DataType)
|
||||
}
|
||||
|
||||
// 数据库对应的数据类型
|
||||
type DbDataType struct {
|
||||
Name string // 类型名
|
||||
|
||||
DataType *DataType // 数据类型
|
||||
|
||||
fixColumnFunc func(column *Column) // 修复字段长度、精度等, 如mysql text会返回长度,需要将其置为0等
|
||||
|
||||
/** 以下为异构数据迁移同步使用,可不赋值,无值则不支持迁移同步 */
|
||||
|
||||
CommonType CommonDbDataType // 对应的公共类型
|
||||
}
|
||||
|
||||
// WithFixColumn 修复列信息函数,用于修复字段长度、精度等
|
||||
func (ct *DbDataType) WithFixColumn(fixColumnFunc func(column *Column)) *DbDataType {
|
||||
ct.fixColumnFunc = fixColumnFunc
|
||||
return ct
|
||||
}
|
||||
|
||||
// WithCT 对应的公共类型,主要用于异构数据库迁移同步时进行类型转换使用
|
||||
func (ct *DbDataType) WithCT(cct CommonDbDataType) *DbDataType {
|
||||
ct.CommonType = cct
|
||||
return ct
|
||||
}
|
||||
|
||||
// FixColumn 使用修复列信息函数进行列信息修复
|
||||
func (ct *DbDataType) FixColumn(column *Column) {
|
||||
if ct.fixColumnFunc != nil {
|
||||
ct.fixColumnFunc(column)
|
||||
}
|
||||
}
|
||||
|
||||
func NewDbDataType(name string, dataType *DataType) *DbDataType {
|
||||
return &DbDataType{
|
||||
Name: name,
|
||||
DataType: dataType,
|
||||
}
|
||||
}
|
||||
|
||||
func ClearCharMaxLength(column *Column) {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
}
|
||||
|
||||
func ClearNumScale(column *Column) {
|
||||
column.NumScale = 0
|
||||
column.CharMaxLength = 0
|
||||
}
|
||||
|
||||
// DataType 数据类型, 对应于go类型,如int int64等。可自定义其他类型
|
||||
type DataType struct {
|
||||
Name string // 类型名
|
||||
|
||||
Valuer func() Valuer // 获取值对应的处理者,用于sql的scan、解析value等
|
||||
|
||||
SQLValue func(val any) string // 转换为sql字符串值,用于insert等SQL语句的值转换
|
||||
}
|
||||
|
||||
// Copy 拷贝一个同类型的datatype,主要方便用于定制化修改Valuer或ToString
|
||||
func (dt *DataType) Copy() *DataType {
|
||||
return &DataType{
|
||||
Name: dt.Name,
|
||||
Valuer: dt.Valuer,
|
||||
SQLValue: dt.SQLValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (dt *DataType) WithValuer(valuerFunc func() Valuer) *DataType {
|
||||
dt.Valuer = valuerFunc
|
||||
return dt
|
||||
}
|
||||
|
||||
func (dt *DataType) WithSQLValue(sqlvalueFunc func(val any) string) *DataType {
|
||||
dt.SQLValue = sqlvalueFunc
|
||||
return dt
|
||||
}
|
||||
|
||||
const NULL = "NULL"
|
||||
|
||||
// SQLValueDefault 默认使用fmt转string
|
||||
func SQLValueDefault(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
return fmt.Sprintf("'%v'", val)
|
||||
}
|
||||
|
||||
// SQLValueNumeric 数字类型转string
|
||||
func SQLValueNumeric(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
func SQLValueString(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(strings.ReplaceAll(strVal, "'", "''"), `\`, `\\`))
|
||||
}
|
||||
|
||||
var (
|
||||
DTBit = &DataType{
|
||||
Name: "bit",
|
||||
Valuer: ValuerBit,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTByte = &DataType{
|
||||
Name: "uint8",
|
||||
Valuer: ValuerByte,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt8 = &DataType{
|
||||
Name: "int8",
|
||||
Valuer: ValuerInt16,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt16 = &DataType{
|
||||
Name: "int16",
|
||||
Valuer: ValuerInt16,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt32 = &DataType{
|
||||
Name: "int32",
|
||||
Valuer: ValuerInt32,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt64 = &DataType{
|
||||
Name: "int64",
|
||||
Valuer: ValuerInt64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
// 所有无符号类型,都使用int64存储
|
||||
DTUint64 = &DataType{
|
||||
Name: "uint64",
|
||||
Valuer: ValuerUint64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTNumeric = &DataType{
|
||||
Name: "numeric",
|
||||
Valuer: ValuerFloat64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTDecimal = &DataType{
|
||||
Name: "decimal",
|
||||
Valuer: ValuerString,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTString = &DataType{
|
||||
Name: "string",
|
||||
Valuer: ValuerString,
|
||||
SQLValue: SQLValueString,
|
||||
}
|
||||
|
||||
DTDate = &DataType{
|
||||
Name: "date",
|
||||
Valuer: ValuerDate,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTTime = &DataType{
|
||||
Name: "time",
|
||||
Valuer: ValuerTime,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTDateTime = &DataType{
|
||||
Name: "datetime",
|
||||
Valuer: ValuerDatetime,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTBytes = &DataType{
|
||||
Name: "bytes",
|
||||
Valuer: ValuerBytes,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
)
|
||||
|
||||
// Valuer 获取值对应的处理者,用于sql row scan、解析value等
|
||||
type Valuer interface {
|
||||
|
||||
// NewValuePtr 新建值对应的指针,用于sql的row scan
|
||||
NewValuePtr() any
|
||||
|
||||
// Value 获取对应的值(人类可阅读的值),不可原样返回ValuePtr指针类型,需取出具体的值
|
||||
Value() any
|
||||
}
|
||||
|
||||
type DefaultValuer[T any] struct {
|
||||
ValuePtr *T
|
||||
}
|
||||
|
||||
func (s *DefaultValuer[T]) NewValuePtr() any {
|
||||
var t T
|
||||
s.ValuePtr = &t
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
// Valuer工厂函数
|
||||
|
||||
func ValuerString() Valuer {
|
||||
return &stringValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullString]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt64() Valuer {
|
||||
return &int64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt64]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerUint64() Valuer {
|
||||
return &uint64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[[]byte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt32() Valuer {
|
||||
return &int32Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt32]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt16() Valuer {
|
||||
return &int16Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt16]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerByte() Valuer {
|
||||
return &byteValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullByte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerBit() Valuer {
|
||||
return &bitValuer{
|
||||
DefaultValuer: new(DefaultValuer[[]byte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerFloat64() Valuer {
|
||||
return &float64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullFloat64]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerDatetime() Valuer {
|
||||
return &datetimeValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerDate() Valuer {
|
||||
return &dateValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerTime() Valuer {
|
||||
return &timeValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerBytes() Valuer {
|
||||
return &bytesValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.RawBytes]),
|
||||
}
|
||||
}
|
||||
|
||||
// 默认 valuer
|
||||
|
||||
// string
|
||||
|
||||
type stringValuer struct {
|
||||
*DefaultValuer[sql.NullString]
|
||||
}
|
||||
|
||||
func (s *stringValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.String
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// uint64
|
||||
|
||||
type uint64Valuer struct {
|
||||
*DefaultValuer[[]byte]
|
||||
}
|
||||
|
||||
func (s *uint64Valuer) Value() any {
|
||||
valBytes := *s.ValuePtr
|
||||
if valBytes == nil {
|
||||
return nil
|
||||
}
|
||||
val := string(valBytes)
|
||||
// 前端超过16位会丢失精度
|
||||
if len(val) > 16 {
|
||||
return val
|
||||
}
|
||||
return cast.ToUint64(val)
|
||||
}
|
||||
|
||||
// int64
|
||||
|
||||
type int64Valuer struct {
|
||||
*DefaultValuer[sql.NullInt64]
|
||||
}
|
||||
|
||||
func (s *int64Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
val := s.ValuePtr.Int64
|
||||
// 前端超过16位会丢失精度
|
||||
if val > 9999999999999999 {
|
||||
return fmt.Sprintf("%d", val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// int32
|
||||
|
||||
type int32Valuer struct {
|
||||
*DefaultValuer[sql.NullInt32]
|
||||
}
|
||||
|
||||
func (s *int32Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Int32
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// int16
|
||||
|
||||
type int16Valuer struct {
|
||||
*DefaultValuer[sql.NullInt16]
|
||||
}
|
||||
|
||||
func (s *int16Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Int16
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// byte(uint8)
|
||||
|
||||
type byteValuer struct {
|
||||
*DefaultValuer[sql.NullByte]
|
||||
}
|
||||
|
||||
func (s *byteValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Byte
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bit
|
||||
|
||||
type bitValuer struct {
|
||||
*DefaultValuer[[]byte]
|
||||
}
|
||||
|
||||
func (s *bitValuer) Value() any {
|
||||
valBytes := *s.ValuePtr
|
||||
if valBytes == nil {
|
||||
return nil
|
||||
}
|
||||
return valBytes[0]
|
||||
}
|
||||
|
||||
// float64
|
||||
|
||||
type float64Valuer struct {
|
||||
*DefaultValuer[sql.NullFloat64]
|
||||
}
|
||||
|
||||
func (s *float64Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Float64
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bytes
|
||||
|
||||
type bytesValuer struct {
|
||||
*DefaultValuer[sql.RawBytes]
|
||||
}
|
||||
|
||||
func (s *bytesValuer) Value() any {
|
||||
val := s.ValuePtr
|
||||
if *val == nil {
|
||||
return nil
|
||||
}
|
||||
return hex.EncodeToString(*val)
|
||||
}
|
||||
|
||||
// datetime
|
||||
|
||||
type datetimeValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *datetimeValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.DateTime,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *datetimeValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// date
|
||||
|
||||
type dateValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *dateValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.DateOnly,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *dateValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// time
|
||||
|
||||
type timeValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *timeValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.TimeOnly,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *timeValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NullTime represents a time that may be null.
|
||||
// NullTime implements the [Scanner] interface so
|
||||
// it can be used as a scan destination, similar to [NullString].
|
||||
type NullTime struct {
|
||||
Time string
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
Layout string
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = NullTime{}
|
||||
)
|
||||
|
||||
// Scan implements the [Scanner] interface.
|
||||
func (n *NullTime) Scan(value any) error {
|
||||
if value == nil {
|
||||
n.Time, n.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
|
||||
n.Valid = true
|
||||
time, err := convertTime(value, n.Layout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.Time = time
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullTime) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
func convertTime(src interface{}, layout string) (string, error) {
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
return s, nil
|
||||
case []uint8:
|
||||
return string(s), nil
|
||||
case time.Time:
|
||||
return s.Format(layout), nil
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"mayfly-go/internal/machine/mcm"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
@@ -23,19 +24,33 @@ type DbConn struct {
|
||||
// 执行数据库查询返回的列信息
|
||||
type QueryColumn struct {
|
||||
Name string `json:"name"` // 列名
|
||||
Type string `json:"type"` // 类型
|
||||
Type string `json:"type"` // 数据类型
|
||||
|
||||
SqlColType *sql.ColumnType `json:"-"`
|
||||
DbDataType *DbDataType `json:"-"`
|
||||
valuer Valuer `json:"-"`
|
||||
}
|
||||
|
||||
func NewQueryColumn(colName string, col *sql.ColumnType) *QueryColumn {
|
||||
func NewQueryColumn(colName string, columnType *DbDataType) *QueryColumn {
|
||||
return &QueryColumn{
|
||||
Name: col.Name(),
|
||||
Type: col.DatabaseTypeName(),
|
||||
SqlColType: col,
|
||||
Name: colName,
|
||||
Type: columnType.DataType.Name,
|
||||
DbDataType: columnType,
|
||||
valuer: columnType.DataType.Valuer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) getValuePtr() any {
|
||||
return qc.valuer.NewValuePtr()
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) value() any {
|
||||
return qc.valuer.Value()
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) SQLValue(val any) any {
|
||||
return qc.DbDataType.DataType.SQLValue(val)
|
||||
}
|
||||
|
||||
func (d *DbConn) GetDb() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
@@ -80,7 +95,7 @@ func (d *DbConn) Query2Struct(execSql string, dest any) error {
|
||||
|
||||
// WalkQueryRows 游标方式遍历查询结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
|
||||
func (d *DbConn) WalkQueryRows(ctx context.Context, querySql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
return walkQueryRows(ctx, d.GetDialect(), d.db, querySql, walkFn, args...)
|
||||
return d.walkQueryRows(ctx, querySql, walkFn, args...)
|
||||
}
|
||||
|
||||
// WalkTableRows 游标方式遍历指定表的结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
|
||||
@@ -138,6 +153,11 @@ func (d *DbConn) GetMetadata() Metadata {
|
||||
return d.Info.Meta.GetMetadata(d)
|
||||
}
|
||||
|
||||
// GetDbDataType 获取定义的数据库数据类型
|
||||
func (d *DbConn) GetDbDataType(dataType string) *DbDataType {
|
||||
return GetDbDataType(d.Info.Type, dataType)
|
||||
}
|
||||
|
||||
// Stats 返回数据库连接状态
|
||||
func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBStats {
|
||||
return d.db.Stats()
|
||||
@@ -149,8 +169,8 @@ func (d *DbConn) Close() {
|
||||
if err := d.db.Close(); err != nil {
|
||||
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
|
||||
}
|
||||
// 如果是达梦并且使用了ssh隧道,则需要手动将其关闭
|
||||
if d.Info.Type == DbTypeDM && d.Info.SshTunnelMachineId > 0 {
|
||||
// 如果是使用了自己实现的ssh隧道转发,则需要手动将其关闭
|
||||
if d.Info.useSshTunnel {
|
||||
mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id))
|
||||
}
|
||||
d.db = nil
|
||||
@@ -158,11 +178,11 @@ func (d *DbConn) Close() {
|
||||
}
|
||||
|
||||
// 游标方式遍历查询rows, walkFn error不为nil, 则跳出遍历
|
||||
func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
func (d *DbConn) walkQueryRows(ctx context.Context, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
cancelCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := db.QueryContext(cancelCtx, selectSql, args...)
|
||||
rows, err := d.db.QueryContext(cancelCtx, selectSql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,8 +190,6 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer rows.Close()
|
||||
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -188,9 +206,9 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
if colName == "" {
|
||||
colName = fmt.Sprintf("<anonymous%d>", k+1)
|
||||
}
|
||||
qc := NewQueryColumn(colName, colType)
|
||||
qc := NewQueryColumn(colName, d.GetDbDataType(colType.DatabaseTypeName()))
|
||||
cols[k] = qc
|
||||
scans[k] = columnHelper.GetScanDestPtr(qc)
|
||||
scans[k] = qc.getValuePtr()
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
@@ -201,8 +219,8 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
// 每行数据
|
||||
rowData := make(map[string]any, lenCols)
|
||||
// 把values中的数据复制到row中
|
||||
for i, v := range scans {
|
||||
rowData[cols[i].Name] = columnHelper.ConvertScanDestValue(v, cols[i])
|
||||
for i := range scans {
|
||||
rowData[cols[i].Name] = cols[i].value()
|
||||
}
|
||||
if err = walkFn(rowData, cols); err != nil {
|
||||
logx.ErrorfContext(ctx, "[%s] cursor traversal query result set error, exit traversal: %s", selectSql, err.Error())
|
||||
|
||||
@@ -11,19 +11,6 @@ import (
|
||||
|
||||
type DbType string
|
||||
|
||||
const (
|
||||
DbTypeMysql DbType = "mysql"
|
||||
DbTypeMariadb DbType = "mariadb"
|
||||
DbTypePostgres DbType = "postgres"
|
||||
DbTypeGauss DbType = "gauss"
|
||||
DbTypeDM DbType = "dm"
|
||||
DbTypeOracle DbType = "oracle"
|
||||
DbTypeSqlite DbType = "sqlite"
|
||||
DbTypeMssql DbType = "mssql"
|
||||
DbTypeKingbaseEs DbType = "kingbaseEs"
|
||||
DbTypeVastbase DbType = "vastbase"
|
||||
)
|
||||
|
||||
func ToDbType(dbType string) DbType {
|
||||
return DbType(dbType)
|
||||
}
|
||||
@@ -52,6 +39,7 @@ type DbInfo struct {
|
||||
|
||||
CodePath []string
|
||||
SshTunnelMachineId int
|
||||
useSshTunnel bool // 是否使用系统自己实现的ssh隧道连接,而非库自带的
|
||||
|
||||
Meta Meta
|
||||
}
|
||||
@@ -116,6 +104,7 @@ func (di *DbInfo) IfUseSshTunnelChangeIpPort() error {
|
||||
}
|
||||
di.Host = exposedIp
|
||||
di.Port = exposedPort
|
||||
di.useSshTunnel = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/pgsql"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pq "gitee.com/liuzongyang/libpq"
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
const DefaultQuoter = `"`
|
||||
|
||||
const (
|
||||
// -1. 无操作
|
||||
DuplicateStrategyNone = -1
|
||||
@@ -36,47 +26,15 @@ type DbCopyTable struct {
|
||||
// BaseDialect 基础dialect,在DefaultDialect 都有默认的实现方法
|
||||
type BaseDialect interface {
|
||||
|
||||
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
|
||||
// used as part of an SQL statement. For example:
|
||||
//
|
||||
// tblname := "my_table"
|
||||
// data := "my_data"
|
||||
// quoted := quoteIdentifier(tblname, '"')
|
||||
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
|
||||
//
|
||||
// Any double quotes in name will be escaped. The quoted identifier will be
|
||||
// case sensitive when used in a query. If the input string contains a zero
|
||||
// byte, the result will be truncated immediately before it.
|
||||
QuoteIdentifier(name string) string
|
||||
|
||||
RemoveQuote(name string) string
|
||||
|
||||
// QuoteEscape 引号转义,多用于sql注释转义,防止拼接sql报错,如: comment xx is '注''释' 最终注释文本为: 注'释
|
||||
QuoteEscape(str string) string
|
||||
|
||||
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
|
||||
// to DDL and other statements that do not accept parameters) to be used as part
|
||||
// of an SQL statement. For example:
|
||||
//
|
||||
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
|
||||
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
|
||||
//
|
||||
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
|
||||
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
|
||||
QuoteLiteral(literal string) string
|
||||
// Quoter sql关键字引用处理,如 table -> `table`、table -> "table"
|
||||
Quoter() Quoter
|
||||
|
||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
|
||||
GetDbProgram() (DbProgram, error)
|
||||
|
||||
// GetColumnHelper
|
||||
GetColumnHelper() ColumnHelper
|
||||
|
||||
// GetDumpHeler
|
||||
GetDumpHelper() DumpHelper
|
||||
|
||||
// GetDataHelper 获取数据处理助手 用于解析格式化列数据等
|
||||
GetDataHelper() DataHelper
|
||||
|
||||
// GetSQLParser 获取sql解析器
|
||||
GetSQLParser() sqlparser.SqlParser
|
||||
}
|
||||
@@ -86,20 +44,11 @@ type BaseDialect interface {
|
||||
type Dialect interface {
|
||||
BaseDialect
|
||||
|
||||
// BatchInsert 批量insert数据
|
||||
BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error)
|
||||
|
||||
// CopyTable 拷贝表
|
||||
CopyTable(copy *DbCopyTable) error
|
||||
|
||||
// GenerateTableDDL 生成建表ddl
|
||||
GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string
|
||||
|
||||
// GenerateIndexDDL 生成索引ddl
|
||||
GenerateIndexDDL(indexs []Index, tableInfo Table) []string
|
||||
|
||||
// UpdateSequence 有些数据库迁移完数据之后,需要更新表自增序列为当前表最大值
|
||||
UpdateSequence(tableName string, columns []Column)
|
||||
// GetSQLGenerator 获取sql生成器
|
||||
GetSQLGenerator() SQLGenerator
|
||||
}
|
||||
|
||||
// DefaultDialect 默认实现,若需要覆盖,则由各个数据库dialect实现去覆盖重写
|
||||
@@ -108,29 +57,10 @@ type DefaultDialect struct {
|
||||
|
||||
var _ (BaseDialect) = (*DefaultDialect)(nil)
|
||||
|
||||
func (dx *DefaultDialect) QuoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return DefaultQuoter + strings.Replace(name, DefaultQuoter, DefaultQuoter+DefaultQuoter, -1) + DefaultQuoter
|
||||
func (dx *DefaultDialect) Quoter() Quoter {
|
||||
return DefaultQuoter
|
||||
}
|
||||
|
||||
func (dx *DefaultDialect) RemoveQuote(name string) string {
|
||||
return strings.ReplaceAll(name, DefaultQuoter, "")
|
||||
}
|
||||
|
||||
func (dd *DefaultDialect) QuoteEscape(str string) string {
|
||||
return strings.Replace(str, `'`, `''`, -1)
|
||||
}
|
||||
|
||||
func (dd *DefaultDialect) QuoteLiteral(literal string) string {
|
||||
return pq.QuoteLiteral(literal)
|
||||
}
|
||||
|
||||
func (dd *DefaultDialect) UpdateSequence(tableName string, columns []Column) {}
|
||||
|
||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
|
||||
func (dd *DefaultDialect) GetDbProgram() (DbProgram, error) {
|
||||
return nil, errors.New("not support db program")
|
||||
}
|
||||
@@ -139,113 +69,10 @@ func (dd *DefaultDialect) GetDumpHelper() DumpHelper {
|
||||
return new(DefaultDumpHelper)
|
||||
}
|
||||
|
||||
func (dd *DefaultDialect) GetColumnHelper() ColumnHelper {
|
||||
return new(DefaultColumnHelper)
|
||||
}
|
||||
|
||||
func (pd *DefaultDialect) GetSQLParser() sqlparser.SqlParser {
|
||||
return new(pgsql.PgsqlParser)
|
||||
}
|
||||
|
||||
func (pd *DefaultDialect) GetDataHelper() DataHelper {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ColumnHelper 数据库迁移辅助方法
|
||||
type ColumnHelper interface {
|
||||
// ToCommonColumn 数据库方言自带的列转换为公共列
|
||||
ToCommonColumn(dialectColumn *Column)
|
||||
|
||||
// ToColumn 公共列转为各个数据库方言自带的列
|
||||
ToColumn(commonColumn *Column)
|
||||
|
||||
// FixColumn 根据数据库类型修复字段长度、精度等
|
||||
FixColumn(column *Column)
|
||||
|
||||
// GetScanDestPtr 获取scan列目标值指针,用于在*sql.Rows.Scan()填充该值
|
||||
GetScanDestPtr(*QueryColumn) any
|
||||
|
||||
// ConvertScanDestValue 将scan的填充的原始值data转为可阅读的值,如[]byte -> string or number...
|
||||
ConvertScanDestValue(data any, qc *QueryColumn) any
|
||||
}
|
||||
|
||||
type DefaultColumnHelper struct {
|
||||
}
|
||||
|
||||
func (dd *DefaultColumnHelper) ToCommonColumn(dialectColumn *Column) {}
|
||||
|
||||
func (dd *DefaultColumnHelper) ToColumn(commonColumn *Column) {}
|
||||
|
||||
func (dd *DefaultColumnHelper) FixColumn(column *Column) {}
|
||||
|
||||
func (dd *DefaultColumnHelper) GetScanDestPtr(qc *QueryColumn) any {
|
||||
return &[]byte{}
|
||||
}
|
||||
|
||||
func (dd *DefaultColumnHelper) ConvertScanDestValue(data any, qc *QueryColumn) any {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
colType := qc.SqlColType
|
||||
// 列的数据库类型名
|
||||
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
|
||||
|
||||
stringV := ""
|
||||
if slicePtr, ok := data.(*[]uint8); ok {
|
||||
bytes := *slicePtr
|
||||
if bytes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果类型是bit,则直接返回第一个字节即可
|
||||
if strings.Contains(colDatabaseTypeName, "bit") {
|
||||
return (bytes)[0]
|
||||
}
|
||||
|
||||
if colDatabaseTypeName == "blob" {
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
// 把[]byte数据转成string
|
||||
stringV = string(bytes)
|
||||
} else {
|
||||
stringV = cast.ToString(data)
|
||||
}
|
||||
|
||||
if colType == nil || colType.ScanType() == nil {
|
||||
return stringV
|
||||
}
|
||||
colScanType := strings.ToLower(colType.ScanType().Name())
|
||||
|
||||
if strings.Contains(colScanType, "int") {
|
||||
// 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度
|
||||
if len(stringV) > 16 {
|
||||
return stringV
|
||||
}
|
||||
intV, _ := strconv.Atoi(stringV)
|
||||
switch colType.ScanType().Kind() {
|
||||
case reflect.Int8:
|
||||
return int8(intV)
|
||||
case reflect.Uint8:
|
||||
return uint8(intV)
|
||||
case reflect.Int64:
|
||||
return int64(intV)
|
||||
case reflect.Uint64:
|
||||
return uint64(intV)
|
||||
case reflect.Uint:
|
||||
return uint(intV)
|
||||
default:
|
||||
return intV
|
||||
}
|
||||
}
|
||||
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
|
||||
floatV, _ := strconv.ParseFloat(stringV, 64)
|
||||
return floatV
|
||||
}
|
||||
|
||||
return stringV
|
||||
}
|
||||
|
||||
// DumpHelper 导出辅助方法
|
||||
type DumpHelper interface {
|
||||
BeforeInsert(writer io.Writer, tableName string)
|
||||
@@ -269,3 +96,14 @@ func (dd *DefaultDumpHelper) BeforeInsertSql(quoteSchema string, quoteTableName
|
||||
func (dd *DefaultDumpHelper) AfterInsert(writer io.Writer, tableName string, columns []Column) {
|
||||
writer.Write([]byte("COMMIT;\n"))
|
||||
}
|
||||
|
||||
type SQLGenerator interface {
|
||||
// GenTableDDL 生成建表语句
|
||||
GenTableDDL(table Table, columns []Column, dropBeforeCreate bool) []string
|
||||
|
||||
// GenIndexDDL 生成索引语句
|
||||
GenIndexDDL(table Table, indexs []Index) []string
|
||||
|
||||
// GenInsert 生成插入语句
|
||||
GenInsert(tableName string, columns []Column, values [][]any, duplicateStrategy int) []string
|
||||
}
|
||||
|
||||
@@ -5,30 +5,59 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
metas = make(map[DbType]Meta)
|
||||
metas = make(map[DbType]Meta)
|
||||
metaInit = make(map[DbType]bool)
|
||||
)
|
||||
|
||||
// 注册数据库类型与dbmeta
|
||||
func Register(dt DbType, meta Meta) {
|
||||
metas[dt] = meta
|
||||
}
|
||||
|
||||
// 根据数据库类型获取对应的Meta
|
||||
func GetMeta(dt DbType) Meta {
|
||||
return metas[dt]
|
||||
}
|
||||
|
||||
type DbVersion string
|
||||
|
||||
// 数据库元信息,如获取sql.DB、Dialect等
|
||||
// Meta 数据库元信息,如获取sql.DB、Dialect等
|
||||
type Meta interface {
|
||||
// GetSqlDb 根据数据库信息获取sql.DB
|
||||
GetSqlDb(*DbInfo) (*sql.DB, error)
|
||||
|
||||
// GetDialect 获取数据库方言, 若一些接口(如QuoteIdentifier)不需要DbConn,则可以传nil
|
||||
// GetDialect 获取数据库方言, 若一些接口不需要DbConn,则可以传nil
|
||||
GetDialect(*DbConn) Dialect
|
||||
|
||||
// GetMetadata 获取元数据信息接口
|
||||
// @param *DbConn 数据库连接
|
||||
GetMetadata(*DbConn) Metadata
|
||||
|
||||
// GetDbDataTypes 获取所有数据库对应的数据类型
|
||||
GetDbDataTypes() []*DbDataType
|
||||
|
||||
// GetCommonTypeConverter 获取公共类型转换器,用于迁移与同步
|
||||
GetCommonTypeConverter() CommonTypeConverter
|
||||
}
|
||||
|
||||
// 注册数据库类型与dbmeta
|
||||
func Register(dt DbType, meta Meta) {
|
||||
metas[dt] = meta
|
||||
metaInit[dt] = false
|
||||
}
|
||||
|
||||
// 根据数据库类型获取对应的Meta
|
||||
func GetMeta(dt DbType) Meta {
|
||||
// 未初始化,则进行初始化,如注册数据库类型等。防止未使用到的数据库都被注册
|
||||
if inited := metaInit[dt]; !inited {
|
||||
initMeta(dt, metas[dt])
|
||||
}
|
||||
return metas[dt]
|
||||
}
|
||||
|
||||
// GetDialect 获取数据库方言,如果dialect方法内需要用到dbConn的,则不支持该方法
|
||||
func GetDialect(dt DbType) Dialect {
|
||||
// 创建一个假连接,仅用于调用方言生成sql,不做数据库连接操作
|
||||
meta := GetMeta(dt)
|
||||
dbConn := &DbConn{Info: &DbInfo{
|
||||
Type: dt,
|
||||
Meta: meta,
|
||||
}}
|
||||
return meta.GetDialect(dbConn)
|
||||
}
|
||||
|
||||
// initMeta 初始化数据库类型,如注册数据库类型等
|
||||
func initMeta(dt DbType, meta Meta) {
|
||||
registerColumnDbDataTypes(dt, meta.GetDbDataTypes()...)
|
||||
registerCommonTypeConverter(dt, meta.GetCommonTypeConverter())
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dbi
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
@@ -55,9 +54,6 @@ func (dd *DefaultMetadata) GetDefaultDb() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GenerateSQLStepFunc 生成insert sql的step函数,用于生成insert sql时,每生成100条sql时调用
|
||||
type GenerateSQLStepFunc func(sqlArr []string)
|
||||
|
||||
// 数据库服务实例信息
|
||||
type DbServer struct {
|
||||
Version string `json:"version"` // 版本信息
|
||||
@@ -74,115 +70,16 @@ type Table struct {
|
||||
IndexLength int64 `json:"indexLength"`
|
||||
}
|
||||
|
||||
// 表的列信息
|
||||
type Column struct {
|
||||
TableName string `json:"tableName"` // 表名
|
||||
ColumnName string `json:"columnName"` // 列名
|
||||
DataType ColumnDataType `json:"dataType"` // 数据类型
|
||||
ColumnComment string `json:"columnComment"` // 列备注
|
||||
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否为主键
|
||||
IsIdentity bool `json:"isIdentity"` // 是否自增
|
||||
ColumnDefault string `json:"columnDefault"` // 默认值
|
||||
Nullable bool `json:"nullable"` // 是否可为null
|
||||
CharMaxLength int `json:"charMaxLength"` // 字符最大长度
|
||||
NumPrecision int `json:"numPrecision"` // 精度(总数字位数)
|
||||
NumScale int `json:"numScale"` // 小数点位数
|
||||
Extra collx.M `json:"extra"` // 其他额外信息
|
||||
}
|
||||
|
||||
// 拼接数据类型与长度等。如varchar(2000),decimal(20,2)
|
||||
func (c *Column) GetColumnType() string {
|
||||
// 哪些mysql数据类型不需要添加字段长度
|
||||
if collx.ArrayAnyMatches([]string{"int", "blob", "float", "double", "date", "year", "json"}, string(c.DataType)) {
|
||||
return string(c.DataType)
|
||||
}
|
||||
if c.DataType == "timestamp" {
|
||||
return "timestamp(6)"
|
||||
}
|
||||
if c.CharMaxLength > 0 {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.CharMaxLength)
|
||||
}
|
||||
if c.NumPrecision > 0 {
|
||||
if c.NumScale > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", c.DataType, c.NumPrecision, c.NumScale)
|
||||
} else {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.NumPrecision)
|
||||
}
|
||||
}
|
||||
|
||||
return string(c.DataType)
|
||||
}
|
||||
|
||||
// 表索引信息
|
||||
type Index struct {
|
||||
IndexName string `json:"indexName"` // 索引名
|
||||
ColumnName string `json:"columnName"` // 列名
|
||||
IndexType string `json:"indexType"` // 索引类型
|
||||
IndexComment string `json:"indexComment"` // 备注
|
||||
SeqInIndex int `json:"seqInIndex"`
|
||||
IsUnique bool `json:"isUnique"`
|
||||
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引
|
||||
}
|
||||
|
||||
type ColumnDataType string
|
||||
|
||||
const (
|
||||
CommonTypeVarchar ColumnDataType = "varchar"
|
||||
CommonTypeChar ColumnDataType = "char"
|
||||
CommonTypeText ColumnDataType = "text"
|
||||
CommonTypeBlob ColumnDataType = "blob"
|
||||
CommonTypeLongblob ColumnDataType = "longblob"
|
||||
CommonTypeLongtext ColumnDataType = "longtext"
|
||||
CommonTypeBinary ColumnDataType = "binary"
|
||||
CommonTypeMediumblob ColumnDataType = "mediumblob"
|
||||
CommonTypeMediumtext ColumnDataType = "mediumtext"
|
||||
CommonTypeVarbinary ColumnDataType = "varbinary"
|
||||
|
||||
CommonTypeInt ColumnDataType = "int"
|
||||
CommonTypeBit ColumnDataType = "bit"
|
||||
CommonTypeSmallint ColumnDataType = "smallint"
|
||||
CommonTypeTinyint ColumnDataType = "tinyint"
|
||||
CommonTypeNumber ColumnDataType = "number"
|
||||
CommonTypeBigint ColumnDataType = "bigint"
|
||||
|
||||
CommonTypeDatetime ColumnDataType = "datetime"
|
||||
CommonTypeDate ColumnDataType = "date"
|
||||
CommonTypeTime ColumnDataType = "time"
|
||||
CommonTypeTimestamp ColumnDataType = "timestamp"
|
||||
|
||||
CommonTypeEnum ColumnDataType = "enum"
|
||||
CommonTypeJSON ColumnDataType = "json"
|
||||
)
|
||||
|
||||
type DataType string
|
||||
|
||||
const (
|
||||
DataTypeString DataType = "string"
|
||||
DataTypeNumber DataType = "number"
|
||||
DataTypeDate DataType = "date"
|
||||
DataTypeTime DataType = "time"
|
||||
DataTypeDateTime DataType = "datetime"
|
||||
DataTypeBlob DataType = "blob"
|
||||
)
|
||||
|
||||
// 列数据处理帮助方法
|
||||
type DataHelper interface {
|
||||
// 获取数据对应的类型
|
||||
// @param dbColumnType 数据库原始列类型,如varchar等
|
||||
GetDataType(dbColumnType string) DataType
|
||||
|
||||
// 根据数据类型格式化指定数据
|
||||
FormatData(dbColumnValue any, dataType DataType) string
|
||||
|
||||
// 根据数据类型解析数据为符合要求的指定类型等
|
||||
ParseData(dbColumnValue any, dataType DataType) any
|
||||
|
||||
// WrapValue 根据数据类型包装value,如:
|
||||
// 1.数字型:不需要引号,
|
||||
// 2.文本型:需要用引号包裹,单引号需要转义,换行符转义,
|
||||
// 3.date型:需要格式化成对应的字符串,如:time:hh:mm:ss.SSS date: yyyy-mm-dd datetime:
|
||||
// 4.特殊:oracle date型需要用函数包裹:to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss')
|
||||
WrapValue(dbColumnValue any, dataType DataType) string
|
||||
IndexName string `json:"indexName"` // 索引名
|
||||
ColumnName string `json:"columnName"` // 列名
|
||||
IndexType string `json:"indexType"` // 索引类型
|
||||
IndexComment string `json:"indexComment"` // 备注
|
||||
SeqInIndex int `json:"seqInIndex"`
|
||||
IsUnique bool `json:"isUnique"`
|
||||
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引
|
||||
Extra collx.M `json:"extra"` // 其他额外信息,如索引列的前缀长度等
|
||||
}
|
||||
|
||||
// ------------------------- 元数据sql操作 -------------------------
|
||||
|
||||
@@ -36,6 +36,7 @@ SELECT
|
||||
IF(non_unique, 0, 1) isUnique,
|
||||
SEQ_IN_INDEX seqInIndex,
|
||||
INDEX_COMMENT indexComment,
|
||||
SUB_PART subPart,
|
||||
index_name = 'PRIMARY' as isPrimaryKey
|
||||
FROM
|
||||
information_schema.STATISTICS
|
||||
|
||||
190
server/internal/db/dbm/dbi/quoter.go
Normal file
190
server/internal/db/dbm/dbi/quoter.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Quoter represents a quoter to the SQL identifier. i.e. table name or column name
|
||||
type Quoter struct {
|
||||
Prefix byte
|
||||
Suffix byte
|
||||
IsReserved func(string) bool
|
||||
}
|
||||
|
||||
var (
|
||||
// AlwaysNoReserve always think it's not a reverse word
|
||||
AlwaysNoReserve = func(string) bool { return false }
|
||||
|
||||
// AlwaysReserve always reverse the word
|
||||
AlwaysReserve = func(string) bool { return true }
|
||||
|
||||
// DefaultQuoter represents the default quoter
|
||||
DefaultQuote byte = '"'
|
||||
|
||||
// DefaultQuoter the default quoter
|
||||
DefaultQuoter = Quoter{DefaultQuote, DefaultQuote, AlwaysReserve}
|
||||
)
|
||||
|
||||
// IsEmpty return true if no prefix and suffix
|
||||
func (q Quoter) IsEmpty() bool {
|
||||
return q.Prefix == 0 && q.Suffix == 0
|
||||
}
|
||||
|
||||
// Quote quote a string
|
||||
func (q Quoter) Quote(s string) string {
|
||||
var buf strings.Builder
|
||||
_ = q.QuoteTo(&buf, s)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Strings quotes a slice of string
|
||||
func (q Quoter) Quotes(s []string) []string {
|
||||
return collx.ArrayMap[string, string](s, func(val string) string {
|
||||
return q.Quote(val)
|
||||
})
|
||||
}
|
||||
|
||||
// QuoteTo quotes the identifier. if the quotes are [ and ]
|
||||
//
|
||||
// name -> [name]
|
||||
// [name] -> [name]
|
||||
// schema.name -> [schema].[name]
|
||||
// [schema].name -> [schema].[name]
|
||||
// schema.[name] -> [schema].[name]
|
||||
// name AS a -> [name] AS a
|
||||
// schema.name AS a -> [schema].[name] AS a
|
||||
func (q Quoter) QuoteTo(buf *strings.Builder, value string) error {
|
||||
var i int
|
||||
for i < len(value) {
|
||||
start := findStart(value, i)
|
||||
if start > i {
|
||||
if _, err := buf.WriteString(value[i:start]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if start == len(value) {
|
||||
return nil
|
||||
}
|
||||
|
||||
nextEnd := findWord(value, start)
|
||||
if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil {
|
||||
return err
|
||||
}
|
||||
i = nextEnd
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Trim removes quotes from s
|
||||
func (q Quoter) Trim(s string) string {
|
||||
if len(s) < 2 {
|
||||
return s
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch {
|
||||
case i == 0 && s[i] == q.Prefix:
|
||||
case i == len(s)-1 && s[i] == q.Suffix:
|
||||
case s[i] == q.Suffix && s[i+1] == '.':
|
||||
case s[i] == q.Prefix && s[i-1] == '.':
|
||||
default:
|
||||
buf.WriteByte(s[i])
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Join joins a slice with quoters
|
||||
func (q Quoter) Join(a []string, separator string) string {
|
||||
var b strings.Builder
|
||||
_ = q.JoinWrite(&b, a, separator)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// JoinWrite writes quoted content to a builder
|
||||
func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
|
||||
if len(a) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := len(sep) * (len(a) - 1)
|
||||
for i := 0; i < len(a); i++ {
|
||||
n += len(a[i])
|
||||
}
|
||||
|
||||
b.Grow(n)
|
||||
for i, s := range a {
|
||||
if i > 0 {
|
||||
if _, err := b.WriteString(sep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := q.QuoteTo(b, strings.TrimSpace(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error {
|
||||
if (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) ||
|
||||
q.IsEmpty() || !q.IsReserved(word) || word == "*" {
|
||||
if _, err := buf.WriteString(word); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := buf.WriteByte(q.Prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := buf.WriteString(word); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.WriteByte(q.Suffix)
|
||||
}
|
||||
|
||||
func findWord(v string, start int) int {
|
||||
for j := start; j < len(v); j++ {
|
||||
switch v[j] {
|
||||
case '.', ' ':
|
||||
return j
|
||||
}
|
||||
}
|
||||
return len(v)
|
||||
}
|
||||
|
||||
func findStart(value string, start int) int {
|
||||
if value[start] == '.' {
|
||||
return start + 1
|
||||
}
|
||||
if value[start] != ' ' {
|
||||
return start
|
||||
}
|
||||
|
||||
k := -1
|
||||
for j := start; j < len(value); j++ {
|
||||
if value[j] != ' ' {
|
||||
k = j
|
||||
break
|
||||
}
|
||||
}
|
||||
if k == -1 {
|
||||
return len(value)
|
||||
}
|
||||
|
||||
if k+1 < len(value) &&
|
||||
(value[k] == 'A' || value[k] == 'a') &&
|
||||
(value[k+1] == 'S' || value[k+1] == 's') {
|
||||
k += 2
|
||||
}
|
||||
|
||||
for j := k; j < len(value); j++ {
|
||||
if value[j] != ' ' {
|
||||
return j
|
||||
}
|
||||
}
|
||||
return len(value)
|
||||
}
|
||||
106
server/internal/db/dbm/dbi/quoter_test.go
Normal file
106
server/internal/db/dbm/dbi/quoter_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestQuoteTo(t *testing.T) {
|
||||
var (
|
||||
quoter = Quoter{'[', ']', AlwaysReserve}
|
||||
kases = []struct {
|
||||
expected string
|
||||
value string
|
||||
}{
|
||||
{"[table]", "table"},
|
||||
{"[table]", "[table]"},
|
||||
{`[table].*`, `[table].*`},
|
||||
{"[schema].[table]", "schema.table"},
|
||||
{`["schema].[table"]`, `"schema.table"`},
|
||||
{"[schema].[table] AS [table]", "schema.table AS table"},
|
||||
{" [table]", " table"},
|
||||
{" [table]", " table"},
|
||||
{"[table] ", "table "},
|
||||
{"[table] ", "table "},
|
||||
{" [table] ", " table "},
|
||||
{" [table] ", " table "},
|
||||
}
|
||||
)
|
||||
|
||||
for _, v := range kases {
|
||||
t.Run(v.value, func(t *testing.T) {
|
||||
buf := &strings.Builder{}
|
||||
err := quoter.QuoteTo(buf, v.value)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReversedQuoteTo(t *testing.T) {
|
||||
var (
|
||||
quoter = Quoter{'[', ']', func(s string) bool {
|
||||
return s == "table"
|
||||
}}
|
||||
kases = []struct {
|
||||
expected string
|
||||
value string
|
||||
}{
|
||||
{"[table]", "table"},
|
||||
{"[table].*", `[table].*`},
|
||||
{`"table"`, `"table"`},
|
||||
{"schema.[table]", "schema.table"},
|
||||
{"[schema].[table]", `[schema].table`},
|
||||
{"schema.[table]", `schema.[table]`},
|
||||
{"[schema].[table]", `[schema].[table]`},
|
||||
{`"schema.table"`, `"schema.table"`},
|
||||
{"schema.[table] AS table1", "schema.table AS table1"},
|
||||
}
|
||||
)
|
||||
|
||||
for _, v := range kases {
|
||||
t.Run(v.value, func(t *testing.T) {
|
||||
buf := &strings.Builder{}
|
||||
quoter.QuoteTo(buf, v.value)
|
||||
assert.EqualValues(t, v.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
cols := []string{"f1", "f2", "f3"}
|
||||
quoter := Quoter{'[', ']', AlwaysReserve}
|
||||
|
||||
assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))
|
||||
|
||||
assert.EqualValues(t, "[a].*,[b].[c]", quoter.Join([]string{"a.*", " b.c"}, ","))
|
||||
|
||||
assert.EqualValues(t, "[b] [a]", quoter.Join([]string{"b a"}, ","))
|
||||
|
||||
assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", "))
|
||||
|
||||
quoter.IsReserved = AlwaysNoReserve
|
||||
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
|
||||
}
|
||||
|
||||
func TestQuotes(t *testing.T) {
|
||||
cols := []string{"f1", "f2", "t3.f3", "t4.*"}
|
||||
quoter := Quoter{'[', ']', AlwaysReserve}
|
||||
|
||||
quotedCols := quoter.Quotes(cols)
|
||||
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]", "[t4].*"}, quotedCols)
|
||||
}
|
||||
|
||||
func TestTrim(t *testing.T) {
|
||||
kases := map[string]string{
|
||||
"[table_name]": "table_name",
|
||||
"[schema].[table_name]": "schema.table_name",
|
||||
}
|
||||
|
||||
for src, dst := range kases {
|
||||
assert.EqualValues(t, src, DefaultQuoter.Trim(src))
|
||||
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
|
||||
}
|
||||
}
|
||||
89
server/internal/db/dbm/dbi/stmt.go
Normal file
89
server/internal/db/dbm/dbi/stmt.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/pkg/logx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StmtType string
|
||||
|
||||
const (
|
||||
StmtTypeSelect StmtType = "select"
|
||||
StmtTypeInsert StmtType = "insert"
|
||||
StmtTypeUpdate StmtType = "update"
|
||||
StmtTypeDelete StmtType = "delete"
|
||||
StmtTypeDDL StmtType = "ddl"
|
||||
)
|
||||
|
||||
// GenTableDDL 生成通用表DDL
|
||||
func GenTableDDL(dialect Dialect, md Metadata, tableName string, dropBeforeCreate bool) (string, error) {
|
||||
// 1.获取表信息
|
||||
tbs, err := md.GetTables(tableName)
|
||||
if len(tbs) == 0 {
|
||||
logx.Errorf("get table error: %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
table := tbs[0]
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := md.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("get columns error: %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
|
||||
sqlGenerator := dialect.GetSQLGenerator()
|
||||
|
||||
tableDDLArr := sqlGenerator.GenTableDDL(table, columns, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := md.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("get indexs error: %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, sqlGenerator.GenIndexDDL(table, indexs)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
}
|
||||
|
||||
// GenCommonInsert 生成通用insert sql
|
||||
//
|
||||
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
|
||||
func GenCommonInsert(dialect Dialect, dbType DbType, tableName string, columns []Column, values [][]any) string {
|
||||
quote := dialect.Quoter().Quote
|
||||
columnStr, valuesStrs := GenInsertSqlColumnAndValues(dialect, dbType, columns, values)
|
||||
|
||||
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
|
||||
return fmt.Sprintf("INSERT INTO %s %s VALUES \n%s", quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
|
||||
}
|
||||
|
||||
// GenInsertSqlColumnAndValues 生成insert sql对应的 columes信息和values信息
|
||||
//
|
||||
// columnsStr -> (column1, column2, column3, ...)
|
||||
// valuesStrs -> ['(value1, value2, value3, ...)', '(value1, value2, value3, ...)', ...]
|
||||
func GenInsertSqlColumnAndValues(dialect Dialect, dbType DbType, columns []Column, values [][]any) (columnsStr string, valuesStrs []string) {
|
||||
quote := dialect.Quoter().Quote
|
||||
|
||||
columnNames := make([]string, 0, len(columns))
|
||||
columnTypes := make([]*DbDataType, len(columns))
|
||||
|
||||
strValueArr := make([]string, 0, len(values))
|
||||
|
||||
for i, column := range columns {
|
||||
columnNames = append(columnNames, quote(column.ColumnName))
|
||||
columnType := GetDbDataType(dbType, column.DataType)
|
||||
columnTypes[i] = columnType
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
vs := make([]string, 0, len(value))
|
||||
for i, v := range value {
|
||||
vs = append(vs, columnTypes[i].DataType.SQLValue(v))
|
||||
}
|
||||
strValueArr = append(strValueArr, fmt.Sprintf("(%s)", strings.Join(vs, ", ")))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(%s)", strings.Join(columnNames, ", ")), strValueArr
|
||||
}
|
||||
154
server/internal/db/dbm/dbi/transfer.go
Normal file
154
server/internal/db/dbm/dbi/transfer.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type CommonDbDataType int
|
||||
|
||||
// common column type enum
|
||||
const (
|
||||
CTVarchar CommonDbDataType = iota
|
||||
CTChar
|
||||
CTText
|
||||
CTMediumtext
|
||||
CTLongtext
|
||||
|
||||
CTBit // 1 bit
|
||||
CTInt1 // 1字节 -128~127
|
||||
CTInt2 // 2字节 -32768~32767
|
||||
CTInt4 // 4字节 -2147483648~2147483647
|
||||
CTInt8 // 8字节 -9223372036854775808~9223372036854775807
|
||||
CTNumeric
|
||||
CTDecimal
|
||||
|
||||
CTUnsignedInt8
|
||||
CTUnsignedInt4
|
||||
CTUnsignedInt2
|
||||
CTUnsignedInt1
|
||||
|
||||
CTDate
|
||||
CTTime
|
||||
CTDateTime
|
||||
CTTimestamp
|
||||
|
||||
CTBinary
|
||||
CTVarbinary
|
||||
CTMediumblob
|
||||
CTBlob
|
||||
CTLongblob
|
||||
|
||||
CTEnum
|
||||
CTJSON
|
||||
)
|
||||
|
||||
type CommonTypeConverter interface {
|
||||
Varchar(*Column) *DbDataType
|
||||
Char(*Column) *DbDataType
|
||||
Text(*Column) *DbDataType
|
||||
Mediumtext(*Column) *DbDataType
|
||||
Longtext(*Column) *DbDataType
|
||||
|
||||
Bit(*Column) *DbDataType
|
||||
Int1(*Column) *DbDataType
|
||||
Int2(*Column) *DbDataType
|
||||
Int4(*Column) *DbDataType
|
||||
Int8(*Column) *DbDataType
|
||||
Numeric(*Column) *DbDataType
|
||||
Decimal(*Column) *DbDataType
|
||||
|
||||
UnsignedInt8(*Column) *DbDataType
|
||||
UnsignedInt4(*Column) *DbDataType
|
||||
UnsignedInt2(*Column) *DbDataType
|
||||
UnsignedInt1(*Column) *DbDataType
|
||||
|
||||
Date(*Column) *DbDataType
|
||||
Time(*Column) *DbDataType
|
||||
Datetime(*Column) *DbDataType
|
||||
Timestamp(*Column) *DbDataType
|
||||
|
||||
Binary(*Column) *DbDataType
|
||||
Varbinary(*Column) *DbDataType
|
||||
Mediumblob(*Column) *DbDataType
|
||||
Blob(*Column) *DbDataType
|
||||
Longblob(*Column) *DbDataType
|
||||
|
||||
Enum(*Column) *DbDataType
|
||||
JSON(*Column) *DbDataType
|
||||
}
|
||||
|
||||
var (
|
||||
commonTypeConverters = make(map[DbType]map[CommonDbDataType]func(*Column) *DbDataType) // 公共列转换器
|
||||
)
|
||||
|
||||
// registerCommonTypeConverter 注册公共列转换器
|
||||
func registerCommonTypeConverter(dbType DbType, ctc CommonTypeConverter) {
|
||||
if ctc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
cts := make(map[CommonDbDataType]func(*Column) *DbDataType)
|
||||
cts[CTVarchar] = ctc.Varchar
|
||||
cts[CTChar] = ctc.Char
|
||||
cts[CTText] = ctc.Text
|
||||
cts[CTMediumtext] = ctc.Mediumtext
|
||||
cts[CTLongtext] = ctc.Longtext
|
||||
|
||||
cts[CTBit] = ctc.Bit
|
||||
cts[CTInt1] = ctc.Int1
|
||||
cts[CTInt2] = ctc.Int2
|
||||
cts[CTInt4] = ctc.Int4
|
||||
cts[CTInt8] = ctc.Int8
|
||||
cts[CTNumeric] = ctc.Numeric
|
||||
cts[CTDecimal] = ctc.Decimal
|
||||
cts[CTUnsignedInt8] = ctc.UnsignedInt8
|
||||
cts[CTUnsignedInt4] = ctc.UnsignedInt4
|
||||
cts[CTUnsignedInt2] = ctc.UnsignedInt2
|
||||
cts[CTUnsignedInt1] = ctc.UnsignedInt1
|
||||
|
||||
cts[CTDate] = ctc.Date
|
||||
cts[CTTime] = ctc.Time
|
||||
cts[CTDateTime] = ctc.Datetime
|
||||
cts[CTTimestamp] = ctc.Timestamp
|
||||
|
||||
cts[CTBinary] = ctc.Binary
|
||||
cts[CTVarbinary] = ctc.Varbinary
|
||||
cts[CTMediumblob] = ctc.Mediumblob
|
||||
cts[CTBlob] = ctc.Blob
|
||||
cts[CTLongblob] = ctc.Longblob
|
||||
|
||||
cts[CTEnum] = ctc.Enum
|
||||
cts[CTJSON] = ctc.JSON
|
||||
|
||||
commonTypeConverters[dbType] = cts
|
||||
}
|
||||
|
||||
// ConvToTargetDbColumn 转换至异构数据库对应的列信息
|
||||
func ConvToTargetDbColumn(srcDbType DbType, targetDbType DbType, targetDialect Dialect, column *Column) error {
|
||||
// 同类型数据库,不转换
|
||||
if srcDbType == targetDbType {
|
||||
return nil
|
||||
}
|
||||
|
||||
srcMap := commonTypeConverters[srcDbType]
|
||||
if srcMap == nil {
|
||||
return fmt.Errorf("src database type [%s] not suport transfer", srcDbType)
|
||||
}
|
||||
|
||||
targetMap := commonTypeConverters[targetDbType]
|
||||
if targetMap == nil {
|
||||
return fmt.Errorf("target database type [%s] not suport transfer", targetDbType)
|
||||
}
|
||||
|
||||
srcDataType := GetDbDataType(srcDbType, column.DataType)
|
||||
|
||||
// 获取目标数据库的数据类型,并进行可能存在的列信息修复,如长度、精度等
|
||||
targetDbDataType := targetMap[srcDataType.CommonType](column)
|
||||
if targetDbDataType == nil {
|
||||
return fmt.Errorf("target database type [%s] not suport transfer, src data type [%d]", targetDbType, srcDataType.CommonType)
|
||||
}
|
||||
|
||||
// 替换为目标数据库的数据类型
|
||||
column.DataType = targetDbDataType.Name
|
||||
return nil
|
||||
}
|
||||
9
server/internal/db/dbm/dbi/utils.go
Normal file
9
server/internal/db/dbm/dbi/utils.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func QuoteEscape(str string) string {
|
||||
return strings.Replace(str, `'`, `''`, -1)
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
|
||||
logx.Info(fmt.Sprintf("delete db conn cache, id = %s", key))
|
||||
value.(*dbi.DbConn).Close()
|
||||
})
|
||||
|
||||
|
||||
106
server/internal/db/dbm/dm/column.go
Normal file
106
server/internal/db/dbm/dm/column.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"strings"
|
||||
|
||||
"gitee.com/chunanyong/dm"
|
||||
)
|
||||
|
||||
var (
|
||||
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar)
|
||||
VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
|
||||
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
|
||||
LONGVARCHAR = dbi.NewDbDataType("LONGVARCHAR", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
IMAGE = dbi.NewDbDataType("IMAGE", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
LONGVARBINARY = dbi.NewDbDataType("LONGVARBINARY", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
CLOB = dbi.NewDbDataType("CLOB", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
|
||||
BLOB = dbi.NewDbDataType("BLOB", dbi.DTBytes).WithCT(dbi.CTBlob)
|
||||
|
||||
NUMERIC = dbi.NewDbDataType("NUMERIC", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
DECIMAL = dbi.NewDbDataType("DECIMAL", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
NUMBER = dbi.NewDbDataType("NUMBER", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
INTEGER = dbi.NewDbDataType("INTEGER", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
INT = dbi.NewDbDataType("INT", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
BIGINT = dbi.NewDbDataType("BIGINT", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
TINYINT = dbi.NewDbDataType("TINYINT", dbi.DTInt8).WithCT(dbi.CTInt1)
|
||||
BYTE = dbi.NewDbDataType("BYTE", dbi.DTInt8).WithCT(dbi.CTInt1)
|
||||
SMALLINT = dbi.NewDbDataType("SMALLINT", dbi.DTInt16).WithCT(dbi.CTInt2)
|
||||
BIT = dbi.NewDbDataType("BIT", dbi.DTBit).WithCT(dbi.CTBit)
|
||||
DOUBLE = dbi.NewDbDataType("DOUBLE", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
FLOAT = dbi.NewDbDataType("FLOAT", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
|
||||
TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
|
||||
ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等
|
||||
ST_LINESTRING = dbi.NewDbDataType("ST_LINESTRING", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条或多条连续的线段
|
||||
ST_GEOMCOLLECTION = dbi.NewDbDataType("ST_GEOMCOLLECTION", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个几何对象集合,可以包含多个不同类型的几何对象
|
||||
ST_GEOMETRY = dbi.NewDbDataType("ST_GEOMETRY", DTDmStruct).WithCT(dbi.CTVarchar) // 通用几何对象类型,可以表示点、线、面等任何几何形状
|
||||
ST_MULTICURVE = dbi.NewDbDataType("ST_MULTICURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个曲线的集合
|
||||
ST_MULTILINESTRING = dbi.NewDbDataType("ST_MULTILINESTRING", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个线串的集合
|
||||
ST_MULTIPOINT = dbi.NewDbDataType("ST_MULTIPOINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个点的集合
|
||||
ST_MULTIPOLYGON = dbi.NewDbDataType("ST_MULTIPOLYGON", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个曲线的集合
|
||||
ST_MULTISURFACE = dbi.NewDbDataType("ST_MULTISURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个表面的集合
|
||||
ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点
|
||||
ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形
|
||||
ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面
|
||||
)
|
||||
|
||||
var DTDmStruct = &dbi.DataType{
|
||||
Name: "dm_struct",
|
||||
Valuer: func() dbi.Valuer {
|
||||
return &dmStructValuer{
|
||||
DefaultValuer: new(dbi.DefaultValuer[dm.DmStruct]),
|
||||
}
|
||||
},
|
||||
SQLValue: dbi.SQLValueString,
|
||||
}
|
||||
|
||||
type dmStructValuer struct {
|
||||
*dbi.DefaultValuer[dm.DmStruct]
|
||||
}
|
||||
|
||||
func (s *dmStructValuer) Value() any {
|
||||
if !s.ValuePtr.Valid {
|
||||
return ""
|
||||
}
|
||||
return ParseDmStruct(s.ValuePtr)
|
||||
}
|
||||
|
||||
func ParseDmStruct(dmStruct *dm.DmStruct) string {
|
||||
if !dmStruct.Valid {
|
||||
return ""
|
||||
}
|
||||
|
||||
name, _ := dmStruct.GetSQLTypeName()
|
||||
attributes, _ := dmStruct.GetAttributes()
|
||||
arr := make([]string, len(attributes))
|
||||
arr = append(arr, name, "(")
|
||||
|
||||
for i, v := range attributes {
|
||||
if blb, ok1 := v.(*dm.DmBlob); ok1 {
|
||||
if blb.Valid {
|
||||
length, _ := blb.GetLength()
|
||||
var dest = make([]byte, length)
|
||||
_, _ = blb.Read(dest)
|
||||
// 2进制转16进制字符串
|
||||
hexStr := hex.EncodeToString(dest)
|
||||
arr = append(arr, "0x", strings.ToUpper(hexStr))
|
||||
}
|
||||
} else {
|
||||
arr = append(arr, anyx.ToString(v))
|
||||
}
|
||||
if i < len(attributes)-1 {
|
||||
arr = append(arr, ",")
|
||||
}
|
||||
}
|
||||
|
||||
arr = append(arr, ")")
|
||||
return strings.Join(arr, "")
|
||||
}
|
||||
@@ -1,11 +1,8 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,109 +16,6 @@ type DMDialect struct {
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
// 执行批量insert sql
|
||||
// insert into "table_name" ("column1", "column2", ...) values (value1, value2, ...)
|
||||
|
||||
// 无需处理重复数据,直接执行批量insert
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 {
|
||||
return dd.batchInsertSimple(tx, tableName, columns, values)
|
||||
} else { // 执行MERGE INTO语句
|
||||
return dd.batchInsertMerge(tx, tableName, columns, values)
|
||||
}
|
||||
}
|
||||
|
||||
func (dd *DMDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
|
||||
// 生成占位符字符串:如:(?,?)
|
||||
// 重复字符串并用逗号连接
|
||||
repeated := strings.Repeat("?,", len(columns))
|
||||
// 去除最后一个逗号,占位符由括号包裹
|
||||
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
|
||||
|
||||
identityInsert := fmt.Sprintf("set identity_insert \"%s\" on;", tableName)
|
||||
|
||||
sqlTemp := fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, dd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
effRows := 0
|
||||
// 设置允许填充自增列之后,显示指定列名可以插入自增列
|
||||
for _, value := range values {
|
||||
// 达梦数据库只能一条条的执行insert
|
||||
res, err := dd.dc.TxExec(tx, sqlTemp, value...)
|
||||
if err != nil {
|
||||
logx.Errorf("执行sql失败:%s, sql: [ %s ]", err.Error(), sqlTemp)
|
||||
return 0, err
|
||||
}
|
||||
effRows += int(res)
|
||||
}
|
||||
// 执行批量insert sql
|
||||
return int64(effRows), nil
|
||||
}
|
||||
|
||||
func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
|
||||
// 查询主键字段
|
||||
uniqueCols := make([]string, 0)
|
||||
caseSqls := make([]string, 0)
|
||||
metadata := dd.dc.GetMetadata()
|
||||
tableCols, _ := metadata.GetColumns(tableName)
|
||||
identityCols := make([]string, 0)
|
||||
for _, col := range tableCols {
|
||||
if col.IsPrimaryKey {
|
||||
uniqueCols = append(uniqueCols, col.ColumnName)
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", dd.QuoteIdentifier(col.ColumnName), dd.QuoteIdentifier(col.ColumnName)))
|
||||
}
|
||||
if col.IsIdentity {
|
||||
// 自增字段不放入insert内,即使是设置了identity_insert on也不起作用
|
||||
identityCols = append(identityCols, dd.QuoteIdentifier(col.ColumnName))
|
||||
}
|
||||
}
|
||||
// 查询唯一索引涉及到的字段,并组装到match条件内
|
||||
indexs, _ := metadata.GetTableIndex(tableName)
|
||||
for _, index := range indexs {
|
||||
if index.IsUnique {
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
tmp := make([]string, 0)
|
||||
for _, col := range cols {
|
||||
uniqueCols = append(uniqueCols, col)
|
||||
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dd.QuoteIdentifier(col), dd.QuoteIdentifier(col)))
|
||||
}
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
|
||||
}
|
||||
}
|
||||
|
||||
// 重复数据处理策略
|
||||
phs := make([]string, 0)
|
||||
insertVals := make([]string, 0)
|
||||
upds := make([]string, 0)
|
||||
insertCols := make([]string, 0)
|
||||
for _, column := range columns {
|
||||
phs = append(phs, fmt.Sprintf("? %s", column))
|
||||
if !collx.ArrayContains(uniqueCols, dd.RemoveQuote(column)) {
|
||||
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
|
||||
}
|
||||
if !collx.ArrayContains(identityCols, column) {
|
||||
insertCols = append(insertCols, column)
|
||||
insertVals = append(insertVals, fmt.Sprintf("T2.%s", column))
|
||||
}
|
||||
|
||||
}
|
||||
t2s := make([]string, 0)
|
||||
for i := 0; i < len(values); i++ {
|
||||
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
|
||||
}
|
||||
t2 := strings.Join(t2s, " UNION ALL ")
|
||||
|
||||
sqlTemp := "MERGE INTO " + dd.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
|
||||
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
|
||||
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
|
||||
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
for _, v := range values {
|
||||
args = append(args, v...)
|
||||
}
|
||||
|
||||
return dd.dc.TxExec(tx, sqlTemp, args...)
|
||||
}
|
||||
|
||||
func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
tableName := copy.TableName
|
||||
metadata := dd.dc.GetMetadata()
|
||||
@@ -162,119 +56,13 @@ func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
tbName := dd.QuoteIdentifier(tableInfo.TableName)
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("create table %s (", tbName)
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, dd.QuoteIdentifier(column.ColumnName))
|
||||
}
|
||||
fields = append(fields, dd.genColumnBasicSql(column))
|
||||
if column.ColumnComment != "" {
|
||||
comment := dd.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, dd.QuoteIdentifier(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
tableCommentSql := ""
|
||||
if tableInfo.TableComment != "" {
|
||||
comment := dd.QuoteEscape(tableInfo.TableComment)
|
||||
tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment)
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
sqls := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = dd.QuoteIdentifier(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, dd.QuoteIdentifier(index.IndexName), dd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GetDumpHelper() dbi.DumpHelper {
|
||||
return new(DumpHelper)
|
||||
}
|
||||
|
||||
func (dd *DMDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
colName := dd.QuoteIdentifier(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " IDENTITY"
|
||||
func (sd *DMDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{
|
||||
Dialect: sd,
|
||||
Metadata: sd.dc.GetMetadata(),
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
|
||||
return columnSql
|
||||
}
|
||||
|
||||
@@ -1,289 +1,11 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitee.com/chunanyong/dm"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
|
||||
// 日期类型
|
||||
dateRegexp = regexp.MustCompile(`(?i)date`)
|
||||
// 时间类型
|
||||
timeRegexp = regexp.MustCompile(`(?i)time`)
|
||||
|
||||
// 达梦数据类型 对应 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
|
||||
"CHAR": dbi.CommonTypeChar, // 字符数据类型
|
||||
"VARCHAR": dbi.CommonTypeVarchar,
|
||||
"TEXT": dbi.CommonTypeText,
|
||||
"LONG": dbi.CommonTypeText,
|
||||
"LONGVARCHAR": dbi.CommonTypeLongtext,
|
||||
"IMAGE": dbi.CommonTypeLongtext,
|
||||
"LONGVARBINARY": dbi.CommonTypeLongtext,
|
||||
"BLOB": dbi.CommonTypeBlob,
|
||||
"CLOB": dbi.CommonTypeText,
|
||||
"NUMERIC": dbi.CommonTypeNumber, // 精确数值数据类型
|
||||
"DECIMAL": dbi.CommonTypeNumber,
|
||||
"NUMBER": dbi.CommonTypeNumber,
|
||||
"INTEGER": dbi.CommonTypeInt,
|
||||
"INT": dbi.CommonTypeInt,
|
||||
"BIGINT": dbi.CommonTypeBigint,
|
||||
"TINYINT": dbi.CommonTypeTinyint,
|
||||
"BYTE": dbi.CommonTypeTinyint,
|
||||
"SMALLINT": dbi.CommonTypeSmallint,
|
||||
"BIT": dbi.CommonTypeTinyint,
|
||||
"DOUBLE": dbi.CommonTypeNumber, // 近似数值类型
|
||||
"FLOAT": dbi.CommonTypeNumber,
|
||||
"DATE": dbi.CommonTypeDate, // 一般日期时间数据类型
|
||||
"TIME": dbi.CommonTypeTime,
|
||||
"TIMESTAMP": dbi.CommonTypeTimestamp,
|
||||
}
|
||||
|
||||
// 公共数据类型 对应 达梦数据类型
|
||||
dmColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "VARCHAR",
|
||||
dbi.CommonTypeChar: "CHAR",
|
||||
dbi.CommonTypeText: "TEXT",
|
||||
dbi.CommonTypeBlob: "BLOB",
|
||||
dbi.CommonTypeLongblob: "TEXT",
|
||||
dbi.CommonTypeLongtext: "TEXT",
|
||||
dbi.CommonTypeBinary: "TEXT",
|
||||
dbi.CommonTypeMediumblob: "TEXT",
|
||||
dbi.CommonTypeMediumtext: "TEXT",
|
||||
dbi.CommonTypeVarbinary: "TEXT",
|
||||
dbi.CommonTypeInt: "INT",
|
||||
dbi.CommonTypeSmallint: "SMALLINT",
|
||||
dbi.CommonTypeTinyint: "TINYINT",
|
||||
dbi.CommonTypeNumber: "NUMBER",
|
||||
dbi.CommonTypeBigint: "BIGINT",
|
||||
dbi.CommonTypeDatetime: "TIMESTAMP",
|
||||
dbi.CommonTypeDate: "DATE",
|
||||
dbi.CommonTypeTime: "DATE",
|
||||
dbi.CommonTypeTimestamp: "TIMESTAMP",
|
||||
dbi.CommonTypeEnum: "TEXT",
|
||||
dbi.CommonTypeJSON: "TEXT",
|
||||
}
|
||||
|
||||
dmStructTypes = map[string]bool{
|
||||
"ST_CURVE": true, // 表示一条曲线,可以是圆弧、抛物线等
|
||||
"ST_LINESTRING": true, // 表示一条或多条连续的线段
|
||||
"ST_GEOMCOLLECTION": true, // 表示一个几何对象集合,可以包含多个不同类型的几何对象
|
||||
"ST_GEOMETRY": true, // 通用几何对象类型,可以表示点、线、面等任何几何形状
|
||||
"ST_MULTICURVE": true, // 表示多个曲线的集合
|
||||
"ST_MULTILINESTRING": true, // 表示多个线串的集合
|
||||
"ST_MULTIPOINT": true, // 表示多个点的集合
|
||||
"ST_MULTIPOLYGON": true, // 表示多个多边形的集合
|
||||
"ST_MULTISURFACE": true, // 表示多个表面的集合
|
||||
"ST_POINT": true, // 表示一个点
|
||||
"ST_POLYGON": true, // 表示一个多边形
|
||||
"ST_SURFACE": true, // 表示一个表面,通常是一个多边形
|
||||
}
|
||||
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
if dateRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
if timeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
str := anyx.ToString(dbColumnValue)
|
||||
switch dataType {
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateOnly)
|
||||
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.TimeOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型
|
||||
_, ok := dbColumnValue.(string)
|
||||
if ok {
|
||||
if dataType == dbi.DataTypeDateTime {
|
||||
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeDate {
|
||||
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeTime {
|
||||
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
// 翻译为通用数据库类型
|
||||
dataType := dialectColumn.DataType
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
if t1 == "" {
|
||||
dialectColumn.DataType = dbi.CommonTypeVarchar
|
||||
dialectColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
dialectColumn.DataType = t1
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
|
||||
ctype := dmColumnTypeMap[commonColumn.DataType]
|
||||
|
||||
if ctype == "" {
|
||||
commonColumn.DataType = "VARCHAR"
|
||||
commonColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
commonColumn.DataType = dbi.ColumnDataType(ctype)
|
||||
ch.FixColumn(commonColumn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
// 如果是date,不设长度
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
} else
|
||||
// 如果是char且长度未设置,则默认长度2000
|
||||
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(string(column.DataType))) && column.CharMaxLength == 0 {
|
||||
column.CharMaxLength = 2000
|
||||
}
|
||||
}
|
||||
|
||||
func (dd *ColumnHelper) GetScanDestPtr(qc *dbi.QueryColumn) any {
|
||||
if dmStructTypes[strings.ToUpper(qc.Type)] {
|
||||
return &dm.DmStruct{}
|
||||
}
|
||||
return dd.DefaultColumnHelper.GetScanDestPtr(qc)
|
||||
}
|
||||
|
||||
func (dd *ColumnHelper) ConvertScanDestValue(data any, qc *dbi.QueryColumn) any {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 达梦特殊数据类型
|
||||
if dmStruct, ok := data.(*dm.DmStruct); ok {
|
||||
return ParseDmStruct(dmStruct)
|
||||
}
|
||||
|
||||
return dd.DefaultColumnHelper.ConvertScanDestValue(data, qc)
|
||||
}
|
||||
|
||||
func ParseDmStruct(dmStruct *dm.DmStruct) string {
|
||||
if !dmStruct.Valid {
|
||||
return ""
|
||||
}
|
||||
|
||||
name, _ := dmStruct.GetSQLTypeName()
|
||||
attributes, _ := dmStruct.GetAttributes()
|
||||
arr := make([]string, len(attributes))
|
||||
arr = append(arr, name, "(")
|
||||
|
||||
for i, v := range attributes {
|
||||
if blb, ok1 := v.(*dm.DmBlob); ok1 {
|
||||
if blb.Valid {
|
||||
length, _ := blb.GetLength()
|
||||
var dest = make([]byte, length)
|
||||
_, _ = blb.Read(dest)
|
||||
// 2进制转16进制字符串
|
||||
hexStr := hex.EncodeToString(dest)
|
||||
arr = append(arr, "0x", strings.ToUpper(hexStr))
|
||||
}
|
||||
} else {
|
||||
arr = append(arr, anyx.ToString(v))
|
||||
}
|
||||
if i < len(attributes)-1 {
|
||||
arr = append(arr, ",")
|
||||
}
|
||||
}
|
||||
|
||||
arr = append(arr, ")")
|
||||
return strings.Join(arr, "")
|
||||
}
|
||||
|
||||
type DumpHelper struct {
|
||||
}
|
||||
|
||||
|
||||
@@ -4,14 +4,19 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dbi.Register(dbi.DbTypeDM, new(Meta))
|
||||
dbi.Register(DbTypeDM, new(Meta))
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeDM dbi.DbType = "dm"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
}
|
||||
|
||||
@@ -49,3 +54,17 @@ func (dm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
dc: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB,
|
||||
BLOB,
|
||||
NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT,
|
||||
TIME, DATE, TIMESTAMP,
|
||||
ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING,
|
||||
ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE,
|
||||
)
|
||||
}
|
||||
|
||||
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
@@ -55,7 +54,7 @@ func (dd *DMMetadata) GetDbNames() ([]string, error) {
|
||||
func (dd *DMMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := dd.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -89,7 +88,7 @@ func (dd *DMMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := dd.dc.GetDialect()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
|
||||
@@ -97,13 +96,12 @@ func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnHelper := dd.dc.GetDialect().GetColumnHelper()
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["TABLE_NAME"]),
|
||||
ColumnName: cast.ToString(re["COLUMN_NAME"]),
|
||||
DataType: dbi.ColumnDataType(anyx.ToString(re["DATA_TYPE"])),
|
||||
DataType: anyx.ToString(re["DATA_TYPE"]),
|
||||
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
|
||||
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
|
||||
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
|
||||
@@ -113,7 +111,7 @@ func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
NumPrecision: cast.ToInt(re["NUM_PRECISION"]),
|
||||
NumScale: cast.ToInt(re["NUM_SCALE"]),
|
||||
}
|
||||
columnHelper.FixColumn(&column)
|
||||
dd.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -176,34 +174,7 @@ func (dd *DMMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (dd *DMMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
|
||||
// 1.获取表信息
|
||||
tbs, err := dd.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := dd.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
dialect := dd.dc.GetDialect()
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := dd.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(dd.dc.GetDialect(), dd, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
// 获取DM当前连接的库可访问的schemaNames
|
||||
|
||||
195
server/internal/db/dbm/dm/sqlgen.go
Normal file
195
server/internal/db/dbm/dm/sqlgen.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
Dialect dbi.Dialect
|
||||
Metadata dbi.Metadata
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
quoter := sg.Dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
tbName := quote(table.TableName)
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("create table %s (", tbName)
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, quote(column.ColumnName))
|
||||
}
|
||||
fields = append(fields, sg.genColumnBasicSql(quoter, column))
|
||||
if column.ColumnComment != "" {
|
||||
comment := dbi.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, quote(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
tableCommentSql := ""
|
||||
if table.TableComment != "" {
|
||||
comment := dbi.QuoteEscape(table.TableComment)
|
||||
tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment)
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
quote := sg.Dialect.Quoter().Quote
|
||||
sqls := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
quoter := sg.Dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone {
|
||||
identityInsert := fmt.Sprintf("set identity_insert %s on;", quote(tableName))
|
||||
|
||||
// 达梦数据库只能一条条的执行insert语句,所以这里需要将values拆分成多条insert语句
|
||||
return collx.ArrayMap(values, func(value []any) string {
|
||||
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value})
|
||||
return fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
|
||||
})
|
||||
}
|
||||
|
||||
// 查询主键字段
|
||||
uniqueCols := make([]string, 0)
|
||||
caseSqls := make([]string, 0)
|
||||
metadata := sg.Metadata
|
||||
tableCols, _ := metadata.GetColumns(tableName)
|
||||
identityCols := make([]string, 0)
|
||||
for _, col := range tableCols {
|
||||
if col.IsPrimaryKey {
|
||||
uniqueCols = append(uniqueCols, col.ColumnName)
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", quote(col.ColumnName), quote(col.ColumnName)))
|
||||
}
|
||||
if col.IsIdentity {
|
||||
// 自增字段不放入insert内,即使是设置了identity_insert on也不起作用
|
||||
identityCols = append(identityCols, quote(col.ColumnName))
|
||||
}
|
||||
}
|
||||
// 查询唯一索引涉及到的字段,并组装到match条件内
|
||||
indexs, _ := metadata.GetTableIndex(tableName)
|
||||
for _, index := range indexs {
|
||||
if index.IsUnique {
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
tmp := make([]string, 0)
|
||||
for _, col := range cols {
|
||||
uniqueCols = append(uniqueCols, col)
|
||||
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
|
||||
}
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
|
||||
}
|
||||
}
|
||||
|
||||
// 重复数据处理策略
|
||||
phs := make([]string, 0)
|
||||
insertVals := make([]string, 0)
|
||||
upds := make([]string, 0)
|
||||
insertCols := make([]string, 0)
|
||||
for _, column := range columns {
|
||||
columnName := column.ColumnName
|
||||
phs = append(phs, fmt.Sprintf("? %s", columnName))
|
||||
if !collx.ArrayContains(uniqueCols, quoter.Trim(columnName)) {
|
||||
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", columnName, columnName))
|
||||
}
|
||||
if !collx.ArrayContains(identityCols, columnName) {
|
||||
insertCols = append(insertCols, columnName)
|
||||
insertVals = append(insertVals, fmt.Sprintf("T2.%s", columnName))
|
||||
}
|
||||
|
||||
}
|
||||
t2s := make([]string, 0)
|
||||
for i := 0; i < len(values); i++ {
|
||||
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
|
||||
}
|
||||
t2 := strings.Join(t2s, " UNION ALL ")
|
||||
|
||||
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
|
||||
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
|
||||
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
|
||||
|
||||
return collx.AsArray(sqlTemp)
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
colName := quoter.Quote(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " IDENTITY"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
|
||||
return columnSql
|
||||
}
|
||||
97
server/internal/db/dbm/dm/transfer.go
Normal file
97
server/internal/db/dbm/dm/transfer.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package dm
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
return VARCHAR
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return CHAR
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
return TEXT
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return TEXT
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return LONGVARCHAR
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIT
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return TINYINT
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return SMALLINT
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return INTEGER
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIGINT
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return NUMBER
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return DECIMAL
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIGINT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return DATE
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIME
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIMESTAMP
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIMESTAMP
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return VARCHAR
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return VARCHAR
|
||||
}
|
||||
50
server/internal/db/dbm/mssql/column.go
Normal file
50
server/internal/db/dbm/mssql/column.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
var (
|
||||
DTOracleDate = dbi.DTDateTime.Copy().WithSQLValue(func(val any) string {
|
||||
// oracle date型需要用函数包裹:to_date('%s', 'yyyy-mm-dd hh24:mi:ss')
|
||||
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", val)
|
||||
})
|
||||
)
|
||||
|
||||
var (
|
||||
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
Numeric = dbi.NewDbDataType("numeric", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
|
||||
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2)
|
||||
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Smallmoney = dbi.NewDbDataType("smallmoney", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1)
|
||||
Money = dbi.NewDbDataType("money", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Real = dbi.NewDbDataType("real", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
|
||||
Datetimeoffset = dbi.NewDbDataType("datetimeoffset", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Datetime2 = dbi.NewDbDataType("datetime2", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Smalldatetime = dbi.NewDbDataType("smalldatetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
|
||||
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Nchar = dbi.NewDbDataType("nchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Nvarchar = dbi.NewDbDataType("nvarchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Ntext = dbi.NewDbDataType("ntext", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
|
||||
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTBinary)
|
||||
Cursor = dbi.NewDbDataType("cursor", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Rowversion = dbi.NewDbDataType("rowversion", dbi.DTBytes).WithCT(dbi.CTBinary)
|
||||
Hierarchyid = dbi.NewDbDataType("hierarchyid", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Uniqueidentifier = dbi.NewDbDataType("uniqueidentifier", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Sql_variant = dbi.NewDbDataType("sql_variant", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Xml = dbi.NewDbDataType("xml", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Table = dbi.NewDbDataType("table", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Geometry = dbi.NewDbDataType("geometry", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Geography = dbi.NewDbDataType("geography", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
)
|
||||
@@ -10,6 +10,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
mssqlQuoter = dbi.Quoter{
|
||||
Prefix: '[',
|
||||
Suffix: ']',
|
||||
IsReserved: dbi.AlwaysReserve,
|
||||
}
|
||||
)
|
||||
|
||||
type MssqlDialect struct {
|
||||
dbi.DefaultDialect
|
||||
|
||||
@@ -99,7 +107,8 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
|
||||
// 去除最后一个逗号
|
||||
placeholder = strings.TrimSuffix(repeated, ",")
|
||||
|
||||
baseTable := fmt.Sprintf("%s.%s", md.QuoteIdentifier(schema), md.QuoteIdentifier(tableName))
|
||||
quote := md.Quoter().Quote
|
||||
baseTable := fmt.Sprintf("%s.%s", quote(schema), quote(tableName))
|
||||
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", baseTable, strings.Join(columns, ","), placeholder)
|
||||
// 执行批量insert sql
|
||||
@@ -119,6 +128,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
|
||||
func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
msMetadata := md.dc.GetMetadata()
|
||||
schema := md.dc.Info.CurrentSchema()
|
||||
quote := md.Quoter().Quote
|
||||
|
||||
// 收集MERGE 语句的 ON 子句条件
|
||||
caseSqls := make([]string, 0)
|
||||
@@ -136,7 +146,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
|
||||
}
|
||||
if col.IsPrimaryKey {
|
||||
pkCols = append(pkCols, col.ColumnName)
|
||||
name := md.QuoteIdentifier(col.ColumnName)
|
||||
name := quote(col.ColumnName)
|
||||
caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name))
|
||||
}
|
||||
}
|
||||
@@ -150,7 +160,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
|
||||
// 源数据占位sql
|
||||
phs := make([]string, 0)
|
||||
for _, column := range columns {
|
||||
if !collx.ArrayContains(identityCols, md.RemoveQuote(column)) {
|
||||
if !collx.ArrayContains(identityCols, md.Quoter().Trim(column)) {
|
||||
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
|
||||
}
|
||||
insertCols = append(insertCols, column)
|
||||
@@ -168,7 +178,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
|
||||
}
|
||||
t2 := strings.Join(t2s, " UNION ALL ")
|
||||
|
||||
sqlTemp := "MERGE INTO " + md.QuoteIdentifier(schema) + "." + md.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
|
||||
sqlTemp := "MERGE INTO " + quote(schema) + "." + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
|
||||
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
|
||||
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
|
||||
|
||||
@@ -247,7 +257,7 @@ func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (str
|
||||
// 查询表名和表注释, 设置表注释
|
||||
tbs, err := metadata.GetTables(tableName)
|
||||
if err != nil || len(tbs) < 1 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
logx.Errorf("failed to get table, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tabInfo := &dbi.Table{
|
||||
@@ -258,165 +268,31 @@ func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (str
|
||||
// 查询列信息
|
||||
columns, err := metadata.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
logx.Errorf("failed to get columns, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
sqlArr := md.GenerateTableDDL(columns, *tabInfo, true)
|
||||
|
||||
sqlGener := md.GetSQLGenerator()
|
||||
sqlArr := sqlGener.GenTableDDL(*tabInfo, columns, true)
|
||||
|
||||
// 设置索引
|
||||
indexs, err := metadata.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
logx.Errorf("failed to get indexs, %s", tableName)
|
||||
return strings.Join(sqlArr, ";"), err
|
||||
}
|
||||
sqlArr = append(sqlArr, md.GenerateIndexDDL(indexs, *tabInfo)...)
|
||||
sqlArr = append(sqlArr, sqlGener.GenIndexDDL(*tabInfo, indexs)...)
|
||||
return strings.Join(sqlArr, ";"), nil
|
||||
}
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MssqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
tbName := tableInfo.TableName
|
||||
schemaName := md.dc.Info.CurrentSchema()
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
// 删除表
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, md.QuoteIdentifier(column.ColumnName))
|
||||
}
|
||||
fields = append(fields, md.genColumnBasicSql(column))
|
||||
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'"
|
||||
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := md.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName))
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// comment
|
||||
tableCommentSql := ""
|
||||
if tableInfo.TableComment != "" {
|
||||
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
|
||||
|
||||
tableCommentSql = fmt.Sprintf(commentTmp, md.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName)
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
// 获取建索引ddl
|
||||
func (md *MssqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
tbName := tableInfo.TableName
|
||||
sqls := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = md.QuoteIdentifier(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, md.QuoteIdentifier(index.IndexName), md.QuoteIdentifier(md.dc.Info.CurrentSchema()), md.QuoteIdentifier(tbName), strings.Join(colNames, ",")))
|
||||
if index.IndexComment != "" {
|
||||
comment := md.QuoteEscape(index.IndexComment)
|
||||
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName))
|
||||
}
|
||||
}
|
||||
if len(comments) > 0 {
|
||||
sqls = append(sqls, comments...)
|
||||
}
|
||||
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (dx *MssqlDialect) QuoteIdentifier(name string) string {
|
||||
return fmt.Sprintf("[%s]", name)
|
||||
}
|
||||
|
||||
func (dx *MssqlDialect) RemoveQuote(name string) string {
|
||||
return strings.Trim(name, "[]")
|
||||
}
|
||||
|
||||
func (md *MssqlDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (md *MssqlDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
func (dx *MssqlDialect) Quoter() dbi.Quoter {
|
||||
return mssqlQuoter
|
||||
}
|
||||
|
||||
func (md *MssqlDialect) GetDumpHelper() dbi.DumpHelper {
|
||||
return new(DumpHelper)
|
||||
}
|
||||
|
||||
func (md *MssqlDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
colName := md.QuoteIdentifier(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " IDENTITY(1,1)"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
|
||||
return columnSql
|
||||
func (md *MssqlDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{dc: md.dc}
|
||||
}
|
||||
|
||||
@@ -4,218 +4,15 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
|
||||
// 日期类型
|
||||
dateRegexp = regexp.MustCompile(`(?i)date`)
|
||||
// 时间类型
|
||||
timeRegexp = regexp.MustCompile(`(?i)time`)
|
||||
|
||||
// mssql数据类型 对应 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"bigint": dbi.CommonTypeBigint,
|
||||
"numeric": dbi.CommonTypeNumber,
|
||||
"bit": dbi.CommonTypeInt,
|
||||
"smallint": dbi.CommonTypeSmallint,
|
||||
"decimal": dbi.CommonTypeNumber,
|
||||
"smallmoney": dbi.CommonTypeNumber,
|
||||
"int": dbi.CommonTypeInt,
|
||||
"tinyint": dbi.CommonTypeSmallint, // mssql tinyint不支持负数
|
||||
"money": dbi.CommonTypeNumber,
|
||||
"float": dbi.CommonTypeNumber, // 近似数字
|
||||
"real": dbi.CommonTypeVarchar,
|
||||
"date": dbi.CommonTypeDate, // 日期和时间
|
||||
"datetimeoffset": dbi.CommonTypeDatetime,
|
||||
"datetime2": dbi.CommonTypeDatetime,
|
||||
"smalldatetime": dbi.CommonTypeDatetime,
|
||||
"datetime": dbi.CommonTypeDatetime,
|
||||
"time": dbi.CommonTypeTime,
|
||||
"char": dbi.CommonTypeChar, // 字符串
|
||||
"varchar": dbi.CommonTypeVarchar,
|
||||
"text": dbi.CommonTypeText,
|
||||
"nchar": dbi.CommonTypeChar,
|
||||
"nvarchar": dbi.CommonTypeVarchar,
|
||||
"ntext": dbi.CommonTypeText,
|
||||
"binary": dbi.CommonTypeBinary,
|
||||
"varbinary": dbi.CommonTypeBinary,
|
||||
"cursor": dbi.CommonTypeVarchar, // 其他
|
||||
"rowversion": dbi.CommonTypeVarchar,
|
||||
"hierarchyid": dbi.CommonTypeVarchar,
|
||||
"uniqueidentifier": dbi.CommonTypeVarchar,
|
||||
"sql_variant": dbi.CommonTypeVarchar,
|
||||
"xml": dbi.CommonTypeText,
|
||||
"table": dbi.CommonTypeText,
|
||||
"geometry": dbi.CommonTypeText, // 空间几何类型
|
||||
"geography": dbi.CommonTypeText, // 空间地理类型
|
||||
}
|
||||
|
||||
// 公共数据类型 对应 mssql数据类型
|
||||
|
||||
mssqlColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "nvarchar",
|
||||
dbi.CommonTypeChar: "nchar",
|
||||
dbi.CommonTypeText: "ntext",
|
||||
dbi.CommonTypeBlob: "ntext",
|
||||
dbi.CommonTypeLongblob: "ntext",
|
||||
dbi.CommonTypeLongtext: "ntext",
|
||||
dbi.CommonTypeBinary: "varbinary",
|
||||
dbi.CommonTypeMediumblob: "ntext",
|
||||
dbi.CommonTypeMediumtext: "ntext",
|
||||
dbi.CommonTypeVarbinary: "varbinary",
|
||||
dbi.CommonTypeInt: "int",
|
||||
dbi.CommonTypeSmallint: "smallint",
|
||||
dbi.CommonTypeTinyint: "smallint",
|
||||
dbi.CommonTypeNumber: "decimal",
|
||||
dbi.CommonTypeBigint: "bigint",
|
||||
dbi.CommonTypeDatetime: "datetime2",
|
||||
dbi.CommonTypeDate: "date",
|
||||
dbi.CommonTypeTime: "time",
|
||||
dbi.CommonTypeTimestamp: "timestamp",
|
||||
dbi.CommonTypeEnum: "nvarchar",
|
||||
dbi.CommonTypeJSON: "nvarchar",
|
||||
}
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if dateRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if timeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要根据类型格式化
|
||||
str, ok := dbColumnValue.(string)
|
||||
if dataType == dbi.DataTypeDateTime && ok {
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
}
|
||||
if dataType == dbi.DataTypeDate && ok {
|
||||
// 尝试用时间格式解析
|
||||
res, _ := time.Parse(time.DateOnly, str)
|
||||
return res.Format(time.DateOnly)
|
||||
}
|
||||
if dataType == dbi.DataTypeTime && ok {
|
||||
res, _ := time.Parse(time.TimeOnly, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return anyx.ToString(dbColumnValue)
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型
|
||||
_, ok := dbColumnValue.(string)
|
||||
if dataType == dbi.DataTypeDateTime && ok {
|
||||
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeDate && ok {
|
||||
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeTime && ok {
|
||||
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
// 翻译为通用数据库类型
|
||||
dataType := dialectColumn.DataType
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
if t1 == "" {
|
||||
dialectColumn.DataType = dbi.CommonTypeVarchar
|
||||
dialectColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
dialectColumn.DataType = t1
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
|
||||
ctype := mssqlColumnTypeMap[commonColumn.DataType]
|
||||
|
||||
if ctype == "" {
|
||||
commonColumn.DataType = "varchar"
|
||||
commonColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
commonColumn.DataType = dbi.ColumnDataType(ctype)
|
||||
ch.FixColumn(commonColumn)
|
||||
// 修复数据库迁移字段长度
|
||||
dataType := string(commonColumn.DataType)
|
||||
if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) {
|
||||
commonColumn.CharMaxLength = commonColumn.CharMaxLength * 2
|
||||
}
|
||||
|
||||
if collx.ArrayAnyMatches([]string{"char"}, dataType) {
|
||||
// char最大长度4000
|
||||
if commonColumn.CharMaxLength >= 4000 {
|
||||
commonColumn.DataType = "ntext"
|
||||
commonColumn.CharMaxLength = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@@ -12,9 +13,13 @@ import (
|
||||
|
||||
func init() {
|
||||
meta := new(Meta)
|
||||
dbi.Register(dbi.DbTypeMssql, meta)
|
||||
dbi.Register(DbTypeMssql, meta)
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeMssql dbi.DbType = "mssql"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
}
|
||||
|
||||
@@ -60,3 +65,45 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
return &MssqlMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray[*dbi.DbDataType](Bigint,
|
||||
Numeric,
|
||||
Bit,
|
||||
Smallint,
|
||||
Decimal,
|
||||
Smallmoney,
|
||||
Int,
|
||||
Tinyint,
|
||||
Money,
|
||||
Float,
|
||||
Real,
|
||||
Date,
|
||||
Datetimeoffset,
|
||||
Datetime2,
|
||||
Smalldatetime,
|
||||
Datetime,
|
||||
Time,
|
||||
Char,
|
||||
Varchar,
|
||||
Text,
|
||||
Nchar,
|
||||
Nvarchar,
|
||||
Ntext,
|
||||
Binary,
|
||||
Varbinary,
|
||||
Cursor,
|
||||
Rowversion,
|
||||
Hierarchyid,
|
||||
Uniqueidentifier,
|
||||
Sql_variant,
|
||||
Xml,
|
||||
Table,
|
||||
Geometry,
|
||||
Geography,
|
||||
)
|
||||
}
|
||||
|
||||
func (sm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
@@ -58,7 +57,7 @@ func (md *MssqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
schema := md.dc.Info.CurrentSchema()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -91,9 +90,8 @@ func (md *MssqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
// 获取列元信息, 如列名等
|
||||
func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_COLUMN_MA_KEY), tableName), md.dc.Info.CurrentSchema())
|
||||
@@ -107,7 +105,7 @@ func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
column := dbi.Column{
|
||||
TableName: anyx.ToString(re["TABLE_NAME"]),
|
||||
ColumnName: anyx.ToString(re["COLUMN_NAME"]),
|
||||
DataType: dbi.ColumnDataType(anyx.ToString(re["DATA_TYPE"])),
|
||||
DataType: anyx.ToString(re["DATA_TYPE"]),
|
||||
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
|
||||
ColumnComment: anyx.ToString(re["COLUMN_COMMENT"]),
|
||||
Nullable: anyx.ToString(re["NULLABLE"]) == "YES",
|
||||
@@ -118,8 +116,7 @@ func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
NumScale: cast.ToInt(re["NUM_SCALE"]),
|
||||
}
|
||||
|
||||
columnHelper.FixColumn(&column)
|
||||
|
||||
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -197,33 +194,7 @@ func (md *MssqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MssqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
// 1.获取表信息
|
||||
tbs, err := md.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := md.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
dialect := md.dc.GetDialect()
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := md.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
func (md *MssqlMetadata) GetSchemas() ([]string, error) {
|
||||
|
||||
146
server/internal/db/dbm/mssql/sqlgen.go
Normal file
146
server/internal/db/dbm/mssql/sqlgen.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
tbName := table.TableName
|
||||
schemaName := sg.dc.Info.CurrentSchema()
|
||||
|
||||
quoter := sg.dc.GetDialect().Quoter()
|
||||
quote := quoter.Quote
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
// 删除表
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", quote(schemaName), quote(tbName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", quote(schemaName), quote(tbName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, quote(column.ColumnName))
|
||||
}
|
||||
fields = append(fields, sg.genColumnBasicSql(quoter, column))
|
||||
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'"
|
||||
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := dbi.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, sg.dc.Info.CurrentSchema(), tbName, column.ColumnName))
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// comment
|
||||
tableCommentSql := ""
|
||||
if table.TableComment != "" {
|
||||
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
|
||||
|
||||
tableCommentSql = fmt.Sprintf(commentTmp, dbi.QuoteEscape(table.TableComment), sg.dc.Info.CurrentSchema(), tbName)
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
quote := sg.dc.GetDialect().Quoter().Quote
|
||||
tbName := table.TableName
|
||||
sqls := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, quote(index.IndexName), quote(sg.dc.Info.CurrentSchema()), quote(tbName), strings.Join(colNames, ",")))
|
||||
if index.IndexComment != "" {
|
||||
comment := dbi.QuoteEscape(index.IndexComment)
|
||||
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, sg.dc.Info.CurrentSchema(), tbName, index.IndexName))
|
||||
}
|
||||
}
|
||||
if len(comments) > 0 {
|
||||
sqls = append(sqls, comments...)
|
||||
}
|
||||
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
return collx.AsArray("")
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
colName := quoter.Quote(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " IDENTITY(1,1)"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
|
||||
return columnSql
|
||||
}
|
||||
97
server/internal/db/dbm/mssql/transfer.go
Normal file
97
server/internal/db/dbm/mssql/transfer.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package mssql
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varchar
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return Char
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bit
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Tinyint
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Smallint
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bigint
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return Numeric
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return Decimal
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bigint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Smallint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Tinyint
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return Date
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return Time
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return Datetime
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return Datetime
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varbinary
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varchar
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
48
server/internal/db/dbm/mysql/column.go
Normal file
48
server/internal/db/dbm/mysql/column.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexSubPartKey = "subPart"
|
||||
)
|
||||
|
||||
var (
|
||||
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
|
||||
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1).WithFixColumn(dbi.ClearNumScale)
|
||||
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
|
||||
Mediumint = dbi.NewDbDataType("mediumint", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
|
||||
|
||||
UnsignedBigint = dbi.NewDbDataType("unsigned bigint", dbi.DTUint64).WithCT(dbi.CTUnsignedInt8).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedInt = dbi.NewDbDataType("unsigned int", dbi.DTUint64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedSmallint = dbi.NewDbDataType("unsigned smallint", dbi.DTInt32).WithCT(dbi.CTUnsignedInt2).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedMediumint = dbi.NewDbDataType("unsigned mediumint", dbi.DTInt64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
|
||||
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Double = dbi.NewDbDataType("double", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
|
||||
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
|
||||
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Mediumtext = dbi.NewDbDataType("mediumtext", dbi.DTString).WithCT(dbi.CTMediumtext).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Longtext = dbi.NewDbDataType("longtext", dbi.DTString).WithCT(dbi.CTLongtext).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
JSON = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
|
||||
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
|
||||
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
|
||||
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTTimestamp)
|
||||
|
||||
Enum = dbi.NewDbDataType("enum", dbi.DTString).WithCT(dbi.CTEnum)
|
||||
Set = dbi.NewDbDataType("set", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
|
||||
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
|
||||
Mediumblob = dbi.NewDbDataType("mediumblob", dbi.DTBytes).WithCT(dbi.CTMediumblob)
|
||||
Longblob = dbi.NewDbDataType("longblob", dbi.DTBytes).WithCT(dbi.CTLongblob)
|
||||
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
|
||||
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTVarbinary)
|
||||
)
|
||||
@@ -1,17 +1,20 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/mysql"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const Quoter = "`"
|
||||
var (
|
||||
mysqlQuoter = dbi.Quoter{
|
||||
Prefix: '`',
|
||||
Suffix: '`',
|
||||
IsReserved: dbi.AlwaysReserve,
|
||||
}
|
||||
)
|
||||
|
||||
type MysqlDialect struct {
|
||||
dbi.DefaultDialect
|
||||
@@ -24,38 +27,6 @@ func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
|
||||
return NewDbProgramMysql(md.dc), nil
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
// 生成占位符字符串:如:(?,?)
|
||||
// 重复字符串并用逗号连接
|
||||
repeated := strings.Repeat("?,", len(columns))
|
||||
// 去除最后一个逗号,占位符由括号包裹
|
||||
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
|
||||
|
||||
// 执行批量insert sql,mysql支持批量insert语法
|
||||
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
|
||||
|
||||
// 重复占位符字符串n遍
|
||||
repeated = strings.Repeat(placeholder+",", len(values))
|
||||
// 去除最后一个逗号
|
||||
placeholder = strings.TrimSuffix(repeated, ",")
|
||||
|
||||
prefix := "insert into"
|
||||
if duplicateStrategy == 1 {
|
||||
prefix = "insert ignore into"
|
||||
} else if duplicateStrategy == 2 {
|
||||
prefix = "replace into"
|
||||
}
|
||||
|
||||
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
// 执行批量insert sql
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
for _, v := range values {
|
||||
args = append(args, v...)
|
||||
}
|
||||
return md.dc.TxExec(tx, sqlStr, args...)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
tableName := copy.TableName
|
||||
|
||||
@@ -77,145 +48,14 @@ func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MysqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", md.QuoteIdentifier(tableInfo.TableName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", md.QuoteIdentifier(tableInfo.TableName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, column.ColumnName)
|
||||
}
|
||||
fields = append(fields, md.genColumnBasicSql(column))
|
||||
}
|
||||
|
||||
// 建表ddl
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// 表注释
|
||||
if tableInfo.TableComment != "" {
|
||||
createSql += fmt.Sprintf(" COMMENT '%s'", md.QuoteEscape(tableInfo.TableComment))
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
// 获取建索引ddl
|
||||
func (md *MysqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = md.QuoteIdentifier(name)
|
||||
}
|
||||
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
|
||||
sqlStr := fmt.Sprintf(sqlTmp, md.QuoteIdentifier(tableInfo.TableName), unique, md.QuoteIdentifier(index.IndexName), strings.Join(colNames, ","))
|
||||
comment := md.QuoteEscape(index.IndexComment)
|
||||
if comment != "" {
|
||||
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
|
||||
}
|
||||
sqlArr = append(sqlArr, sqlStr)
|
||||
}
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
columnType := column.GetColumnType()
|
||||
if nullAble == "" && strings.Contains(columnType, "timestamp") {
|
||||
nullAble = " NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
|
||||
if column.ColumnDefault != "" &&
|
||||
// 当默认值是字符串'NULL'时,不需要设置默认值
|
||||
column.ColumnDefault != "NULL" &&
|
||||
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
!strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
comment := ""
|
||||
if column.ColumnComment != "" {
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
commentStr := md.QuoteEscape(column.ColumnComment)
|
||||
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment)
|
||||
return columnSql
|
||||
}
|
||||
|
||||
func (dx *MysqlDialect) QuoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return Quoter + strings.Replace(name, Quoter, Quoter+Quoter, -1) + Quoter
|
||||
}
|
||||
|
||||
func (dx *MysqlDialect) RemoveQuote(name string) string {
|
||||
return strings.ReplaceAll(name, Quoter, "")
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) QuoteLiteral(literal string) string {
|
||||
literal = strings.ReplaceAll(literal, `\`, `\\`)
|
||||
literal = strings.ReplaceAll(literal, `'`, `''`)
|
||||
return "'" + literal + "'"
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
func (md *MysqlDialect) Quoter() dbi.Quoter {
|
||||
return mysqlQuoter
|
||||
}
|
||||
|
||||
func (pd *MysqlDialect) GetSQLParser() sqlparser.SqlParser {
|
||||
return new(mysql.MysqlParser)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{Dialect: md}
|
||||
}
|
||||
|
||||
@@ -1,218 +1 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
|
||||
// 日期类型
|
||||
dateRegexp = regexp.MustCompile(`(?i)date`)
|
||||
// 时间类型
|
||||
timeRegexp = regexp.MustCompile(`(?i)time`)
|
||||
|
||||
blobRegexp = regexp.MustCompile(`(?i)blob`)
|
||||
|
||||
// mysql数据类型 映射 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"bigint": dbi.CommonTypeBigint,
|
||||
"binary": dbi.CommonTypeBinary,
|
||||
"blob": dbi.CommonTypeBlob,
|
||||
"char": dbi.CommonTypeChar,
|
||||
"datetime": dbi.CommonTypeDatetime,
|
||||
"date": dbi.CommonTypeDate,
|
||||
"decimal": dbi.CommonTypeNumber,
|
||||
"double": dbi.CommonTypeNumber,
|
||||
"enum": dbi.CommonTypeEnum,
|
||||
"float": dbi.CommonTypeNumber,
|
||||
"int": dbi.CommonTypeInt,
|
||||
"json": dbi.CommonTypeJSON,
|
||||
"longblob": dbi.CommonTypeLongblob,
|
||||
"longtext": dbi.CommonTypeLongtext,
|
||||
"mediumblob": dbi.CommonTypeBlob,
|
||||
"mediumtext": dbi.CommonTypeMediumtext,
|
||||
"bit": dbi.CommonTypeBit,
|
||||
"set": dbi.CommonTypeVarchar,
|
||||
"smallint": dbi.CommonTypeSmallint,
|
||||
"text": dbi.CommonTypeText,
|
||||
"time": dbi.CommonTypeTime,
|
||||
"timestamp": dbi.CommonTypeTimestamp,
|
||||
"tinyint": dbi.CommonTypeTinyint,
|
||||
"varbinary": dbi.CommonTypeVarbinary,
|
||||
"varchar": dbi.CommonTypeVarchar,
|
||||
}
|
||||
|
||||
// 公共数据类型 映射 mysql数据类型
|
||||
mysqlColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "varchar",
|
||||
dbi.CommonTypeChar: "char",
|
||||
dbi.CommonTypeText: "text",
|
||||
dbi.CommonTypeBlob: "blob",
|
||||
dbi.CommonTypeLongblob: "longblob",
|
||||
dbi.CommonTypeLongtext: "longtext",
|
||||
dbi.CommonTypeBinary: "binary",
|
||||
dbi.CommonTypeMediumblob: "blob",
|
||||
dbi.CommonTypeMediumtext: "mediumtext",
|
||||
dbi.CommonTypeVarbinary: "varbinary",
|
||||
dbi.CommonTypeInt: "int",
|
||||
dbi.CommonTypeBit: "bit",
|
||||
dbi.CommonTypeSmallint: "smallint",
|
||||
dbi.CommonTypeTinyint: "tinyint",
|
||||
dbi.CommonTypeNumber: "decimal",
|
||||
dbi.CommonTypeBigint: "bigint",
|
||||
dbi.CommonTypeDatetime: "datetime",
|
||||
dbi.CommonTypeDate: "date",
|
||||
dbi.CommonTypeTime: "time",
|
||||
dbi.CommonTypeTimestamp: "timestamp",
|
||||
dbi.CommonTypeEnum: "enum",
|
||||
dbi.CommonTypeJSON: "json",
|
||||
}
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if dateRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if timeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
// blob类型
|
||||
if blobRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeBlob
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要根据类型格式化
|
||||
str, ok := dbColumnValue.(string)
|
||||
if dataType == dbi.DataTypeDateTime && ok {
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
}
|
||||
if dataType == dbi.DataTypeDate && ok {
|
||||
res, _ := time.Parse(time.DateOnly, str)
|
||||
return res.Format(time.DateOnly)
|
||||
}
|
||||
if dataType == dbi.DataTypeTime && ok {
|
||||
res, _ := time.Parse(time.TimeOnly, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return anyx.ToString(dbColumnValue)
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型
|
||||
_, ok := dbColumnValue.(string)
|
||||
if ok {
|
||||
if dataType == dbi.DataTypeDateTime {
|
||||
res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeDate {
|
||||
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeTime {
|
||||
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
// mysql时间类型无需格式化
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
case dbi.DataTypeBlob:
|
||||
return fmt.Sprintf("unhex('%s')", dbColumnValue)
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
dataType := dialectColumn.DataType
|
||||
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
commonColumnType := dbi.CommonTypeVarchar
|
||||
|
||||
if t1 != "" {
|
||||
commonColumnType = t1
|
||||
}
|
||||
|
||||
dialectColumn.DataType = commonColumnType
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(column *dbi.Column) {
|
||||
ctype := mysqlColumnTypeMap[column.DataType]
|
||||
if ctype == "" {
|
||||
column.DataType = "varchar"
|
||||
column.CharMaxLength = 1000
|
||||
} else {
|
||||
column.DataType = dbi.ColumnDataType(ctype)
|
||||
ch.FixColumn(column)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
// 如果是int整型,删除精度
|
||||
if strings.Contains(strings.ToLower(string(column.DataType)), "int") {
|
||||
column.NumScale = 0
|
||||
column.CharMaxLength = 0
|
||||
} else
|
||||
// 如果是text,删除长度
|
||||
if strings.Contains(strings.ToLower(string(column.DataType)), "text") {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"net"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
@@ -12,10 +13,15 @@ import (
|
||||
|
||||
func init() {
|
||||
meta := new(Meta)
|
||||
dbi.Register(dbi.DbTypeMysql, meta)
|
||||
dbi.Register(dbi.DbTypeMariadb, meta)
|
||||
dbi.Register(DbTypeMysql, meta)
|
||||
dbi.Register(DbTypeMariadb, meta)
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeMysql dbi.DbType = "mysql"
|
||||
DbTypeMariadb dbi.DbType = "mariadb"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
}
|
||||
|
||||
@@ -31,7 +37,7 @@ func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?parseTime=true&timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
@@ -46,3 +52,17 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
return &MysqlMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (mm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray(
|
||||
UnsignedBigint, Bigint, Tinyint, Smallint, Int, Bit, Float, Double, Decimal,
|
||||
Varchar, Char, Text, Longtext, Mediumtext,
|
||||
Datetime, Date, Time, Timestamp,
|
||||
Enum, JSON, Set,
|
||||
Binary, Blob, Longblob, Mediumblob, Varbinary,
|
||||
)
|
||||
}
|
||||
|
||||
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
@@ -54,7 +53,7 @@ func (md *MysqlMetadata) GetDbNames() ([]string, error) {
|
||||
func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -87,9 +86,8 @@ func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
// 获取列元信息, 如列名等
|
||||
func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
@@ -103,7 +101,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["tableName"]),
|
||||
ColumnName: cast.ToString(re["columnName"]),
|
||||
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
|
||||
DataType: cast.ToString(re["dataType"]),
|
||||
ColumnComment: cast.ToString(re["columnComment"]),
|
||||
Nullable: cast.ToString(re["nullable"]) == "YES",
|
||||
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
|
||||
@@ -114,7 +112,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
NumScale: cast.ToInt(re["numScale"]),
|
||||
}
|
||||
|
||||
columnHelper.FixColumn(&column)
|
||||
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -156,6 +154,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
IsUnique: cast.ToInt(re["isUnique"]) == 1,
|
||||
SeqInIndex: cast.ToInt(re["seqInIndex"]),
|
||||
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
|
||||
Extra: collx.Kvs(IndexSubPartKey, cast.ToInt(re[IndexSubPartKey])),
|
||||
})
|
||||
}
|
||||
// 把查询结果以索引名分组,索引字段以逗号连接
|
||||
@@ -179,34 +178,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MysqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
// 1.获取表信息
|
||||
tbs, err := md.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := md.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
|
||||
dialect := md.dc.GetDialect()
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := md.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
func (md *MysqlMetadata) GetSchemas() ([]string, error) {
|
||||
|
||||
@@ -56,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
|
||||
dbInfo := svc.dbInfo()
|
||||
var mysqlBin *config.MysqlBin
|
||||
switch dbInfo.Type {
|
||||
case dbi.DbTypeMariadb:
|
||||
case DbTypeMariadb:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
|
||||
case dbi.DbTypeMysql:
|
||||
case DbTypeMysql:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
|
||||
default:
|
||||
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))
|
||||
|
||||
147
server/internal/db/dbm/mysql/sqlgen.go
Normal file
147
server/internal/db/dbm/mysql/sqlgen.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
Dialect dbi.Dialect
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
quoter := msg.Dialect.Quoter()
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoter.Quote(table.TableName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoter.Quote(table.TableName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, column.ColumnName)
|
||||
}
|
||||
fields = append(fields, msg.genColumnBasicSql(quoter, column))
|
||||
}
|
||||
|
||||
// 建表ddl
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// 表注释
|
||||
if table.TableComment != "" {
|
||||
createSql += fmt.Sprintf(" COMMENT '%s'", dbi.QuoteEscape(table.TableComment))
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
quoter := msg.Dialect.Quoter()
|
||||
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
colNames := quoter.Quotes(strings.Split(index.ColumnName, ","))
|
||||
|
||||
// 暂时先处理单个索引的情况,多个涉及获取索引时的合并等,以及前端调整等,后续完善
|
||||
if subPart := cast.ToInt(index.Extra[IndexSubPartKey]); subPart > 0 && len(colNames) == 1 {
|
||||
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
|
||||
}
|
||||
|
||||
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s"
|
||||
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType)
|
||||
comment := dbi.QuoteEscape(index.IndexComment)
|
||||
if comment != "" {
|
||||
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
|
||||
}
|
||||
sqlArr = append(sqlArr, sqlStr)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone {
|
||||
return collx.AsArray(dbi.GenCommonInsert(msg.Dialect, DbTypeMysql, tableName, columns, values))
|
||||
}
|
||||
|
||||
prefix := "insert ignore into"
|
||||
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
prefix = "replace into"
|
||||
}
|
||||
|
||||
quote := msg.Dialect.Quoter().Quote
|
||||
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(msg.Dialect, DbTypeMysql, columns, values)
|
||||
|
||||
return collx.AsArray[string](fmt.Sprintf("%s %s %s VALUES \n%s", prefix, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
columnType := column.GetColumnType()
|
||||
if nullAble == "" && strings.Contains(columnType, "timestamp") {
|
||||
nullAble = " NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
|
||||
if column.ColumnDefault != "" &&
|
||||
// 当默认值是字符串'NULL'时,不需要设置默认值
|
||||
column.ColumnDefault != "NULL" &&
|
||||
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
!strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
comment := ""
|
||||
if column.ColumnComment != "" {
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
commentStr := dbi.QuoteEscape(column.ColumnComment)
|
||||
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", quoter.Quote(column.ColumnName), columnType, nullAble, incr, defVal, comment)
|
||||
return columnSql
|
||||
}
|
||||
108
server/internal/db/dbm/mysql/transfer.go
Normal file
108
server/internal/db/dbm/mysql/transfer.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mysql
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
// 如果字符长度大于16383,则转为text类型
|
||||
if col.CharMaxLength > 16383 {
|
||||
col.CharMaxLength = 0
|
||||
return Text
|
||||
}
|
||||
return Varchar
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return Char
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Mediumtext
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Longtext
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bit
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Tinyint
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Smallint
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bigint
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return Double
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return Decimal
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedBigint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedInt
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedMediumint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedSmallint
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return Date
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return Time
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return Datetime
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return Timestamp
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varbinary
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Mediumblob
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Longblob
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return Enum
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return JSON
|
||||
}
|
||||
45
server/internal/db/dbm/oracle/column.go
Normal file
45
server/internal/db/dbm/oracle/column.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
var (
|
||||
DTOracleDate = dbi.DTDateTime.Copy().WithSQLValue(func(val any) string {
|
||||
// oracle date型需要用函数包裹:to_date('%s', 'yyyy-mm-dd hh24:mi:ss')
|
||||
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", val)
|
||||
})
|
||||
)
|
||||
|
||||
var (
|
||||
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar)
|
||||
NCHAR = dbi.NewDbDataType("NCHAR", dbi.DTString).WithCT(dbi.CTChar)
|
||||
VARCHAR2 = dbi.NewDbDataType("VARCHAR2", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
NVARCHAR2 = dbi.NewDbDataType("NVARCHAR2", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
|
||||
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
|
||||
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
|
||||
LONGVARCHAR = dbi.NewDbDataType("LONGVARCHAR", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
IMAGE = dbi.NewDbDataType("IMAGE", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
LONGVARBINARY = dbi.NewDbDataType("LONGVARBINARY", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
CLOB = dbi.NewDbDataType("CLOB", dbi.DTString).WithCT(dbi.CTLongtext)
|
||||
|
||||
BLOB = dbi.NewDbDataType("BLOB", dbi.DTBytes).WithCT(dbi.CTBlob)
|
||||
|
||||
DECIMAL = dbi.NewDbDataType("DECIMAL", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
NUMBER = dbi.NewDbDataType("NUMBER", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
INTEGER = dbi.NewDbDataType("INTEGER", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
INT = dbi.NewDbDataType("INT", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
BIGINT = dbi.NewDbDataType("BIGINT", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
TINYINT = dbi.NewDbDataType("TINYINT", dbi.DTInt8).WithCT(dbi.CTInt1)
|
||||
BYTE = dbi.NewDbDataType("BYTE", dbi.DTInt8).WithCT(dbi.CTInt1)
|
||||
SMALLINT = dbi.NewDbDataType("SMALLINT", dbi.DTInt16).WithCT(dbi.CTInt2)
|
||||
BIT = dbi.NewDbDataType("BIT", dbi.DTBit).WithCT(dbi.CTBit)
|
||||
DOUBLE = dbi.NewDbDataType("DOUBLE", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
FLOAT = dbi.NewDbDataType("FLOAT", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
|
||||
TIME = dbi.NewDbDataType("TIME", DTOracleDate).WithCT(dbi.CTTime)
|
||||
DATE = dbi.NewDbDataType("DATE", DTOracleDate).WithCT(dbi.CTDate)
|
||||
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", DTOracleDate).WithCT(dbi.CTTimestamp)
|
||||
)
|
||||
@@ -59,6 +59,8 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v
|
||||
ignore = fmt.Sprintf("/*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */", tableName, strings.Join(arr, ","))
|
||||
}
|
||||
}
|
||||
|
||||
quote := od.Quoter().Quote
|
||||
effRows := 0
|
||||
for _, value := range values {
|
||||
// 拼接带占位符的sql oracle的占位符是:1,:2,:3....
|
||||
@@ -66,7 +68,7 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v
|
||||
for i := 0; i < len(value); i++ {
|
||||
placeholder = append(placeholder, fmt.Sprintf(":%d", i+1))
|
||||
}
|
||||
sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, od.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
|
||||
sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, quote(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
|
||||
|
||||
// oracle数据库为了兼容ignore主键冲突,只能一条条的执行insert
|
||||
res, err := od.dc.TxExec(tx, sqlTemp, value...)
|
||||
@@ -83,6 +85,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
|
||||
uniqueCols := make([]string, 0)
|
||||
caseSqls := make([]string, 0)
|
||||
metadata := od.dc.GetMetadata()
|
||||
quote := od.Quoter().Quote
|
||||
// 查询唯一索引涉及到的字段,并组装到match条件内
|
||||
indexs, _ := metadata.GetTableIndex(tableName)
|
||||
if indexs != nil {
|
||||
@@ -94,7 +97,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
|
||||
if !collx.ArrayContains(uniqueCols, col) {
|
||||
uniqueCols = append(uniqueCols, col)
|
||||
}
|
||||
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", od.QuoteIdentifier(col), od.QuoteIdentifier(col)))
|
||||
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
|
||||
}
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
|
||||
}
|
||||
@@ -110,8 +113,9 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
|
||||
insertVals := make([]string, 0)
|
||||
upds := make([]string, 0)
|
||||
insertCols := make([]string, 0)
|
||||
quoteTrim := od.Quoter().Trim
|
||||
for _, column := range columns {
|
||||
if !collx.ArrayContains(uniqueCols, od.RemoveQuote(column)) {
|
||||
if !collx.ArrayContains(uniqueCols, quoteTrim(column)) {
|
||||
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
|
||||
}
|
||||
insertCols = append(insertCols, fmt.Sprintf("T1.%s", column))
|
||||
@@ -132,7 +136,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
|
||||
|
||||
t2 := strings.Join(t2s, " UNION ALL ")
|
||||
|
||||
sqlTemp := "MERGE INTO " + od.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
|
||||
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
|
||||
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
|
||||
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
|
||||
|
||||
@@ -157,7 +161,8 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
|
||||
// 获取建表ddl
|
||||
func (od *OracleDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
quoteTableName := od.QuoteIdentifier(tableInfo.TableName)
|
||||
quote := od.Quoter().Quote
|
||||
quoteTableName := quote(tableInfo.TableName)
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
@@ -181,13 +186,13 @@ end`
|
||||
// 把通用类型转换为达梦类型
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, od.QuoteIdentifier(column.ColumnName))
|
||||
pks = append(pks, quote(column.ColumnName))
|
||||
}
|
||||
fields = append(fields, od.genColumnBasicSql(column))
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := od.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, od.QuoteIdentifier(column.ColumnName), comment))
|
||||
comment := dbi.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, quote(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,7 +207,7 @@ end`
|
||||
// 表注释
|
||||
tableCommentSql := ""
|
||||
if tableInfo.TableComment != "" {
|
||||
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", od.QuoteIdentifier(tableInfo.TableName), od.QuoteEscape(tableInfo.TableComment))
|
||||
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", quote(tableInfo.TableName), dbi.QuoteEscape(tableInfo.TableComment))
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
|
||||
@@ -222,7 +227,7 @@ end`
|
||||
func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
sqls := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
|
||||
quote := od.Quoter().Quote
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
@@ -233,10 +238,10 @@ func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tabl
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = od.QuoteIdentifier(name)
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, od.QuoteIdentifier(index.IndexName), od.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
|
||||
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, quote(index.IndexName), quote(tableInfo.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
@@ -255,16 +260,15 @@ func (od *OracleDialect) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableNa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (od *OracleDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (od *OracleDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
func (od *OracleDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{
|
||||
Dialect: od,
|
||||
Metadata: od.dc.GetMetadata(),
|
||||
}
|
||||
}
|
||||
|
||||
func (od *OracleDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
colName := od.QuoteIdentifier(column.ColumnName)
|
||||
colName := od.Quoter().Quote(column.ColumnName)
|
||||
|
||||
if column.IsIdentity {
|
||||
// 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number
|
||||
|
||||
@@ -1,174 +1,16 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
dateTimeReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`)
|
||||
dateTimeIsoReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*$`)
|
||||
|
||||
// 日期时间类型
|
||||
datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`)
|
||||
|
||||
// oracle数据类型 映射 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"CHAR": dbi.CommonTypeChar,
|
||||
"NCHAR": dbi.CommonTypeChar,
|
||||
"VARCHAR2": dbi.CommonTypeVarchar,
|
||||
"NVARCHAR2": dbi.CommonTypeVarchar,
|
||||
"NUMBER": dbi.CommonTypeNumber,
|
||||
"INTEGER": dbi.CommonTypeInt,
|
||||
"INT": dbi.CommonTypeInt,
|
||||
"DECIMAL": dbi.CommonTypeNumber,
|
||||
"FLOAT": dbi.CommonTypeNumber,
|
||||
"REAL": dbi.CommonTypeNumber,
|
||||
"BINARY_FLOAT": dbi.CommonTypeNumber,
|
||||
"BINARY_DOUBLE": dbi.CommonTypeNumber,
|
||||
"DATE": dbi.CommonTypeDate,
|
||||
"TIMESTAMP": dbi.CommonTypeDatetime,
|
||||
"LONG": dbi.CommonTypeLongtext,
|
||||
"BLOB": dbi.CommonTypeLongtext,
|
||||
"CLOB": dbi.CommonTypeLongtext,
|
||||
"NCLOB": dbi.CommonTypeLongtext,
|
||||
"BFILE": dbi.CommonTypeBinary,
|
||||
}
|
||||
|
||||
// 公共数据类型 映射 oracle数据类型
|
||||
oracleColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "NVARCHAR2",
|
||||
dbi.CommonTypeChar: "NCHAR",
|
||||
dbi.CommonTypeText: "CLOB",
|
||||
dbi.CommonTypeBlob: "CLOB",
|
||||
dbi.CommonTypeLongblob: "CLOB",
|
||||
dbi.CommonTypeLongtext: "CLOB",
|
||||
dbi.CommonTypeBinary: "BFILE",
|
||||
dbi.CommonTypeMediumblob: "CLOB",
|
||||
dbi.CommonTypeMediumtext: "CLOB",
|
||||
dbi.CommonTypeVarbinary: "BFILE",
|
||||
dbi.CommonTypeInt: "INT",
|
||||
dbi.CommonTypeSmallint: "INT",
|
||||
dbi.CommonTypeTinyint: "INT",
|
||||
dbi.CommonTypeNumber: "NUMBER",
|
||||
dbi.CommonTypeBigint: "NUMBER",
|
||||
dbi.CommonTypeDatetime: "DATE",
|
||||
dbi.CommonTypeDate: "DATE",
|
||||
dbi.CommonTypeTime: "DATE",
|
||||
dbi.CommonTypeTimestamp: "TIMESTAMP",
|
||||
dbi.CommonTypeEnum: "CLOB",
|
||||
dbi.CommonTypeJSON: "CLOB",
|
||||
}
|
||||
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberTypeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if datetimeTypeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
str := anyx.ToString(dbColumnValue)
|
||||
if dateTimeReg.MatchString(str) || dateTimeIsoReg.MatchString(str) {
|
||||
dataType = dbi.DataTypeDateTime
|
||||
}
|
||||
switch dataType {
|
||||
// oracle把日期类型数据格式化输出
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// oracle把日期类型的数据转化为time类型
|
||||
if dataType == dbi.DataTypeDateTime {
|
||||
res, _ := time.Parse(time.RFC3339, cast.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", dc.FormatData(dbColumnValue, dataType))
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
// 翻译为通用数据库类型
|
||||
dataType := dialectColumn.DataType
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
if t1 == "" {
|
||||
dialectColumn.DataType = dbi.CommonTypeVarchar
|
||||
dialectColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
dialectColumn.DataType = t1
|
||||
// 如果是number类型,需要根据公共类型加上长度, 如 bigint 需要转换为number(19,0)
|
||||
if strings.Contains(string(t1), "NUMBER") {
|
||||
dialectColumn.CharMaxLength = 19
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
|
||||
ctype := oracleColumnTypeMap[commonColumn.DataType]
|
||||
if ctype == "" {
|
||||
commonColumn.DataType = "NVARCHAR2"
|
||||
commonColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
commonColumn.DataType = dbi.ColumnDataType(ctype)
|
||||
ch.FixColumn(commonColumn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/jsonx"
|
||||
"strings"
|
||||
|
||||
@@ -12,11 +13,12 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
dbi.Register(dbi.DbTypeOracle, new(Meta))
|
||||
dbi.Register(DbTypeOracle, new(Meta))
|
||||
}
|
||||
|
||||
const (
|
||||
DbVersionOracle11 dbi.DbVersion = "11"
|
||||
DbTypeOracle dbi.DbType = "oracle"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
@@ -114,3 +116,37 @@ func (om *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
}
|
||||
return &OracleMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray[*dbi.DbDataType](
|
||||
CHAR,
|
||||
NCHAR,
|
||||
VARCHAR2,
|
||||
NVARCHAR2,
|
||||
TEXT,
|
||||
LONG,
|
||||
LONGVARCHAR,
|
||||
IMAGE,
|
||||
LONGVARBINARY,
|
||||
CLOB,
|
||||
BLOB,
|
||||
DECIMAL,
|
||||
NUMBER,
|
||||
INTEGER,
|
||||
INT,
|
||||
BIGINT,
|
||||
TINYINT,
|
||||
BYTE,
|
||||
SMALLINT,
|
||||
BIT,
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
TIME,
|
||||
DATE,
|
||||
TIMESTAMP,
|
||||
)
|
||||
}
|
||||
|
||||
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
@@ -61,7 +60,7 @@ func (od *OracleMetadata) GetDbNames() ([]string, error) {
|
||||
func (od *OracleMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := od.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -95,7 +94,7 @@ func (od *OracleMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := od.dc.GetDialect()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
// 如果表数量超过了1000,需要分批查询
|
||||
@@ -121,13 +120,12 @@ func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["TABLE_NAME"]),
|
||||
ColumnName: cast.ToString(re["COLUMN_NAME"]),
|
||||
DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])),
|
||||
DataType: cast.ToString(re["DATA_TYPE"]),
|
||||
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
|
||||
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
|
||||
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
|
||||
@@ -138,7 +136,7 @@ func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
NumScale: cast.ToInt(re["NUM_SCALE"]),
|
||||
}
|
||||
|
||||
columnHelper.FixColumn(&column)
|
||||
od.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -201,33 +199,7 @@ func (od *OracleMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (od *OracleMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
dialect := od.dc.GetDialect()
|
||||
// 1.获取表信息
|
||||
tbs, err := od.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := od.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := od.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(od.dc.GetDialect(), od, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
// 获取DM当前连接的库可访问的schemaNames
|
||||
|
||||
@@ -21,7 +21,7 @@ type OracleMetadata11 struct {
|
||||
func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := od.dc.GetDialect()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
// 如果表数量超过了1000,需要分批查询
|
||||
@@ -47,13 +47,12 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["TABLE_NAME"]),
|
||||
ColumnName: cast.ToString(re["COLUMN_NAME"]),
|
||||
DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])),
|
||||
DataType: cast.ToString(re["DATA_TYPE"]),
|
||||
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
|
||||
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
|
||||
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
|
||||
@@ -64,7 +63,7 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
|
||||
NumScale: cast.ToInt(re["NUM_SCALE"]),
|
||||
}
|
||||
|
||||
columnHelper.FixColumn(&column)
|
||||
od.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -72,7 +71,7 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
|
||||
|
||||
func (od *OracleMetadata11) genColumnBasicSql(column dbi.Column) string {
|
||||
dialect := od.dc.GetDialect()
|
||||
colName := dialect.QuoteIdentifier(column.ColumnName)
|
||||
colName := dialect.Quoter().Quote(column.ColumnName)
|
||||
|
||||
if column.IsIdentity {
|
||||
// 11g以前的版本 如果是自增,自增列数据类型必须是number,不需要设置默认值和空值,建表后设置自增序列
|
||||
|
||||
226
server/internal/db/dbm/oracle/sqlgen.go
Normal file
226
server/internal/db/dbm/oracle/sqlgen.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
Dialect dbi.Dialect
|
||||
Metadata dbi.Metadata
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
quoter := sg.Dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
quoteTableName := quote(table.TableName)
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
dropSqlTmp := `
|
||||
declare
|
||||
num number;
|
||||
begin
|
||||
select count(1) into num from user_tables where table_name = '%s' and owner = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM dual) ;
|
||||
if num > 0 then
|
||||
execute immediate 'drop table "%s"' ;
|
||||
end if;
|
||||
end`
|
||||
sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, table.TableName, table.TableName))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s ( \n", quoteTableName)
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
// 把通用类型转换为达梦类型
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, quote(column.ColumnName))
|
||||
}
|
||||
quote := quoter.Quote
|
||||
fields = append(fields, sg.genColumnBasicSql(quoter, column))
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := dbi.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, quote(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
|
||||
// 建表
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
// 表注释
|
||||
tableCommentSql := ""
|
||||
if table.TableComment != "" {
|
||||
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", quote(table.TableName), dbi.QuoteEscape(table.TableComment))
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
|
||||
// 列注释
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
otherSql := sg.GenerateTableOtherDDL(table, quoteTableName, columns)
|
||||
if len(otherSql) > 0 {
|
||||
sqlArr = append(sqlArr, otherSql...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
sqls := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
quote := sg.Dialect.Quoter().Quote
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
|
||||
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
sqlArr = append(sqlArr, sqls...)
|
||||
|
||||
if len(comments) > 0 {
|
||||
sqlArr = append(sqlArr, comments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
quoter := sg.Dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone {
|
||||
identityInsert := fmt.Sprintf("set identity_insert %s on;", quote(tableName))
|
||||
|
||||
// 达梦数据库只能一条条的执行insert语句,所以这里需要将values拆分成多条insert语句
|
||||
return collx.ArrayMap(values, func(value []any) string {
|
||||
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeOracle, columns, [][]any{value})
|
||||
return fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
|
||||
})
|
||||
}
|
||||
|
||||
// 查询主键字段
|
||||
uniqueCols := make([]string, 0)
|
||||
caseSqls := make([]string, 0)
|
||||
metadata := sg.Metadata
|
||||
tableCols, _ := metadata.GetColumns(tableName)
|
||||
identityCols := make([]string, 0)
|
||||
for _, col := range tableCols {
|
||||
if col.IsPrimaryKey {
|
||||
uniqueCols = append(uniqueCols, col.ColumnName)
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", quote(col.ColumnName), quote(col.ColumnName)))
|
||||
}
|
||||
if col.IsIdentity {
|
||||
// 自增字段不放入insert内,即使是设置了identity_insert on也不起作用
|
||||
identityCols = append(identityCols, quote(col.ColumnName))
|
||||
}
|
||||
}
|
||||
// 查询唯一索引涉及到的字段,并组装到match条件内
|
||||
indexs, _ := metadata.GetTableIndex(tableName)
|
||||
for _, index := range indexs {
|
||||
if index.IsUnique {
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
tmp := make([]string, 0)
|
||||
for _, col := range cols {
|
||||
uniqueCols = append(uniqueCols, col)
|
||||
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
|
||||
}
|
||||
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
|
||||
}
|
||||
}
|
||||
|
||||
// 重复数据处理策略
|
||||
phs := make([]string, 0)
|
||||
insertVals := make([]string, 0)
|
||||
upds := make([]string, 0)
|
||||
insertCols := make([]string, 0)
|
||||
for _, column := range columns {
|
||||
columnName := column.ColumnName
|
||||
phs = append(phs, fmt.Sprintf("? %s", columnName))
|
||||
if !collx.ArrayContains(uniqueCols, quoter.Trim(columnName)) {
|
||||
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", columnName, columnName))
|
||||
}
|
||||
if !collx.ArrayContains(identityCols, columnName) {
|
||||
insertCols = append(insertCols, columnName)
|
||||
insertVals = append(insertVals, fmt.Sprintf("T2.%s", columnName))
|
||||
}
|
||||
|
||||
}
|
||||
t2s := make([]string, 0)
|
||||
for i := 0; i < len(values); i++ {
|
||||
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
|
||||
}
|
||||
t2 := strings.Join(t2s, " UNION ALL ")
|
||||
|
||||
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
|
||||
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
|
||||
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
|
||||
|
||||
return collx.AsArray(sqlTemp)
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
colName := quoter.Quote(column.ColumnName)
|
||||
|
||||
if column.IsIdentity {
|
||||
// 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number
|
||||
return fmt.Sprintf(" %s NUMBER generated by default as IDENTITY", colName)
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := ""
|
||||
if column.ColumnDefault != "" {
|
||||
defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault)
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble)
|
||||
return columnSql
|
||||
}
|
||||
|
||||
// 11g及以下版本会设置自增序列
|
||||
func (sg *SQLGenerator) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 11g及以下版本会设置自增序列和触发器
|
||||
func (sg *SQLGenerator) Oracle11GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string {
|
||||
result := make([]string, 0)
|
||||
for _, col := range columns {
|
||||
if col.IsIdentity {
|
||||
seqName := fmt.Sprintf("%s_%s_seq", tableInfo.TableName, col.ColumnName)
|
||||
trgName := fmt.Sprintf("%s_%s_trg", tableInfo.TableName, col.ColumnName)
|
||||
result = append(result, fmt.Sprintf("CREATE SEQUENCE %s START WITH 1 INCREMENT BY 1", seqName))
|
||||
result = append(result, fmt.Sprintf("CREATE OR REPLACE TRIGGER %s BEFORE INSERT ON %s FOR EACH ROW WHEN (NEW.%s IS NULL) BEGIN SELECT %s.nextval INTO :new.%s FROM dual; END", trgName, quoteTableName, col.ColumnName, seqName, col.ColumnName))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
97
server/internal/db/dbm/oracle/transfer.go
Normal file
97
server/internal/db/dbm/oracle/transfer.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package oracle
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
return VARCHAR2
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return CHAR
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
return NVARCHAR2
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return NVARCHAR2
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return NVARCHAR2
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIT
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return TINYINT
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return SMALLINT
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return INTEGER
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIGINT
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return NUMBER
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return DECIMAL
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return BIGINT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return INT
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return DATE
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIME
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIMESTAMP
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return TIMESTAMP
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return BLOB
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return NVARCHAR2
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return NVARCHAR2
|
||||
}
|
||||
31
server/internal/db/dbm/postgres/column.go
Normal file
31
server/internal/db/dbm/postgres/column.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
var (
|
||||
Bool = dbi.NewDbDataType("bool", dbi.DTBit).WithCT(dbi.CTBit).WithFixColumn(dbi.ClearNumScale)
|
||||
Int2 = dbi.NewDbDataType("int2", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
|
||||
Int4 = dbi.NewDbDataType("int4", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
Int8 = dbi.NewDbDataType("int8", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
|
||||
Numeric = dbi.NewDbDataType("numeric", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Smallserial = dbi.NewDbDataType("smallserial", dbi.DTInt16).WithCT(dbi.CTInt2)
|
||||
Serial = dbi.NewDbDataType("serial", dbi.DTInt32).WithCT(dbi.CTInt4)
|
||||
Bigserial = dbi.NewDbDataType("bigserial", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
Largeserial = dbi.NewDbDataType("largeserial", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
|
||||
Money = dbi.NewDbDataType("money", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
|
||||
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
|
||||
Nchar = dbi.NewDbDataType("nchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Json = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Bytea = dbi.NewDbDataType("bytea", dbi.DTString).WithCT(dbi.CTBinary)
|
||||
|
||||
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTDateTime).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
)
|
||||
@@ -1,12 +1,8 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
@@ -18,107 +14,6 @@ type PgsqlDialect struct {
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
// 执行批量insert sql,跟mysql一样 pg或高斯支持批量insert语法
|
||||
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
|
||||
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
for _, v := range values {
|
||||
args = append(args, v...)
|
||||
}
|
||||
|
||||
// 构建占位符字符串 "($1, $2, $3), ($4, $5, $6), ..." 用于指定参数
|
||||
var placeholders []string
|
||||
for i := 0; i < len(args); i += len(columns) {
|
||||
var placeholder []string
|
||||
for j := 0; j < len(columns); j++ {
|
||||
placeholder = append(placeholder, fmt.Sprintf("$%d", i+j+1))
|
||||
}
|
||||
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
|
||||
}
|
||||
|
||||
// 根据冲突策略生成后缀
|
||||
suffix := ""
|
||||
if pd.dc.Info.Type == dbi.DbTypeGauss {
|
||||
// 高斯db使用ON DUPLICATE KEY UPDATE 语法参考 https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
|
||||
suffix = pd.gaussOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
|
||||
} else {
|
||||
// pgsql 默认使用 on conflict 语法参考 http://www.postgres.cn/docs/12/sql-insert.html
|
||||
// vastbase语法参考 https://docs.vastdata.com.cn/zh/docs/VastbaseE100Ver3.0.0/doc/SQL%E8%AF%AD%E6%B3%95/INSERT.html
|
||||
// kingbase语法参考 https://help.kingbase.com.cn/v8/development/sql-plsql/sql/SQL_Statements_9.html#insert
|
||||
suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
|
||||
}
|
||||
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix)
|
||||
// 执行批量insert sql
|
||||
|
||||
return pd.dc.TxExec(tx, sqlStr, args...)
|
||||
}
|
||||
|
||||
// pgsql默认唯一键冲突策略
|
||||
func (pd *PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
|
||||
suffix := ""
|
||||
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
|
||||
suffix = " \n on conflict do nothing"
|
||||
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
// 生成 on conflict () do update set column1 = excluded.column1, column2 = excluded.column2, ...
|
||||
var updateColumns []string
|
||||
for _, col := range columns {
|
||||
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", col, col))
|
||||
}
|
||||
// 查询唯一键名,拼接冲突sql
|
||||
_, keyRes, _ := pd.dc.Query("SELECT constraint_name FROM information_schema.table_constraints WHERE constraint_schema = $1 AND table_name = $2 AND constraint_type in ('PRIMARY KEY', 'UNIQUE') ", pd.dc.Info.CurrentSchema(), tableName)
|
||||
if len(keyRes) > 0 {
|
||||
for _, re := range keyRes {
|
||||
key := anyx.ToString(re["constraint_name"])
|
||||
if key != "" {
|
||||
suffix += fmt.Sprintf(" \n on conflict on constraint %s do update set %s \n", key, strings.Join(updateColumns, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考:https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
|
||||
func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
|
||||
suffix := ""
|
||||
metadata := pd.dc.GetMetadata()
|
||||
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
|
||||
suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
|
||||
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
|
||||
// 查出表里的唯一键涉及的字段
|
||||
var uniqueColumns []string
|
||||
indexs, err := metadata.GetTableIndex(tableName)
|
||||
if err == nil {
|
||||
for _, index := range indexs {
|
||||
if index.IsUnique {
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
for _, col := range cols {
|
||||
if !collx.ArrayContains(uniqueColumns, strings.ToLower(col)) {
|
||||
uniqueColumns = append(uniqueColumns, strings.ToLower(col))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suffix = " \n ON DUPLICATE KEY UPDATE "
|
||||
for i, col := range columns {
|
||||
// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段,所以得去掉
|
||||
if !collx.ArrayContains(uniqueColumns, pd.RemoveQuote(strings.ToLower(col))) {
|
||||
suffix += fmt.Sprintf("%s = excluded.%s", col, col)
|
||||
if i < len(columns)-1 {
|
||||
suffix += ", "
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
tableName := copy.TableName
|
||||
// 生成新表名,为老表明+_copy_时间戳
|
||||
@@ -176,182 +71,13 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
quoteTableName := pd.QuoteIdentifier(tableInfo.TableName)
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
commentTmp := "comment on column %s.%s is '%s'"
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, pd.QuoteIdentifier(column.ColumnName))
|
||||
}
|
||||
|
||||
fields = append(fields, pd.genColumnBasicSql(column))
|
||||
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := pd.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteIdentifier(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
tableCommentSql := ""
|
||||
if tableInfo.TableComment != "" {
|
||||
commentTmp := "comment on table %s is '%s'"
|
||||
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteEscape(tableInfo.TableComment))
|
||||
}
|
||||
|
||||
// create
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
// table comment
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
// column comment
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
creates := make([]string, 0)
|
||||
drops := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
|
||||
// 如果索引名存在,先删除索引
|
||||
drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
|
||||
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = pd.QuoteIdentifier(name)
|
||||
}
|
||||
// 创建索引
|
||||
creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, pd.QuoteIdentifier(index.IndexName), pd.QuoteIdentifier(pd.dc.Info.CurrentSchema()), pd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
|
||||
if index.IndexComment != "" {
|
||||
comment := pd.QuoteEscape(index.IndexComment)
|
||||
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
|
||||
}
|
||||
}
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if len(drops) > 0 {
|
||||
sqlArr = append(sqlArr, drops...)
|
||||
}
|
||||
|
||||
if len(creates) > 0 {
|
||||
sqlArr = append(sqlArr, creates...)
|
||||
}
|
||||
if len(comments) > 0 {
|
||||
sqlArr = append(sqlArr, comments...)
|
||||
}
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) {
|
||||
for _, column := range columns {
|
||||
if column.IsIdentity {
|
||||
_, _ = pd.dc.Exec(fmt.Sprintf("select setval('%s_%s_seq', (SELECT max(%s) from %s))", tableName, column.ColumnName, column.ColumnName, tableName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GetDumpHelper() dbi.DumpHelper {
|
||||
return new(DumpHelper)
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
colName := pd.QuoteIdentifier(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
// 如果数据类型是数字,则去掉长度
|
||||
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
|
||||
column.NumPrecision = 0
|
||||
column.CharMaxLength = 0
|
||||
func (md *PgsqlDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{
|
||||
dialect: md,
|
||||
dc: md.dc,
|
||||
}
|
||||
|
||||
// 如果是自增类型,需要转换为serial
|
||||
if column.IsIdentity {
|
||||
if dataType == "int4" {
|
||||
column.DataType = "serial"
|
||||
} else if dataType == "int2" {
|
||||
column.DataType = "smallserial"
|
||||
} else if dataType == "int8" {
|
||||
column.DataType = "bigserial"
|
||||
} else {
|
||||
column.DataType = "bigserial"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
mark := false
|
||||
// 哪些字段类型默认值需要加引号
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
// 如果数据类型是日期时间,则写死默认值函数
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
|
||||
column.ColumnDefault = "CURRENT_TIMESTAMP"
|
||||
}
|
||||
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是varchar,长度翻倍,防止报错
|
||||
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
|
||||
column.CharMaxLength = column.CharMaxLength * 2
|
||||
}
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
|
||||
return columnSql
|
||||
}
|
||||
|
||||
@@ -4,212 +4,16 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
|
||||
// 日期类型
|
||||
dateRegexp = regexp.MustCompile(`(?i)date`)
|
||||
// 时间类型
|
||||
timeRegexp = regexp.MustCompile(`(?i)time`)
|
||||
|
||||
// 提取pg默认值, 如:'id'::varchar 提取id ; '-1'::integer 提取-1
|
||||
defaultValueRegexp = regexp.MustCompile(`'([^']*)'`)
|
||||
|
||||
// pgsql数据类型 映射 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"int2": dbi.CommonTypeSmallint,
|
||||
"int4": dbi.CommonTypeInt,
|
||||
"int8": dbi.CommonTypeBigint,
|
||||
"numeric": dbi.CommonTypeNumber,
|
||||
"decimal": dbi.CommonTypeNumber,
|
||||
"smallserial": dbi.CommonTypeSmallint,
|
||||
"serial": dbi.CommonTypeInt,
|
||||
"bigserial": dbi.CommonTypeBigint,
|
||||
"largeserial": dbi.CommonTypeBigint,
|
||||
"money": dbi.CommonTypeNumber,
|
||||
"bool": dbi.CommonTypeTinyint,
|
||||
"char": dbi.CommonTypeChar,
|
||||
"character": dbi.CommonTypeChar,
|
||||
"nchar": dbi.CommonTypeChar,
|
||||
"varchar": dbi.CommonTypeVarchar,
|
||||
"text": dbi.CommonTypeText,
|
||||
"bytea": dbi.CommonTypeText,
|
||||
"date": dbi.CommonTypeDate,
|
||||
"time": dbi.CommonTypeTime,
|
||||
"timestamp": dbi.CommonTypeTimestamp,
|
||||
}
|
||||
// 公共数据类型 映射 pgsql数据类型
|
||||
pgsqlColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "varchar",
|
||||
dbi.CommonTypeChar: "char",
|
||||
dbi.CommonTypeText: "text",
|
||||
dbi.CommonTypeBlob: "text",
|
||||
dbi.CommonTypeLongblob: "text",
|
||||
dbi.CommonTypeLongtext: "text",
|
||||
dbi.CommonTypeBinary: "text",
|
||||
dbi.CommonTypeMediumblob: "text",
|
||||
dbi.CommonTypeMediumtext: "text",
|
||||
dbi.CommonTypeVarbinary: "text",
|
||||
dbi.CommonTypeInt: "int4",
|
||||
dbi.CommonTypeSmallint: "int2",
|
||||
dbi.CommonTypeTinyint: "int2",
|
||||
dbi.CommonTypeNumber: "numeric",
|
||||
dbi.CommonTypeBigint: "int8",
|
||||
dbi.CommonTypeDatetime: "timestamp",
|
||||
dbi.CommonTypeDate: "date",
|
||||
dbi.CommonTypeTime: "time",
|
||||
dbi.CommonTypeTimestamp: "timestamp",
|
||||
dbi.CommonTypeEnum: "varchar(2000)",
|
||||
dbi.CommonTypeJSON: "varchar(2000)",
|
||||
}
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if dateRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if timeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
str := fmt.Sprintf("%v", dbColumnValue)
|
||||
switch dataType {
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, err = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateOnly)
|
||||
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.TimeOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return cast.ToString(dbColumnValue)
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型
|
||||
_, ok := dbColumnValue.(string)
|
||||
if dataType == dbi.DataTypeDateTime && ok {
|
||||
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeDate && ok {
|
||||
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeTime && ok {
|
||||
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(column *dbi.Column) {
|
||||
// 翻译为通用数据库类型
|
||||
dataType := column.DataType
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
if t1 == "" {
|
||||
column.DataType = dbi.CommonTypeVarchar
|
||||
column.CharMaxLength = 2000
|
||||
} else {
|
||||
column.DataType = t1
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
|
||||
ctype := pgsqlColumnTypeMap[commonColumn.DataType]
|
||||
|
||||
if ctype == "" {
|
||||
commonColumn.DataType = "varchar"
|
||||
commonColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
commonColumn.DataType = dbi.ColumnDataType(ctype)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
dataType := strings.ToLower(string(column.DataType))
|
||||
// 哪些字段可以指定长度
|
||||
if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, dataType) {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
} else if strings.Contains(dataType, "char") {
|
||||
// 如果类型是文本,长度翻倍
|
||||
column.CharMaxLength = column.CharMaxLength * 2
|
||||
}
|
||||
func FixColumnDefault(column *dbi.Column) {
|
||||
// 如果默认值带冒号,如:'id'::varchar
|
||||
if column.ColumnDefault != "" && strings.Contains(column.ColumnDefault, "::") && !strings.HasPrefix(column.ColumnDefault, "nextval") {
|
||||
match := defaultValueRegexp.FindStringSubmatch(column.ColumnDefault)
|
||||
|
||||
@@ -16,16 +16,23 @@ import (
|
||||
|
||||
func init() {
|
||||
meta := new(Meta)
|
||||
dbi.Register(dbi.DbTypePostgres, meta)
|
||||
dbi.Register(dbi.DbTypeKingbaseEs, meta)
|
||||
dbi.Register(dbi.DbTypeVastbase, meta)
|
||||
dbi.Register(DbTypePostgres, meta)
|
||||
dbi.Register(DbTypeKingbaseEs, meta)
|
||||
dbi.Register(DbTypeVastbase, meta)
|
||||
|
||||
gauss := &Meta{
|
||||
Param: "dbtype=gauss",
|
||||
}
|
||||
dbi.Register(dbi.DbTypeGauss, gauss)
|
||||
dbi.Register(DbTypeGauss, gauss)
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypePostgres dbi.DbType = "postgres"
|
||||
DbTypeGauss dbi.DbType = "gauss"
|
||||
DbTypeKingbaseEs dbi.DbType = "kingbaseEs"
|
||||
DbTypeVastbase dbi.DbType = "vastbase"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
Param string
|
||||
}
|
||||
@@ -89,6 +96,20 @@ func (pm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
return &PgsqlMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (pm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray(
|
||||
Bool, Int2, Int4, Int8, Numeric, Decimal, Smallserial, Serial, Bigserial, Largeserial,
|
||||
Money,
|
||||
Char, Nchar, Varchar, Text, Json,
|
||||
Date, Time, Timestamp,
|
||||
Bytea,
|
||||
)
|
||||
}
|
||||
|
||||
func (pm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
// pgsql dialer
|
||||
type PqSqlDialer struct {
|
||||
sshTunnelMachineId int
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
@@ -54,7 +53,7 @@ func (pd *PgsqlMetadata) GetDbNames() ([]string, error) {
|
||||
func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := pd.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -88,7 +87,7 @@ func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := pd.dc.GetDialect()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
|
||||
@@ -96,13 +95,12 @@ func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["tableName"]),
|
||||
ColumnName: cast.ToString(re["columnName"]),
|
||||
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
|
||||
DataType: cast.ToString(re["dataType"]),
|
||||
CharMaxLength: cast.ToInt(re["charMaxLength"]),
|
||||
ColumnComment: cast.ToString(re["columnComment"]),
|
||||
Nullable: cast.ToString(re["nullable"]) == "YES",
|
||||
@@ -112,7 +110,9 @@ func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
NumPrecision: cast.ToInt(re["numPrecision"]),
|
||||
NumScale: cast.ToInt(re["numScale"]),
|
||||
}
|
||||
columnHelper.FixColumn(&column)
|
||||
|
||||
pd.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
FixColumnDefault(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -164,7 +164,9 @@ func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
// 索引字段已根据名称和顺序排序,故取最后一个即可
|
||||
i := len(result) - 1
|
||||
// 同索引字段以逗号连接
|
||||
result[i].ColumnName = result[i].ColumnName + "," + v.ColumnName
|
||||
if result[i].ColumnName != v.ColumnName {
|
||||
result[i].ColumnName = result[i].ColumnName + "," + v.ColumnName
|
||||
}
|
||||
} else {
|
||||
key = in
|
||||
result = append(result, v)
|
||||
@@ -175,33 +177,7 @@ func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (pd *PgsqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
// 1.获取表信息
|
||||
tbs, err := pd.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := pd.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
dialect := pd.dc.GetDialect()
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := pd.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(pd.dc.GetDialect(), pd, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
// 获取pgsql当前连接的库可访问的schemaNames
|
||||
@@ -220,11 +196,11 @@ func (pd *PgsqlMetadata) GetSchemas() ([]string, error) {
|
||||
|
||||
func (pd *PgsqlMetadata) GetDefaultDb() string {
|
||||
switch pd.dc.Info.Type {
|
||||
case dbi.DbTypePostgres, dbi.DbTypeGauss:
|
||||
case DbTypePostgres, DbTypeGauss:
|
||||
return "postgres"
|
||||
case dbi.DbTypeKingbaseEs:
|
||||
case DbTypeKingbaseEs:
|
||||
return "security"
|
||||
case dbi.DbTypeVastbase:
|
||||
case DbTypeVastbase:
|
||||
return "vastbase"
|
||||
default:
|
||||
return ""
|
||||
|
||||
263
server/internal/db/dbm/postgres/sqlgen.go
Normal file
263
server/internal/db/dbm/postgres/sqlgen.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
dialect dbi.Dialect
|
||||
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
quoter := msg.dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
quoteTableName := quote(table.TableName)
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
columnComments := make([]string, 0)
|
||||
commentTmp := "COMMENT ON COLUMN %s.%s IS '%s'"
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, quote(column.ColumnName))
|
||||
}
|
||||
|
||||
fields = append(fields, msg.genColumnBasicSql(quoter, column))
|
||||
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
if column.ColumnComment != "" {
|
||||
comment := dbi.QuoteEscape(column.ColumnComment)
|
||||
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, quote(column.ColumnName), comment))
|
||||
}
|
||||
}
|
||||
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
tableCommentSql := ""
|
||||
if table.TableComment != "" {
|
||||
commentTmp := "COMMENT ON TABLE %s IS '%s'"
|
||||
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, dbi.QuoteEscape(table.TableComment))
|
||||
}
|
||||
|
||||
// create
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
// table comment
|
||||
if tableCommentSql != "" {
|
||||
sqlArr = append(sqlArr, tableCommentSql)
|
||||
}
|
||||
// column comment
|
||||
if len(columnComments) > 0 {
|
||||
sqlArr = append(sqlArr, columnComments...)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
quoter := msg.dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
|
||||
creates := make([]string, 0)
|
||||
drops := make([]string, 0)
|
||||
comments := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = " unique"
|
||||
}
|
||||
|
||||
currentSchema := msg.dc.Info.CurrentSchema()
|
||||
// 带上后缀. 避免后续判断
|
||||
if currentSchema != "" {
|
||||
currentSchema = quote(currentSchema) + "."
|
||||
}
|
||||
|
||||
// 如果索引名存在,先删除索引
|
||||
drops = append(drops, fmt.Sprintf("DROP INDEX IF EXISTS %s%s", currentSchema, index.IndexName))
|
||||
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
// 创建索引
|
||||
creates = append(creates, fmt.Sprintf("CREATE%s INDEX %s ON %s%s(%s)", unique, quote(index.IndexName), currentSchema, quote(table.TableName), strings.Join(colNames, ",")))
|
||||
if index.IndexComment != "" {
|
||||
comment := dbi.QuoteEscape(index.IndexComment)
|
||||
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s%s IS '%s'", currentSchema, index.IndexName, comment))
|
||||
}
|
||||
}
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if len(drops) > 0 {
|
||||
sqlArr = append(sqlArr, drops...)
|
||||
}
|
||||
|
||||
if len(creates) > 0 {
|
||||
sqlArr = append(sqlArr, creates...)
|
||||
}
|
||||
if len(comments) > 0 {
|
||||
sqlArr = append(sqlArr, comments...)
|
||||
}
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (psg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
insertSql := dbi.GenCommonInsert(psg.dialect, psg.dc.Info.Type, tableName, columns, values)
|
||||
|
||||
// 根据冲突策略生成后缀
|
||||
suffix := ""
|
||||
if psg.dc.Info.Type == DbTypeGauss {
|
||||
// 高斯db使用ON DUPLICATE KEY UPDATE 语法参考 https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
|
||||
suffix = psg.gaussOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
|
||||
} else {
|
||||
// pgsql 默认使用 on conflict 语法参考 http://www.postgres.cn/docs/12/sql-insert.html
|
||||
// vastbase语法参考 https://docs.vastdata.com.cn/zh/docs/VastbaseE100Ver3.0.0/doc/SQL%E8%AF%AD%E6%B3%95/INSERT.html
|
||||
// kingbase语法参考 https://help.kingbase.com.cn/v8/development/sql-plsql/sql/SQL_Statements_9.html#insert
|
||||
suffix = psg.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
|
||||
}
|
||||
|
||||
return collx.AsArray[string](insertSql + suffix)
|
||||
}
|
||||
|
||||
// pgsql默认唯一键冲突策略
|
||||
func (psg *SQLGenerator) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []dbi.Column) string {
|
||||
suffix := ""
|
||||
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
|
||||
suffix = " \n on conflict do nothing"
|
||||
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
// 生成 on conflict () do update set column1 = excluded.column1, column2 = excluded.column2, ...
|
||||
var updateColumns []string
|
||||
for _, col := range columns {
|
||||
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", col.ColumnName, col.ColumnName))
|
||||
}
|
||||
// 查询唯一键名,拼接冲突sql
|
||||
_, keyRes, _ := psg.dc.Query("SELECT constraint_name FROM information_schema.table_constraints WHERE constraint_schema = $1 AND table_name = $2 AND constraint_type in ('PRIMARY KEY', 'UNIQUE') ", psg.dc.Info.CurrentSchema(), tableName)
|
||||
if len(keyRes) > 0 {
|
||||
for _, re := range keyRes {
|
||||
key := anyx.ToString(re["constraint_name"])
|
||||
if key != "" {
|
||||
suffix += fmt.Sprintf(" \n on conflict on constraint %s do update set %s \n", key, strings.Join(updateColumns, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考:https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
|
||||
func (psg *SQLGenerator) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []dbi.Column) string {
|
||||
suffix := ""
|
||||
metadata := psg.dc.GetMetadata()
|
||||
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
|
||||
suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
|
||||
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
|
||||
// 查出表里的唯一键涉及的字段
|
||||
var uniqueColumns []string
|
||||
indexs, err := metadata.GetTableIndex(tableName)
|
||||
if err == nil {
|
||||
for _, index := range indexs {
|
||||
if index.IsUnique {
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
for _, col := range cols {
|
||||
if !collx.ArrayContains(uniqueColumns, strings.ToLower(col)) {
|
||||
uniqueColumns = append(uniqueColumns, strings.ToLower(col))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suffix = " \n ON DUPLICATE KEY UPDATE "
|
||||
for i, col := range columns {
|
||||
// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段,所以得去掉
|
||||
if !collx.ArrayContains(uniqueColumns, psg.dialect.Quoter().Trim(strings.ToLower(col.ColumnName))) {
|
||||
suffix += fmt.Sprintf("%s = excluded.%s", col.ColumnName, col.ColumnName)
|
||||
if i < len(columns)-1 {
|
||||
suffix += ", "
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return suffix
|
||||
}
|
||||
|
||||
func (pd *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
colName := quoter.Quote(column.ColumnName)
|
||||
dataType := string(column.DataType)
|
||||
|
||||
// 如果数据类型是数字,则去掉长度
|
||||
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
|
||||
column.NumPrecision = 0
|
||||
column.CharMaxLength = 0
|
||||
}
|
||||
|
||||
// 如果是自增类型,需要转换为serial
|
||||
if column.IsIdentity {
|
||||
if dataType == "int4" {
|
||||
column.DataType = "serial"
|
||||
} else if dataType == "int2" {
|
||||
column.DataType = "smallserial"
|
||||
} else if dataType == "int8" {
|
||||
column.DataType = "bigserial"
|
||||
} else {
|
||||
column.DataType = "bigserial"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
mark := false
|
||||
// 哪些字段类型默认值需要加引号
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
// 如果数据类型是日期时间,则写死默认值函数
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
|
||||
column.ColumnDefault = "CURRENT_TIMESTAMP"
|
||||
}
|
||||
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
|
||||
return columnSql
|
||||
}
|
||||
97
server/internal/db/dbm/postgres/transfer.go
Normal file
97
server/internal/db/dbm/postgres/transfer.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package postgres
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varchar
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return Char
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int2
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int2
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int2
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int4
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int8
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return Numeric
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return Decimal
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int8
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int4
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int2
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int2
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return Date
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return Time
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return Timestamp
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return Timestamp
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bytea
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bytea
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bytea
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bytea
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bytea
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varchar
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return Json
|
||||
}
|
||||
14
server/internal/db/dbm/sqlite/column.go
Normal file
14
server/internal/db/dbm/sqlite/column.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package sqlite
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var (
|
||||
Integer = dbi.NewDbDataType("integer", dbi.DTInt64).WithCT(dbi.CTInt8)
|
||||
Real = dbi.NewDbDataType("real", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText)
|
||||
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
|
||||
|
||||
DateTime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
|
||||
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
|
||||
)
|
||||
@@ -1,10 +1,8 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -15,42 +13,6 @@ type SqliteDialect struct {
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
_, _ = sd.dc.Exec("PRAGMA foreign_keys = false")
|
||||
// 执行批量insert sql,跟mysql一样 支持批量insert语法
|
||||
// 生成占位符字符串:如:(?,?)
|
||||
// 重复字符串并用逗号连接
|
||||
repeated := strings.Repeat("?,", len(columns))
|
||||
// 去除最后一个逗号,占位符由括号包裹
|
||||
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
|
||||
|
||||
// 重复占位符字符串n遍
|
||||
repeated = strings.Repeat(placeholder+",", len(values))
|
||||
// 去除最后一个逗号
|
||||
placeholder = strings.TrimSuffix(repeated, ",")
|
||||
|
||||
prefix := "insert into"
|
||||
if duplicateStrategy == 1 {
|
||||
prefix = "insert or ignore into"
|
||||
} else if duplicateStrategy == 2 {
|
||||
prefix = "insert or replace into"
|
||||
}
|
||||
|
||||
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
for _, v := range values {
|
||||
args = append(args, v...)
|
||||
}
|
||||
exec, err := sd.dc.TxExec(tx, sqlStr, args...)
|
||||
|
||||
_, _ = sd.dc.Exec("PRAGMA foreign_keys = true;")
|
||||
|
||||
// 执行批量insert sql
|
||||
return exec, err
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
tableName := copy.TableName
|
||||
|
||||
@@ -83,101 +45,12 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取建表ddl
|
||||
func (sd *SqliteDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
tbName := sd.QuoteIdentifier(tableInfo.TableName)
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName)
|
||||
fields := make([]string, 0)
|
||||
|
||||
// 把通用类型转换为达梦类型
|
||||
for _, column := range columns {
|
||||
fields = append(fields, sd.genColumnBasicSql(column))
|
||||
}
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
createSql += "\n)"
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTOINCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
quoteColumnName := sd.QuoteIdentifier(column.ColumnName)
|
||||
|
||||
// 如果是主键,则直接返回,不判断默认值
|
||||
if column.IsPrimaryKey {
|
||||
return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble)
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal)
|
||||
}
|
||||
|
||||
// 获取建索引ddl
|
||||
func (sd *SqliteDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
sqls := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = sd.QuoteIdentifier(name)
|
||||
}
|
||||
// 创建前尝试删除
|
||||
sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName))
|
||||
|
||||
sqlTmp := "CREATE %s INDEX %s ON %s (%s) "
|
||||
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, sd.QuoteIdentifier(index.IndexName), sd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) GetDumpHelper() dbi.DumpHelper {
|
||||
return new(DumpHelper)
|
||||
}
|
||||
|
||||
func (sd *SqliteDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{
|
||||
dialect: sd,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,180 +1,10 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime`)
|
||||
|
||||
dataTypeRegexp = regexp.MustCompile(`(\w+)\((\d*),?(\d*)\)`)
|
||||
|
||||
// sqlite数据类型 映射 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"int": dbi.CommonTypeInt,
|
||||
"integer": dbi.CommonTypeInt,
|
||||
"tinyint": dbi.CommonTypeTinyint,
|
||||
"smallint": dbi.CommonTypeSmallint,
|
||||
"mediumint": dbi.CommonTypeSmallint,
|
||||
"bigint": dbi.CommonTypeBigint,
|
||||
"int2": dbi.CommonTypeInt,
|
||||
"int8": dbi.CommonTypeInt,
|
||||
"character": dbi.CommonTypeChar,
|
||||
"varchar": dbi.CommonTypeVarchar,
|
||||
"varying character": dbi.CommonTypeVarchar,
|
||||
"nchar": dbi.CommonTypeChar,
|
||||
"native character": dbi.CommonTypeVarchar,
|
||||
"nvarchar": dbi.CommonTypeVarchar,
|
||||
"text": dbi.CommonTypeText,
|
||||
"clob": dbi.CommonTypeBlob,
|
||||
"blob": dbi.CommonTypeBlob,
|
||||
"real": dbi.CommonTypeNumber,
|
||||
"double": dbi.CommonTypeNumber,
|
||||
"double precision": dbi.CommonTypeNumber,
|
||||
"float": dbi.CommonTypeNumber,
|
||||
"numeric": dbi.CommonTypeNumber,
|
||||
"decimal": dbi.CommonTypeNumber,
|
||||
"boolean": dbi.CommonTypeTinyint,
|
||||
"date": dbi.CommonTypeDate,
|
||||
"datetime": dbi.CommonTypeDatetime,
|
||||
}
|
||||
|
||||
// 公共数据类型 映射 sqlite数据类型
|
||||
sqliteColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "nvarchar",
|
||||
dbi.CommonTypeChar: "nchar",
|
||||
dbi.CommonTypeText: "text",
|
||||
dbi.CommonTypeBlob: "blob",
|
||||
dbi.CommonTypeLongblob: "blob",
|
||||
dbi.CommonTypeLongtext: "text",
|
||||
dbi.CommonTypeBinary: "text",
|
||||
dbi.CommonTypeMediumblob: "blob",
|
||||
dbi.CommonTypeMediumtext: "text",
|
||||
dbi.CommonTypeVarbinary: "text",
|
||||
dbi.CommonTypeInt: "int",
|
||||
dbi.CommonTypeSmallint: "smallint",
|
||||
dbi.CommonTypeTinyint: "tinyint",
|
||||
dbi.CommonTypeNumber: "number",
|
||||
dbi.CommonTypeBigint: "bigint",
|
||||
dbi.CommonTypeDatetime: "datetime",
|
||||
dbi.CommonTypeDate: "date",
|
||||
dbi.CommonTypeTime: "datetime",
|
||||
dbi.CommonTypeTimestamp: "datetime",
|
||||
dbi.CommonTypeEnum: "nvarchar(2000)",
|
||||
dbi.CommonTypeJSON: "nvarchar(2000)",
|
||||
}
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
str := anyx.ToString(dbColumnValue)
|
||||
switch dataType {
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
|
||||
// 尝试用时间格式解析
|
||||
_, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ := time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
|
||||
// 尝试用时间格式解析
|
||||
_, err := time.Parse(time.DateOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ := time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateOnly)
|
||||
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
|
||||
// 尝试用时间格式解析
|
||||
_, err := time.Parse(time.TimeOnly, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ := time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
// 翻译为通用数据库类型
|
||||
dataType := dialectColumn.DataType
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
if t1 == "" {
|
||||
dialectColumn.DataType = dbi.CommonTypeVarchar
|
||||
dialectColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
dialectColumn.DataType = t1
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
|
||||
ctype := sqliteColumnTypeMap[commonColumn.DataType]
|
||||
if ctype == "" {
|
||||
commonColumn.DataType = "nvarchar"
|
||||
commonColumn.CharMaxLength = 2000
|
||||
} else {
|
||||
ch.FixColumn(commonColumn)
|
||||
}
|
||||
}
|
||||
|
||||
type DumpHelper struct {
|
||||
dbi.DefaultDumpHelper
|
||||
}
|
||||
|
||||
@@ -4,13 +4,18 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"os"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dbi.Register(dbi.DbTypeSqlite, new(Meta))
|
||||
dbi.Register(DbTypeSqlite, new(Meta))
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeSqlite dbi.DbType = "sqlite"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
}
|
||||
|
||||
@@ -36,3 +41,16 @@ func (sm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
func (sm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
return &SqliteMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray(
|
||||
Integer, Real,
|
||||
Text,
|
||||
Blob,
|
||||
DateTime, Date, Time,
|
||||
)
|
||||
}
|
||||
|
||||
func (sm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,10 @@ const (
|
||||
SQLITE_INDEX_INFO_KEY = "SQLITE_INDEX_INFO"
|
||||
)
|
||||
|
||||
var (
|
||||
dataTypeRegexp = regexp.MustCompile(`(\w+)\((\d*),?(\d*)\)`)
|
||||
)
|
||||
|
||||
type SqliteMetadata struct {
|
||||
dbi.DefaultMetadata
|
||||
|
||||
@@ -53,7 +57,7 @@ func (sd *SqliteMetadata) GetDbNames() ([]string, error) {
|
||||
func (sd *SqliteMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := sd.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -98,7 +102,6 @@ func (sd *SqliteMetadata) getDataTypes(dataType string) (string, string, string)
|
||||
// 获取列元信息, 如列名等
|
||||
func (sd *SqliteMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
columns := make([]dbi.Column, 0)
|
||||
columnHelper := sd.dc.GetDialect().GetColumnHelper()
|
||||
for i := 0; i < len(tableNames); i++ {
|
||||
tableName := tableNames[i]
|
||||
_, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName))
|
||||
@@ -134,9 +137,9 @@ func (sd *SqliteMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
} else {
|
||||
column.CharMaxLength = cast.ToInt(length)
|
||||
}
|
||||
column.DataType = dbi.ColumnDataType(strings.ToLower(dataType))
|
||||
columnHelper.FixColumn(&column)
|
||||
column.DataType = strings.ToLower(dataType)
|
||||
|
||||
sd.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
}
|
||||
|
||||
123
server/internal/db/dbm/sqlite/sqlgen.go
Normal file
123
server/internal/db/dbm/sqlite/sqlgen.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
dialect dbi.Dialect
|
||||
}
|
||||
|
||||
func (ssg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
quoter := ssg.dialect.Quoter()
|
||||
|
||||
sqlArr := make([]string, 0)
|
||||
tbName := ssg.dialect.Quoter().Quote(table.TableName)
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName))
|
||||
}
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName)
|
||||
fields := make([]string, 0)
|
||||
|
||||
// 把通用类型转换为达梦类型
|
||||
for _, column := range columns {
|
||||
fields = append(fields, ssg.genColumnBasicSql(quoter, column))
|
||||
}
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
createSql += "\n)"
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (ssg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
quoter := ssg.dialect.Quoter()
|
||||
quote := quoter.Quote
|
||||
|
||||
sqls := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = quote(name)
|
||||
}
|
||||
// 创建前尝试删除
|
||||
sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName))
|
||||
|
||||
sqlTmp := "CREATE %s INDEX %s ON %s (%s) "
|
||||
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
|
||||
}
|
||||
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (ssg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone {
|
||||
return collx.AsArray(dbi.GenCommonInsert(ssg.dialect, DbTypeSqlite, tableName, columns, values))
|
||||
}
|
||||
|
||||
sqls := make([]string, 0)
|
||||
sqls = append(sqls, "PRAGMA foreign_keys = false")
|
||||
|
||||
prefix := "insert or ignore into"
|
||||
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
prefix = "insert or replace into"
|
||||
}
|
||||
|
||||
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(ssg.dialect, DbTypeSqlite, columns, values)
|
||||
|
||||
sqls = append(sqls, "PRAGMA foreign_keys = true")
|
||||
sqls = append(sqls, fmt.Sprintf("%s %s %s VALUES \n%s", prefix, ssg.dialect.Quoter().Quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
|
||||
return sqls
|
||||
}
|
||||
|
||||
func (ssg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTOINCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
|
||||
quoteColumnName := quoter.Quote(column.ColumnName)
|
||||
|
||||
// 如果是主键,则直接返回,不判断默认值
|
||||
if column.IsPrimaryKey {
|
||||
return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble)
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal)
|
||||
}
|
||||
97
server/internal/db/dbm/sqlite/transfer.go
Normal file
97
server/internal/db/dbm/sqlite/transfer.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package sqlite
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return Real
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return Real
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Integer
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return Date
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return Time
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return DateTime
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return DateTime
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return Text
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (d *DbJobBaseImpl) setLastStatus(jobType DbJobType, status DbJobStatus, err
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("%s: %v", result, err)
|
||||
}
|
||||
d.LastResult = stringx.TruncateStr(result, LastResultSize)
|
||||
d.LastResult = stringx.Truncate(result, LastResultSize, LastResultSize, "")
|
||||
d.LastTime = timex.NewNullTime(time.Now())
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ var En = map[i18n.MsgId]string{
|
||||
|
||||
SqlScriptRunFail: "sql script failed to execute",
|
||||
SqlScriptRunSuccess: "sql script executed successfully",
|
||||
SqlScripRunProgress: "sql script execution progress",
|
||||
SqlScripRunProgress: "sql execution progress",
|
||||
DbDumpErr: "Database export failed",
|
||||
ErrDbNameExist: "The database name already exists in this instance",
|
||||
ErrDbNotAccess: "The operation permissions of database [{{.dbName}}] are not configured",
|
||||
|
||||
@@ -16,7 +16,7 @@ var Zh_CN = map[i18n.MsgId]string{
|
||||
|
||||
SqlScriptRunFail: "sql脚本执行失败",
|
||||
SqlScriptRunSuccess: "sql脚本执行成功",
|
||||
SqlScripRunProgress: "sql脚本执行进度",
|
||||
SqlScripRunProgress: "sql执行进度",
|
||||
DbDumpErr: "数据库导出失败",
|
||||
ErrDbNameExist: "该实例下数据库名已存在",
|
||||
ErrDbNotAccess: "未配置数据库【{{.dbName}}】的操作权限",
|
||||
|
||||
@@ -12,7 +12,7 @@ type dbRepoImpl struct {
|
||||
}
|
||||
|
||||
func newDbRepo() repository.Db {
|
||||
return &dbRepoImpl{base.RepoImpl[*entity.Db]{M: new(entity.Db)}}
|
||||
return &dbRepoImpl{}
|
||||
}
|
||||
|
||||
// 分页获取数据库信息列表
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user