refactor: dbm

This commit is contained in:
meilin.huang
2024-12-08 13:04:23 +08:00
parent ebc89e056f
commit e56788af3e
152 changed files with 4273 additions and 3715 deletions

View File

@@ -11,7 +11,7 @@
},
"dependencies": {
"@element-plus/icons-vue": "^2.3.1",
"@vueuse/core": "^11.3.0",
"@vueuse/core": "^12.0.0",
"asciinema-player": "^3.8.1",
"axios": "^1.6.2",
"clipboard": "^2.0.11",
@@ -19,7 +19,7 @@
"crypto-js": "^4.2.0",
"dayjs": "^1.11.13",
"echarts": "^5.5.1",
"element-plus": "^2.8.8",
"element-plus": "^2.9.0",
"js-base64": "^3.7.7",
"jsencrypt": "^3.3.2",
"lodash": "^4.17.21",
@@ -28,16 +28,16 @@
"monaco-sql-languages": "^0.12.2",
"monaco-themes": "^0.4.4",
"nprogress": "^0.2.0",
"pinia": "^2.2.7",
"qrcode.vue": "^3.5.1",
"pinia": "^2.3.0",
"qrcode.vue": "^3.6.0",
"screenfull": "^6.0.2",
"sortablejs": "^1.15.3",
"sortablejs": "^1.15.6",
"splitpanes": "^3.1.5",
"sql-formatter": "^15.4.5",
"trzsz": "^1.1.5",
"uuid": "^9.0.1",
"vue": "^3.5.13",
"vue-i18n": "^10.0.4",
"vue-i18n": "^10.0.5",
"vue-router": "^4.5.0",
"xterm": "^5.3.0",
"xterm-addon-fit": "^0.8.0",
@@ -59,9 +59,9 @@
"eslint": "^8.35.0",
"eslint-plugin-vue": "^9.31.0",
"prettier": "^3.2.5",
"sass": "^1.81.0",
"sass": "^1.82.0",
"typescript": "^5.7.2",
"vite": "^6.0.1",
"vite": "^6.0.3",
"vue-eslint-parser": "^9.4.3"
},
"browserslist": [

View File

@@ -7,7 +7,7 @@ export default {
getPublicKey: () => request.get('/common/public-key'),
getConfigValue: (params: any) => request.get('/sys/configs/value', params),
getServerConf: () => request.get('/sys/configs/server'),
oauth2LoginConfig: () => request.get('/auth/oauth2-config'),
oauth2LoginConfig: () => request.get('/auth/oauth2/config'),
changePwd: (param: any) => request.post('/sys/accounts/change-pwd', param),
captcha: () => request.get('/sys/captcha'),
logout: () => request.post('/auth/accounts/logout'),

View File

@@ -0,0 +1,47 @@
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import syssocket from './syssocket';
import { h, reactive } from 'vue';
import { ElNotification } from 'element-plus';
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
export function initSysMsgs() {
registerDbSqlExecProgress();
}
const sqlExecNotifyMap: Map<string, any> = new Map();
function registerDbSqlExecProgress() {
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
const content = JSON.parse(message.msg);
const id = content.id;
let progress = sqlExecNotifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
sqlExecNotifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!sqlExecNotifyMap.has(id)) {
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: h(ProgressNotify, progress.props),
type: syssocket.getMsgType(message.type),
showClose: false,
});
sqlExecNotifyMap.set(id, progress);
}
});
}

View File

@@ -1,9 +1,9 @@
import Config from './config';
import {ElNotification} from 'element-plus';
import SocketBuilder from './SocketBuilder';
import {getToken} from '@/common/utils/storage';
import { getToken } from '@/common/utils/storage';
import {joinClientParams} from './request';
import { joinClientParams } from './request';
import { ElNotification } from 'element-plus';
class SysSocket {
/**
@@ -23,7 +23,6 @@ class SysSocket {
0: 'error',
1: 'success',
2: 'info',
22: 'info',
};
/**
@@ -57,21 +56,16 @@ class SysSocket {
return;
}
// 默认通知处理
const type = this.getMsgType(message.type);
let msg = message.msg
let duration = 0
if (message.type == 22) {
let obj = JSON.parse(msg);
msg = `文件:${obj['title']} 执行成功: ${obj['executedStatements']}`
duration = 2000
}
let msg = message.msg;
let duration = 0;
ElNotification({
duration: duration,
title: message.title,
message: msg,
type: type,
});
console.log(message)
})
.open((event: any) => console.log(event))
.close(() => {

View File

@@ -1,5 +1,5 @@
<template>
<el-descriptions border size="small" :title="`${progress.title}`">
<el-descriptions border size="small" :title="`${props.progress.title}`">
<el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
<el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
</el-descriptions>

View File

@@ -214,7 +214,7 @@ export default {
getDbNamesModeAssign: 'Specifying the db name',
ignore: 'Ignore',
replate: 'Replate',
replace: 'Replate',
running: 'Running',
waitRun: 'Wait Run',

View File

@@ -210,7 +210,7 @@ export default {
getDbNamesModeAssign: '指定库名',
ignore: '忽略',
replate: '替换',
replace: '替换',
running: '运行中',
waitRun: '待运行',

View File

@@ -18,11 +18,13 @@ import '@/theme/index.scss';
import '@/assets/font/font.css';
import '@/assets/iconfont/iconfont.js';
import { getThemeConfig } from './common/utils/storage';
import { initSysMsgs } from './common/sysmsgs';
const app = createApp(App);
registElSvgIcon(app);
directive(app);
initSysMsgs();
app.use(pinia).use(router).use(i18n).use(ElementPlus, { size: getThemeConfig()?.globalComponentSize }).mount('#app');

View File

@@ -9,13 +9,13 @@
v-model:selection-data="state.selectionData"
:columns="columns"
@page-num-change="
(args) => {
(args: any) => {
state.query.pageNum = args.pageNum;
search();
}
"
@page-size-change="
(args) => {
(args: any) => {
state.query.pageSize = args.pageNum;
search();
}
@@ -85,14 +85,14 @@ import { dbApi } from '@/views/ops/db/api';
import { getDbDialect } from '@/views/ops/db/dialect';
import PageTable from '@/components/pagetable/PageTable.vue';
import { TableColumn } from '@/components/pagetable';
import { ElMessage, ElMessageBox } from 'element-plus';
import { ElMessage } from 'element-plus';
import { hasPerms } from '@/components/auth/auth';
import TerminalLog from '@/components/terminal/TerminalLog.vue';
import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue';
import { getClientId } from '@/common/utils/storage';
import FileInfo from '@/components/file/FileInfo.vue';
import { DbTransferFileStatusEnum } from './enums';
import { useI18nDeleteConfirm, useI18nDeleteSuccessMsg, useI18nFormValidate, useI18nPleaseSelect, useI18nSaveSuccessMsg } from '@/hooks/useI18n';
import { useI18nDeleteConfirm, useI18nDeleteSuccessMsg, useI18nFormValidate, useI18nOperateSuccessMsg, useI18nPleaseSelect } from '@/hooks/useI18n';
import { useI18n } from 'vue-i18n';
const { t } = useI18n();
@@ -179,14 +179,13 @@ const state = reactive({
},
btnOk: async function () {
await useI18nFormValidate(runFormRef);
console.log(state.runDialog.runForm);
if (state.runDialog.runForm.targetDbType !== state.runDialog.runForm.dbType) {
ElMessage.warning(t('db.targetDbTypeSelectError', { dbType: state.runDialog.runForm.dbType }));
return false;
}
state.runDialog.runForm.clientId = getClientId();
await dbApi.dbTransferFileRun.request(state.runDialog.runForm);
useI18nSaveSuccessMsg();
useI18nOperateSuccessMsg();
state.runDialog.cancel();
await search();
},
@@ -228,9 +227,7 @@ const openLog = function (data: any) {
// 运行sql弹出选择需要运行的库默认运行当前数据库需要保证数据库类型与sql文件一致
const openRun = function (data: any) {
console.log(data);
state.runDialog.runForm = { id: data.id, dbType: data.fileDbType } as any;
console.log(state.runDialog.runForm);
state.runDialog.visible = true;
};

View File

@@ -205,7 +205,7 @@ const editInstance = async (data: any) => {
const deleteInstance = async () => {
try {
useI18nDeleteConfirm(state.selectionData.map((x: any) => x.name).join('、'));
await useI18nDeleteConfirm(state.selectionData.map((x: any) => x.name).join('、'));
await dbApi.deleteInstance.request({ id: state.selectionData.map((x: any) => x.id).join(',') });
useI18nDeleteSuccessMsg();
search();

View File

@@ -140,7 +140,7 @@
<el-table-column prop="src" :label="$t('db.srcField')" :width="200" />
<el-table-column prop="target" :label="$t('db.targetField')">
<template #default="scope">
<el-select v-model="scope.row.target">
<el-select v-model="scope.row.target" allow-create filterable>
<el-option
v-for="item in state.targetColumnList"
:key="item.columnName"
@@ -502,11 +502,13 @@ const handleGetSrcFields = async () => {
sql,
});
if (!res.columns) {
if (res.length && !res[0].columns) {
ElMessage.warning(t('db.notColumnSql'));
return;
}
let data = res[0];
let filedMap: any = {};
if (state.form.fieldMap && state.form.fieldMap.length > 0) {
state.form.fieldMap.forEach((a: any) => {
@@ -514,11 +516,11 @@ const handleGetSrcFields = async () => {
});
}
state.srcTableFields = res.columns.map((a: any) => a.name);
state.srcTableFields = data.columns.map((a: any) => a.name);
state.form.fieldMap = res.columns.map((a: any) => ({ src: a.name, target: filedMap[a.name] || '' }));
state.form.fieldMap = data.columns.map((a: any) => ({ src: a.name, target: filedMap[a.name] || '' }));
state.previewRes = res;
state.previewRes = data;
};
const handleGetTargetFields = async () => {

View File

@@ -104,7 +104,7 @@
</el-row>
<db-table-data
v-if="!dt.errorMsg"
:ref="(el) => (dt.dbTableRef = el)"
:ref="(el: any) => (dt.dbTableRef = el)"
:db-id="dbId"
:db="dbName"
:data="dt.data"
@@ -128,12 +128,12 @@
</template>
<script lang="ts" setup>
import { h, nextTick, onMounted, reactive, ref, toRefs, unref } from 'vue';
import { nextTick, onMounted, reactive, ref, toRefs, unref } from 'vue';
import { getToken } from '@/common/utils/storage';
import { notBlank } from '@/common/assert';
import { format as sqlFormatter } from 'sql-formatter';
import config from '@/common/config';
import { ElMessage, ElMessageBox, ElNotification } from 'element-plus';
import { ElMessage, ElMessageBox } from 'element-plus';
import * as monaco from 'monaco-editor/esm/vs/editor/editor.api';
import { editor } from 'monaco-editor';
@@ -144,9 +144,6 @@ import { dbApi } from '../../api';
import MonacoEditor from '@/components/monaco/MonacoEditor.vue';
import { joinClientParams } from '@/common/request';
import { buildProgressProps } from '@/components/progress-notify/progress-notify';
import ProgressNotify from '@/components/progress-notify/progress-notify.vue';
import syssocket from '@/common/syssocket';
import SvgIcon from '@/components/svgIcon/index.vue';
import { Pane, Splitpanes } from 'splitpanes';
import { useI18n } from 'vue-i18n';
@@ -593,44 +590,8 @@ const replaceSelection = (str: string, selection: any) => {
});
};
/**
* sql文件执行进度通知缓存
*/
const sqlExecNotifyMap: Map<string, any> = new Map();
const beforeUpload = (file: File) => {
ElMessage.success(t('db.scriptFileUploadRunning', { filename: file.name }));
syssocket.registerMsgHandler('execSqlFileProgress', function (message: any) {
const content = JSON.parse(message.msg);
const id = content.id;
let progress = sqlExecNotifyMap.get(id);
if (content.terminated) {
if (progress != undefined) {
progress.notification?.close();
sqlExecNotifyMap.delete(id);
progress = undefined;
}
return;
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
};
}
progress.props.progress.title = content.title;
progress.props.progress.executedStatements = content.executedStatements;
if (!sqlExecNotifyMap.has(id)) {
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: h(ProgressNotify, progress.props),
type: syssocket.getMsgType(message.type),
showClose: false,
});
sqlExecNotifyMap.set(id, progress);
}
});
};
// 执行sql成功

View File

@@ -1,36 +1,54 @@
<template>
<div class="string-input-container w100" v-if="dataType == DataType.String || dataType == DataType.Number">
<el-input
:ref="(el: any) => focus && el?.focus()"
:ref="
(el: any) => {
nextTick(() => {
focus && el?.focus();
});
}
"
:disabled="disabled"
@blur="handleBlur"
:class="`w100 mb4 ${showEditorIcon ? 'string-input-container-show-icon' : ''}`"
size="small"
v-model="itemValue"
:placeholder="placeholder"
:placeholder="placeholder ?? $t('common.pleaseInput')"
/>
<SvgIcon v-if="showEditorIcon" @mousedown="openEditor" class="string-input-container-icon" name="FullScreen" :size="10" />
</div>
<el-date-picker
v-else-if="dataType == DataType.Date"
:ref="(el: any) => focus && el?.focus()"
:ref="
(el: any) => {
nextTick(() => {
focus && el?.focus();
});
}
"
:disabled="disabled"
@change="emit('blur')"
@change="handleBlur"
@blur="handleBlur"
class="edit-time-picker mb4"
popper-class="edit-time-picker-popper"
size="small"
v-model="itemValue"
:clearable="false"
type="Date"
type="date"
value-format="YYYY-MM-DD"
:placeholder="`date-${placeholder}`"
:placeholder="`date-${placeholder ?? $t('common.pleaseSelect')}`"
/>
<el-date-picker
v-else-if="dataType == DataType.DateTime"
:ref="(el: any) => focus && el?.focus()"
:ref="
(el: any) => {
nextTick(() => {
focus && el?.focus();
});
}
"
:disabled="disabled"
@change="handleBlur"
@blur="handleBlur"
@@ -41,12 +59,18 @@
:clearable="false"
type="datetime"
value-format="YYYY-MM-DD HH:mm:ss"
:placeholder="`datetime-${placeholder}`"
:placeholder="`datetime-${placeholder ?? $t('common.pleaseSelect')}`"
/>
<el-time-picker
v-else-if="dataType == DataType.Time"
:ref="(el: any) => focus && el?.focus()"
:ref="
(el: any) => {
nextTick(() => {
focus && el?.focus();
});
}
"
:disabled="disabled"
@change="handleBlur"
@blur="handleBlur"
@@ -56,12 +80,12 @@
v-model="itemValue"
:clearable="false"
value-format="HH:mm:ss"
:placeholder="`time-${placeholder}`"
:placeholder="`time-${placeholder ?? $t('common.pleaseSelect')}`"
/>
</template>
<script lang="ts" setup>
import { computed, ref, Ref } from 'vue';
import { computed, nextTick, ref, Ref } from 'vue';
import { ElInput, ElMessage } from 'element-plus';
import { DataType } from '../../dialect/index';
import SvgIcon from '@/components/svgIcon/index.vue';

View File

@@ -164,7 +164,7 @@ import SvgIcon from '@/components/svgIcon/index.vue';
import { exportCsv, exportFile } from '@/common/utils/export';
import { formatDate } from '@/common/utils/format';
import { useIntervalFn, useStorage } from '@vueuse/core';
import { ColumnTypeSubscript, compatibleMysql, DataType, DbDialect, getDbDialect } from '../../dialect/index';
import { ColumnTypeSubscript, DataType, DbDialect, getDbDialect } from '../../dialect/index';
import ColumnFormItem from './ColumnFormItem.vue';
import DbTableDataForm from './DbTableDataForm.vue';
import { useI18n } from 'vue-i18n';
@@ -454,24 +454,11 @@ onBeforeUnmount(() => {
endLoading();
});
const formatDataValues = (datas: any) => {
// mysql数据暂不做处理
if (compatibleMysql(getNowDbInst().type)) {
return;
}
for (let data of datas) {
for (let column of props.columns!) {
data[column.columnName] = getFormatTimeValue(dbDialect.getDataType(column.dataType), data[column.columnName]);
}
}
};
const setTableData = (datas: any) => {
tableRef.value?.scrollTo({ scrollLeft: 0, scrollTop: 0 });
selectionRowsMap.value.clear();
cellUpdateMap.value.clear();
formatDataValues(datas);
// formatDataValues(datas);
state.datas = datas;
setTableColumns(props.columns);
};
@@ -684,7 +671,8 @@ const onExportSql = async () => {
};
const onEnterEditMode = (rowData: any, column: any, rowIndex = 0, columnIndex = 0) => {
if (!state.table) {
// 不存在表,或者已经在编辑中,则不处理
if (!state.table || nowUpdateCell.value) {
return;
}
@@ -697,7 +685,7 @@ const onEnterEditMode = (rowData: any, column: any, rowIndex = 0, columnIndex =
};
const onExitEditMode = (rowData: any, column: any, rowIndex = 0) => {
if (!nowUpdateCell) {
if (!nowUpdateCell.value) {
return;
}
const oldValue = nowUpdateCell.value.oldValue;
@@ -788,32 +776,6 @@ const rowClass = (row: any) => {
return '';
};
/**
* 根据数据库返回的时间字段类型,获取格式化后的时间值
* @param dataType getDataType返回的数据类型
* @param originValue 原始值
* @return 格式化后的值
*/
const getFormatTimeValue = (dataType: DataType, originValue: string): string => {
if (!originValue || dataType === DataType.Number || dataType === DataType.String) {
return originValue;
}
// 把Z去掉
originValue = originValue.replace('Z', '');
switch (dataType) {
case DataType.Time:
return formatDate(originValue, 'HH:mm:ss');
case DataType.Date:
return formatDate(originValue, 'YYYY-MM-DD');
case DataType.DateTime:
return formatDate(originValue, 'YYYY-MM-DD HH:mm:ss');
default:
return originValue;
}
};
const scrollLeftValue = ref(0);
const onTableScroll = (param: any) => {
scrollLeftValue.value = param.scrollLeft;

View File

@@ -592,7 +592,7 @@ const onSelectByCondition = async () => {
*/
const onTableSortChange = async (sort: any) => {
const sortType = sort.order == 'desc' ? 'DESC' : 'ASC';
state.orderBy = `ORDER BY ${sort.columnName} ${sortType}`;
state.orderBy = `ORDER BY ${state.dbDialect.quoteIdentifier(sort.columnName)} ${sortType}`;
await onRefresh();
};

View File

@@ -79,7 +79,7 @@
</el-tab-pane>
<el-tab-pane :label="$t('db.index')" name="2">
<el-table :data="tableData.indexs.res" :max-height="tableData.height">
<el-table-column :prop="item.prop" :label="item.label" v-for="item in tableData.indexs.colNames" :key="item.prop">
<el-table-column :prop="item.prop" :label="$t(item.label)" v-for="item in tableData.indexs.colNames" :key="item.prop">
<template #default="scope">
<el-input v-if="item.prop === 'indexName'" size="small" disabled v-model="scope.row.indexName"></el-input>
@@ -241,7 +241,7 @@ const state = reactive({
colNames: [
{
prop: 'indexName',
label: 'db.indexName',
label: 'common.name',
},
{
prop: 'columnNames',

View File

@@ -209,8 +209,7 @@ class PostgresqlDialect implements DbDialect {
}
quoteIdentifier = (name: string) => {
// 后端sql解析器暂不支持pgsql
return name;
return `"${name}"`;
};
matchType(text: string, arr: string[]): boolean {

View File

@@ -22,7 +22,7 @@ export const DbSqlExecStatusEnum = {
export const DbDataSyncDuplicateStrategyEnum = {
None: EnumValue.of(-1, 'db.none'),
Ignore: EnumValue.of(1, 'db.ingore'),
Ignore: EnumValue.of(1, 'db.ignore'),
Replace: EnumValue.of(2, 'db.replace'),
};

View File

@@ -32,9 +32,9 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/veops/go-ansiterm v0.0.5
go.mongodb.org/mongo-driver v1.16.0 // mongo
golang.org/x/crypto v0.29.0 // ssh
golang.org/x/crypto v0.30.0 // ssh
golang.org/x/oauth2 v0.23.0
golang.org/x/sync v0.9.0
golang.org/x/sync v0.10.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
// gorm
@@ -93,8 +93,8 @@ require (
golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect
golang.org/x/image v0.13.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect

View File

@@ -6,15 +6,15 @@ import (
"mayfly-go/internal/auth/api/form"
"mayfly-go/internal/auth/config"
"mayfly-go/internal/auth/imsg"
"mayfly-go/internal/auth/pkg/captcha"
"mayfly-go/internal/auth/pkg/otp"
msgapp "mayfly-go/internal/msg/application"
sysapp "mayfly-go/internal/sys/application"
sysentity "mayfly-go/internal/sys/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/captcha"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/otp"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/cryptox"

View File

@@ -1,8 +1,8 @@
package api
import (
"mayfly-go/internal/auth/pkg/captcha"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/captcha"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/collx"
)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"mayfly-go/internal/auth/config"
"mayfly-go/internal/auth/imsg"
"mayfly-go/internal/auth/pkg/otp"
msgapp "mayfly-go/internal/msg/application"
msgentity "mayfly-go/internal/msg/domain/entity"
sysapp "mayfly-go/internal/sys/application"
@@ -12,7 +13,6 @@ import (
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/i18n"
"mayfly-go/pkg/otp"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/jsonx"

View File

@@ -7,12 +7,12 @@ import (
"mayfly-go/internal/auth/api/form"
"mayfly-go/internal/auth/config"
"mayfly-go/internal/auth/imsg"
"mayfly-go/internal/auth/pkg/captcha"
msgapp "mayfly-go/internal/msg/application"
sysapp "mayfly-go/internal/sys/application"
sysentity "mayfly-go/internal/sys/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/captcha"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"

View File

@@ -11,5 +11,5 @@ type oauth2AccountRepoImpl struct {
}
func newAuthAccountRepo() repository.Oauth2Account {
return &oauth2AccountRepoImpl{base.RepoImpl[*entity.Oauth2Account]{M: new(entity.Oauth2Account)}}
return &oauth2AccountRepoImpl{}
}

View File

@@ -0,0 +1,37 @@
package router
import (
"mayfly-go/internal/auth/api"
"mayfly-go/internal/auth/imsg"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ioc"
"mayfly-go/pkg/req"
"github.com/gin-gonic/gin"
)
func InitAccount(router *gin.RouterGroup) {
accountLogin := new(api.AccountLogin)
biz.ErrIsNil(ioc.Inject(accountLogin))
ldapLogin := new(api.LdapLogin)
biz.ErrIsNil(ioc.Inject(ldapLogin))
rg := router.Group("/auth/accounts")
reqs := [...]*req.Conf{
// 用户账号密码登录
req.NewPost("/login", accountLogin.Login).Log(req.NewLogSaveI(imsg.LogAccountLogin)).DontNeedToken(),
req.NewGet("/refreshToken", accountLogin.RefreshToken).DontNeedToken(),
// 用户退出登录
req.NewPost("/logout", accountLogin.Logout),
// 用户otp双因素校验
req.NewPost("/otp-verify", accountLogin.OtpVerify).DontNeedToken(),
}
req.BatchSetGroup(rg, reqs[:])
}

View File

@@ -1,13 +1,13 @@
package router
import (
"mayfly-go/internal/sys/api"
"mayfly-go/internal/auth/api"
"mayfly-go/pkg/req"
"github.com/gin-gonic/gin"
)
func InitCaptchaRouter(router *gin.RouterGroup) {
func InitCaptcha(router *gin.RouterGroup) {
captcha := router.Group("sys/captcha")
{
req.NewGet("", api.GenerateCaptcha).DontNeedToken().Group(captcha)

View File

@@ -0,0 +1,25 @@
package router
import (
"mayfly-go/internal/auth/api"
"mayfly-go/internal/auth/imsg"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ioc"
"mayfly-go/pkg/req"
"github.com/gin-gonic/gin"
)
func InitLdap(router *gin.RouterGroup) {
ldapLogin := new(api.LdapLogin)
biz.ErrIsNil(ioc.Inject(ldapLogin))
rg := router.Group("/auth/ldap")
reqs := [...]*req.Conf{
req.NewGet("/enabled", ldapLogin.GetLdapEnabled).DontNeedToken(),
req.NewPost("/login", ldapLogin.Login).Log(req.NewLogSaveI(imsg.LogLdapLogin)).DontNeedToken(),
}
req.BatchSetGroup(rg, reqs[:])
}

View File

@@ -0,0 +1,37 @@
package router
import (
"mayfly-go/internal/auth/api"
"mayfly-go/internal/auth/imsg"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ioc"
"mayfly-go/pkg/req"
"github.com/gin-gonic/gin"
)
func InitOauth2(router *gin.RouterGroup) {
oauth2Login := new(api.Oauth2Login)
biz.ErrIsNil(ioc.Inject(oauth2Login))
rg := router.Group("/auth/oauth2")
reqs := [...]*req.Conf{
req.NewGet("/config", oauth2Login.Oauth2Config).DontNeedToken(),
// oauth2登录
req.NewGet("/login", oauth2Login.OAuth2Login).DontNeedToken(),
req.NewGet("/bind", oauth2Login.OAuth2Bind),
// oauth2回调地址
req.NewGet("/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSaveI(imsg.LogOauth2Callback)).DontNeedToken(),
req.NewGet("/status", oauth2Login.Oauth2Status),
req.NewGet("/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSaveI(imsg.LogOauth2Unbind)),
}
req.BatchSetGroup(rg, reqs[:])
}

View File

@@ -1,60 +1,10 @@
package router
import (
"mayfly-go/internal/auth/api"
"mayfly-go/internal/auth/imsg"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ioc"
"mayfly-go/pkg/req"
"github.com/gin-gonic/gin"
)
import "github.com/gin-gonic/gin"
func Init(router *gin.RouterGroup) {
accountLogin := new(api.AccountLogin)
biz.ErrIsNil(ioc.Inject(accountLogin))
ldapLogin := new(api.LdapLogin)
biz.ErrIsNil(ioc.Inject(ldapLogin))
oauth2Login := new(api.Oauth2Login)
biz.ErrIsNil(ioc.Inject(oauth2Login))
rg := router.Group("/auth")
reqs := [...]*req.Conf{
// 用户账号密码登录
req.NewPost("/accounts/login", accountLogin.Login).Log(req.NewLogSaveI(imsg.LogAccountLogin)).DontNeedToken(),
req.NewGet("/accounts/refreshToken", accountLogin.RefreshToken).DontNeedToken(),
// 用户退出登录
req.NewPost("/accounts/logout", accountLogin.Logout),
// 用户otp双因素校验
req.NewPost("/accounts/otp-verify", accountLogin.OtpVerify).DontNeedToken(),
/*--------oauth2登录相关----------*/
req.NewGet("/oauth2-config", oauth2Login.Oauth2Config).DontNeedToken(),
// oauth2登录
req.NewGet("/oauth2/login", oauth2Login.OAuth2Login).DontNeedToken(),
req.NewGet("/oauth2/bind", oauth2Login.OAuth2Bind),
// oauth2回调地址
req.NewGet("/oauth2/callback", oauth2Login.OAuth2Callback).Log(req.NewLogSaveI(imsg.LogOauth2Callback)).DontNeedToken(),
req.NewGet("/oauth2/status", oauth2Login.Oauth2Status),
req.NewGet("/oauth2/unbind", oauth2Login.Oauth2Unbind).Log(req.NewLogSaveI(imsg.LogOauth2Unbind)),
// LDAP 登录
req.NewGet("/ldap/enabled", ldapLogin.GetLdapEnabled).DontNeedToken(),
req.NewPost("/ldap/login", ldapLogin.Login).Log(req.NewLogSaveI(imsg.LogLdapLogin)).DontNeedToken(),
}
req.BatchSetGroup(rg, reqs[:])
InitCaptcha(router)
InitAccount(router)
InitOauth2(router)
InitLdap(router)
}

View File

@@ -9,7 +9,6 @@ import (
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/imsg"
"mayfly-go/internal/event"
@@ -26,9 +25,7 @@ import (
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/cryptox"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/utils/writerx"
"mayfly-go/pkg/ws"
"strings"
"time"
@@ -118,7 +115,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sqlStr)
biz.NotEmpty(form.Sql, "sql cannot be empty")
execReq := &application.DbSqlExecReq{
execReq := &dto.DbSqlExecReq{
DbId: dbId,
Db: form.Db,
Remark: form.Remark,
@@ -163,47 +160,12 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
defer func() {
if err := recover(); err != nil {
errInfo := anyx.ToString(err)
if len(errInfo) > 300 {
errInfo = errInfo[:300] + "..."
}
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] execution failure: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
}
}()
executedStatements := 0
progressId := stringx.Rand(32)
laId := rc.GetLoginAccount().Id
defer ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
Terminated: true,
}).WithCategory(progressCategory))
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
err = sqlparser.SQLSplit(file, func(sql string) error {
select {
case <-ticker.C:
ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
}
executedStatements++
_, err = dbConn.Exec(sql)
return err
})
biz.ErrIsNilAppendErr(err, "%s")
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), fmt.Sprintf("execution success: %s", rc.ReqParam)).WithClientId(clientId))
biz.ErrIsNil(d.DbSqlExecApp.ExecReader(rc.MetaCtx, &dto.SqlReaderExec{
Reader: file,
Filename: filename,
DbConn: dbConn,
ClientId: clientId,
}))
}
// 数据库dump

View File

@@ -2,28 +2,18 @@ package api
import (
"context"
"fmt"
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/imsg"
fileapp "mayfly-go/internal/file/application"
msgapp "mayfly-go/internal/msg/application"
msgdto "mayfly-go/internal/msg/application/dto"
tagapp "mayfly-go/internal/tag/application"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/i18n"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/ws"
"strings"
"time"
"github.com/may-fly/cast"
)
@@ -120,7 +110,6 @@ func (d *DbTransferTask) FileDel(rc *req.Ctx) {
}
func (d *DbTransferTask) FileRun(rc *req.Ctx) {
fm := req.BindJsonAndValid(rc, &form.DbTransferFileRunForm{})
rc.ReqParam = fm
@@ -132,69 +121,15 @@ func (d *DbTransferTask) FileRun(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s")
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, targetDbConn.Info.CodePath...), "%s")
defer func() {
if err := recover(); err != nil {
errInfo := anyx.ToString(err)
if len(errInfo) > 300 {
errInfo = errInfo[:300] + "..."
}
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] run failed: [%s]", tFile.FileKey, targetDbConn.Info.GetLogDesc(), errInfo)).WithClientId(fm.ClientId))
}
}()
go func() {
d.fileRun(rc.GetLoginAccount(), fm, tFile, targetDbConn)
}()
}
func (d *DbTransferTask) fileRun(la *model.LoginAccount, fm *form.DbTransferFileRunForm, tFile *entity.DbTransferFile, targetDbConn *dbi.DbConn) {
filename, reader, err := d.FileApp.GetReader(context.TODO(), tFile.FileKey)
executedStatements := 0
progressId := stringx.Rand(32)
laId := la.Id
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
defer func() {
if err := recover(); err != nil {
errInfo := anyx.ToString(err)
if len(errInfo) > 500 {
errInfo = errInfo[:500] + "..."
}
d.MsgApp.CreateAndSend(la, msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), errInfo).WithClientId(fm.ClientId))
}
biz.ErrIsNil(err)
go func() {
biz.ErrIsNil(d.DbSqlExecApp.ExecReader(rc.MetaCtx, &dto.SqlReaderExec{
Reader: reader,
Filename: filename,
DbConn: targetDbConn,
ClientId: fm.ClientId,
}))
}()
if err != nil {
biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s")
}
errSql := ""
err = sqlparser.SQLSplit(reader, func(sql string) error {
select {
case <-ticker.C:
ws.SendJsonMsg(ws.UserId(laId), fm.ClientId, msgdto.InfoSqlProgressMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
}
executedStatements++
_, err = targetDbConn.Exec(sql)
if err != nil {
errSql = sql
}
return err
})
if err != nil {
biz.ErrIsNil(err, "[%s] execution failed: %s", errSql, err)
}
d.MsgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), fmt.Sprintf("sql execution successfully: %s", filename)).WithClientId(fm.ClientId))
}

View File

@@ -214,7 +214,7 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error
return nil, errorx.NewBiz("failed to get database list")
}
if len(dbs) == 0 {
return nil, errorx.NewBiz("DB instance [%d] Database is not configured, please configure it first", instanceId)
return nil, errorx.NewBiz("DB instance [%d] database is not configured, please configure it first", instanceId)
}
// 使用该实例关联的已配置数据库中的第一个库进行连接并返回
@@ -227,6 +227,10 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
if reqParam.Log != nil {
log = reqParam.Log
}
progress := dto.DefaultDumpProgress
if reqParam.Progress != nil {
progress = reqParam.Progress
}
writer := writerx.NewStringWriter(reqParam.Writer)
defer writer.Close()
@@ -249,23 +253,16 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
// 获取目标元数据仅生成sql用于生成建表语句和插入数据不能用于查询
targetDialect := dbConn.GetDialect()
if reqParam.TargetDbType != "" && dbConn.Info.Type != reqParam.TargetDbType {
// 创建一个假连接仅用于调用方言生成sql不做数据库连接操作
meta := dbi.GetMeta(reqParam.TargetDbType)
dbConn := &dbi.DbConn{Info: &dbi.DbInfo{
Type: reqParam.TargetDbType,
Meta: meta,
}}
targetDialect = meta.GetDialect(dbConn)
targetDialect = dbi.GetDialect(reqParam.TargetDbType)
}
srcMeta := dbConn.GetMetadata()
srcDialect := dbConn.GetDialect()
if len(tables) == 0 {
log("Gets the table information that can be export...")
log("gets the table information that can be export...")
ti, err := srcMeta.GetTables()
if err != nil {
log(fmt.Sprintf("Failed to get table info %s", err.Error()))
log(fmt.Sprintf("failed to get table info %s", err.Error()))
}
biz.ErrIsNil(err)
tables = make([]string, len(ti))
@@ -275,105 +272,133 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
log(fmt.Sprintf("Get %d tables", len(tables)))
}
if len(tables) == 0 {
log("No table to export. End export")
log("no table to export. end export")
return errorx.NewBiz("there is no table to export")
}
log("Querying column information...")
log("querying column information...")
// 查询列信息后面生成建表ddl和insert都需要列信息
columns, err := srcMeta.GetColumns(tables...)
if err != nil {
log(fmt.Sprintf("Failed to query column information: %s", err.Error()))
log(fmt.Sprintf("failed to query column information: %s", err.Error()))
}
biz.ErrIsNil(err)
// 以表名分组,存放每个表的列信息
columnMap := make(map[string][]dbi.Column)
for _, column := range columns {
if err := dbi.ConvToTargetDbColumn(dbConn.Info.Type, cmp.Or(reqParam.TargetDbType, dbConn.Info.Type), targetDialect, &column); err != nil {
return err
}
columnMap[column.TableName] = append(columnMap[column.TableName], column)
}
// 按表名排序
sort.Strings(tables)
quoteSchema := srcDialect.QuoteIdentifier(dbConn.Info.CurrentSchema())
quoteSchema := srcDialect.Quoter().Quote(dbConn.Info.CurrentSchema())
dumpHelper := targetDialect.GetDumpHelper()
dataHelper := targetDialect.GetDataHelper()
targetSqlGenerator := targetDialect.GetSQLGenerator()
targetDialectQuote := targetDialect.Quoter().Quote
// 遍历获取每个表的信息
for _, tableName := range tables {
log(fmt.Sprintf("Get table [%s] information...", tableName))
quoteTableName := targetDialect.QuoteIdentifier(tableName)
log(fmt.Sprintf("get table [%s] information...", tableName))
quoteTableName := targetDialectQuote(tableName)
// 查询表信息,主要是为了查询表注释
tbs, err := srcMeta.GetTables(tableName)
if err != nil {
log(fmt.Sprintf("Failed to get table [%s] information: %s", tableName, err.Error()))
log(fmt.Sprintf("failed to get table [%s] information: %s", tableName, err.Error()))
return err
}
if len(tbs) <= 0 {
log(fmt.Sprintf("Failed to get table [%s] information: No table information was retrieved", tableName))
log(fmt.Sprintf("failed to get table [%s] information: No table information was retrieved", tableName))
return errorx.NewBiz(fmt.Sprintf("Failed to get table information: %s", tableName))
}
tabInfo := dbi.Table{
TableName: tableName,
TableComment: tbs[0].TableComment,
}
tableInfo := tbs[0]
columns := columnMap[tableName]
// 生成表结构信息
if reqParam.DumpDDL {
log(fmt.Sprintf("Generate table [%s] DDL...", tableName))
log(fmt.Sprintf("generate table [%s] DDL...", tableName))
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Table structure: %s \n-- ----------------------------\n", tableName))
tbDdlArr := targetDialect.GenerateTableDDL(columnMap[tableName], tabInfo, true)
tbDdlArr := targetSqlGenerator.GenTableDDL(tableInfo, columns, true)
for _, ddl := range tbDdlArr {
writer.WriteString(ddl + ";\n")
if _, err := writer.WriteString(ddl + ";\n"); err != nil {
return err
}
}
progress(tableName, dbi.StmtTypeDDL, len(tbDdlArr), true)
}
// 生成insert sql数据在索引前加速insert
if reqParam.DumpData {
log(fmt.Sprintf("Generate table [%s] DML...", tableName))
log(fmt.Sprintf("generate table [%s] DML...", tableName))
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Data: %s \n-- ----------------------------\n", tableName))
dumpHelper.BeforeInsert(writer, quoteTableName)
// 获取列信息
quoteColNames := make([]string, 0)
for _, col := range columnMap[tableName] {
quoteColNames = append(quoteColNames, targetDialect.QuoteIdentifier(col.ColumnName))
}
_, _ = dbConn.WalkTableRows(ctx, tableName, func(row map[string]any, _ []*dbi.QueryColumn) error {
rowValues := make([]string, len(columnMap[tableName]))
for i, col := range columnMap[tableName] {
rowValues[i] = dataHelper.WrapValue(row[col.ColumnName], dataHelper.GetDataType(string(col.DataType)))
dataCount := 0
rows := make([][]any, 0)
_, err = dbConn.WalkTableRows(ctx, tableName, func(row map[string]any, _ []*dbi.QueryColumn) error {
rowValues := make([]any, len(columns))
for i, col := range columns {
rowValues[i] = row[col.ColumnName]
}
rows = append(rows, rowValues)
dataCount++
if dataCount%500 != 0 {
return nil
}
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
insertSQL := fmt.Sprintf("%s INSERT INTO %s (%s) values(%s)", beforeInsert, quoteTableName, strings.Join(quoteColNames, ", "), strings.Join(rowValues, ", "))
writer.WriteString(insertSQL + ";\n")
writer.WriteString(beforeInsert)
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
return err
}
progress(tableName, dbi.StmtTypeInsert, dataCount, false)
rows = make([][]any, 0)
return nil
})
dumpHelper.AfterInsert(writer, tableName, columnMap[tableName])
if err != nil {
return err
}
if len(rows) > 0 {
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
writer.WriteString(beforeInsert)
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
return err
}
}
dumpHelper.AfterInsert(writer, tableName, columns)
progress(tableName, dbi.StmtTypeInsert, dataCount, true)
}
log(fmt.Sprintf("Get table [%s] index information...", tableName))
log(fmt.Sprintf("get table [%s] index information...", tableName))
indexs, err := srcMeta.GetTableIndex(tableName)
if err != nil {
log(fmt.Sprintf("Failed to get table [%s] index information: %s", tableName, err.Error()))
log(fmt.Sprintf("failed to get table [%s] index information: %s", tableName, err.Error()))
return err
}
if len(indexs) > 0 {
// 最后添加索引
log(fmt.Sprintf("Generate table [%s] index...", tableName))
log(fmt.Sprintf("generate table [%s] index...", tableName))
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- Table Index: %s \n-- ----------------------------\n", tableName))
sqlArr := targetDialect.GenerateIndexDDL(indexs, tabInfo)
sqlArr := targetSqlGenerator.GenIndexDDL(tableInfo, indexs)
for _, sqlStr := range sqlArr {
writer.WriteString(sqlStr + ";\n")
if _, err := writer.WriteString(sqlStr + ";\n"); err != nil {
return err
}
}
progress(tableName, dbi.StmtTypeDDL, len(sqlArr), true)
}
}
return nil

View File

@@ -1,6 +1,7 @@
package application
import (
"cmp"
"context"
"database/sql"
"encoding/json"
@@ -14,8 +15,8 @@ import (
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/scheduler"
"mayfly-go/pkg/utils/collx"
"regexp"
"strconv"
"strings"
"time"
@@ -145,28 +146,11 @@ func (app *dataSyncAppImpl) RunCronJob(ctx context.Context, id uint64) error {
updSql := ""
orderSql := ""
if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" {
srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
if err != nil {
logx.ErrorfContext(ctx, "data source connection unavailable: %s", err.Error())
return
}
task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ")
// 判断UpdFieldVal数据类型
var updFieldValType dbi.DataType
if _, err = strconv.Atoi(task.UpdFieldVal); err != nil {
if dateTimeReg.MatchString(task.UpdFieldVal) || dateTimeIsoReg.MatchString(task.UpdFieldVal) {
updFieldValType = dbi.DataTypeDateTime
} else {
updFieldValType = dbi.DataTypeString
}
} else {
updFieldValType = dbi.DataTypeNumber
}
wrapUpdFieldVal := srcConn.GetDialect().GetDataHelper().WrapValue(task.UpdFieldVal, updFieldValType)
updSql = fmt.Sprintf("and %s > %s", task.UpdField, wrapUpdFieldVal)
updSql = fmt.Sprintf("and %s > %s", task.UpdField, strings.Trim(task.UpdFieldVal, " "))
orderSql = "order by " + task.UpdField + " asc "
}
// 正则判断DataSql是否以where .*结尾如果是则不添加where 1 = 1
@@ -221,15 +205,13 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
}
}()
srcDialect := srcConn.GetDialect()
// task.FieldMap为json数组字符串 [{"src":"id","target":"id"}]转为map
var fieldMap []map[string]string
err = json.Unmarshal([]byte(task.FieldMap), &fieldMap)
if err != nil {
return syncLog, errorx.NewBiz("there was an error parsing the field map json: %s", err.Error())
}
var updFieldType dbi.DataType
var updFieldType *dbi.DbDataType
// 记录本次同步数据总数
total := 0
@@ -243,15 +225,28 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
updFieldName = strings.Split(task.UpdField, ".")[1]
}
targetTableColumns, err := targetConn.GetMetadata().GetColumns(task.TargetTableName)
if err != nil {
return syncLog, errorx.NewBiz("failed to get target table columns: %s", err.Error())
}
targetColumnName2Column := collx.ArrayToMap(targetTableColumns, func(column dbi.Column) string {
return column.ColumnName
})
// 目标库对应的insert columns
targetInsertColumns := collx.ArrayMap[map[string]string, dbi.Column](fieldMap, func(val map[string]string) dbi.Column {
return targetColumnName2Column[val["target"]]
})
_, err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
if len(queryColumns) == 0 {
queryColumns = columns
// 遍历columns 取task.UpdField的字段类型
updFieldType = dbi.DataTypeString
updFieldType = dbi.DefaultDbDataType
for _, column := range columns {
if strings.EqualFold(column.Name, updFieldName) {
updFieldType = srcDialect.GetDataHelper().GetDataType(column.Type)
updFieldType = column.DbDataType
break
}
}
@@ -260,7 +255,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
total++
result = append(result, row)
if total%batchSize == 0 {
if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, updFieldName, task, targetConn, targetDbTx, targetInsertColumns); err != nil {
return err
}
@@ -283,7 +278,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
// 处理剩余的数据
if len(result) > 0 {
if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, updFieldName, task, targetConn, targetDbTx, targetInsertColumns); err != nil {
targetDbTx.Rollback()
return syncLog, err
}
@@ -291,7 +286,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
// 如果是mssql暂不手动提交事务否则报错 mssql: The COMMIT TRANSACTION request has no corresponding BEGIN TRANSACTION.
if err := targetDbTx.Commit(); err != nil {
if targetConn.Info.Type != dbi.DbTypeMssql {
if targetConn.Info.Type != dbi.ToDbType("mssql") {
return syncLog, errorx.NewBiz("data synchronization - The target database transaction failed to commit: %s", err.Error())
}
}
@@ -307,36 +302,38 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
return syncLog, nil
}
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
// 遍历src字段列表取出字段对应的类型
var srcColumnTypes = make(map[string]string)
for _, column := range columns {
srcColumnTypes[column.Name] = column.Type
}
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType *dbi.DbDataType, updFieldName string, task *entity.DataSyncTask, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx, targetInsertColumns []dbi.Column) error {
// 遍历res组装数据
var data = make([]map[string]any, 0)
for _, record := range srcRes {
var rowData = make(map[string]any)
var targetData = make([]map[string]any, 0)
for _, srcData := range srcRes {
var data = make(map[string]any)
// 遍历字段映射, target字段的值为src字段取值
for _, item := range fieldMap {
srcField := item["src"]
targetField := item["target"]
// target字段的值为src字段取值
rowData[targetField] = record[srcField]
data[item["target"]] = srcData[item["src"]]
}
data = append(data, rowData)
targetData = append(targetData, data)
}
tragetValues := make([][]any, 0)
for _, item := range targetData {
var values = make([]any, 0)
for _, column := range targetInsertColumns {
values = append(values, item[column.ColumnName])
}
tragetValues = append(tragetValues, values)
}
// 执行插入
setUpdateFieldVal := func(field string) {
// 解决字段大小写问题
updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(field)]
if updFieldVal == "" || updFieldVal == nil {
updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(field)]
}
task.UpdFieldVal = srcDialect.GetDataHelper().FormatData(updFieldVal, updFieldType)
task.UpdFieldVal = updFieldType.DataType.SQLValue(updFieldVal)
}
// 如果指定了更新字段,则以更新字段取值
@@ -346,36 +343,15 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
setUpdateFieldVal(updFieldName)
}
// 获取目标库字段数组
targetWrapColumns := make([]string, 0)
// 获取源库字段数组
srcColumns := make([]string, 0)
srcFieldTypes := make(map[string]dbi.DataType)
targetDialect := targetDbConn.GetDialect()
for _, item := range fieldMap {
targetField := item["target"]
srcField := item["target"]
srcFieldTypes[srcField] = srcDialect.GetDataHelper().GetDataType(srcColumnTypes[item["src"]])
targetWrapColumns = append(targetWrapColumns, targetDialect.QuoteIdentifier(targetField))
srcColumns = append(srcColumns, srcField)
}
// 目标数据中取出源库字段对应的值
values := make([][]any, 0)
for _, record := range data {
rawValue := make([]any, 0)
for _, column := range srcColumns {
// 某些情况如oracle需要转换时间类型的字符串为time类型
res := srcDialect.GetDataHelper().ParseData(record[column], srcFieldTypes[column])
rawValue = append(rawValue, res)
// 生成目标数据库批量插入sql并执行
sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, tragetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone))
for _, sql := range sqls {
_, err := targetDbTx.Exec(sql)
if err != nil {
return err
}
values = append(values, rawValue)
}
// 目标数据库执行sql批量插入
_, err := targetDialect.BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values, task.DuplicateStrategy)
if err != nil {
return err
}
// 运行过程中,判断状态是否为已关闭,是则结束运行,否则继续运行

View File

@@ -3,6 +3,7 @@ package application
import (
"context"
"fmt"
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
@@ -12,26 +13,21 @@ import (
"mayfly-go/internal/db/imsg"
flowapp "mayfly-go/internal/flow/application"
flowentity "mayfly-go/internal/flow/domain/entity"
msgapp "mayfly-go/internal/msg/application"
msgdto "mayfly-go/internal/msg/application/dto"
"mayfly-go/pkg/contextx"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/i18n"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/jsonx"
"os"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/ws"
"strings"
)
type DbSqlExecReq struct {
DbId uint64
Db string
Sql string // 需要执行的sql支持多条
SqlFile *os.File // sql文件
Remark string // 执行备注
DbConn *dbi.DbConn
CheckFlow bool // 是否检查存储审批流程
}
type sqlExecParam struct {
DbConn *dbi.DbConn
Sql string // 执行的sql
@@ -41,18 +37,25 @@ type sqlExecParam struct {
SqlExecRecord *entity.DbSqlExec // sql执行记录
}
type DbSqlExecRes struct {
Sql string `json:"sql"` // 执行的sql
ErrorMsg string `json:"errorMsg"` // 若执行失败,则将失败内容记录到该字段
Columns []*dbi.QueryColumn `json:"columns"` // 响应的列信
Res []map[string]any `json:"res"` // 响应结果
// progressCategory sql文件执行进度消息类型
const progressCategory = "execSqlFileProgress"
// progressMsg sql文件执行进度消
type progressMsg struct {
Id string `json:"id"`
Title string `json:"title"`
ExecutedStatements int `json:"executedStatements"`
Terminated bool `json:"terminated"`
}
type DbSqlExec interface {
flowapp.FlowBizHandler
// 执行sql
Exec(ctx context.Context, execSqlReq *DbSqlExecReq) ([]*DbSqlExecRes, error)
Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error)
// ExecReader 从reader中读取sql并执行
ExecReader(ctx context.Context, execReader *dto.SqlReaderExec) error
// 根据条件删除sql执行记录
DeleteBy(ctx context.Context, condition *entity.DbSqlExec) error
@@ -66,9 +69,10 @@ type dbSqlExecAppImpl struct {
dbSqlExecRepo repository.DbSqlExec `inject:"DbSqlExecRepo"`
flowProcdefApp flowapp.Procdef `inject:"ProcdefApp"`
msgApp msgapp.Msg `inject:"MsgApp"`
}
func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq, sql string) *entity.DbSqlExec {
func createSqlExecRecord(ctx context.Context, execSqlReq *dto.DbSqlExecReq, sql string) *entity.DbSqlExec {
dbSqlExecRecord := new(entity.DbSqlExec)
dbSqlExecRecord.DbId = execSqlReq.DbId
dbSqlExecRecord.Db = execSqlReq.Db
@@ -79,7 +83,7 @@ func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq, sql stri
return dbSqlExecRecord
}
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) ([]*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error) {
dbConn := execSqlReq.DbConn
execSql := execSqlReq.Sql
sp := dbConn.GetDialect().GetSQLParser()
@@ -89,13 +93,13 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
flowProcdef = d.flowProcdefApp.GetProcdefByCodePath(ctx, dbConn.Info.CodePath...)
}
allExecRes := make([]*DbSqlExecRes, 0)
allExecRes := make([]*dto.DbSqlExecRes, 0)
stmts, err := sp.Parse(execSql)
// sql解析失败则使用默认方式切割
if err != nil {
sqlparser.SQLSplit(strings.NewReader(execSql), func(oneSql string) error {
var execRes *DbSqlExecRes
var execRes *dto.DbSqlExecRes
var err error
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, oneSql)
@@ -120,7 +124,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
// 执行错误
if err != nil {
if execRes == nil {
execRes = &DbSqlExecRes{Sql: oneSql}
execRes = &dto.DbSqlExecRes{Sql: oneSql}
}
execRes.ErrorMsg = err.Error()
} else {
@@ -133,7 +137,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
}
for _, stmt := range stmts {
var execRes *DbSqlExecRes
var execRes *dto.DbSqlExecRes
var err error
sql := stmt.GetText()
@@ -172,7 +176,7 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
if err != nil {
if execRes == nil {
execRes = &DbSqlExecRes{Sql: sql}
execRes = &dto.DbSqlExecRes{Sql: sql}
}
execRes.ErrorMsg = err.Error()
} else {
@@ -184,6 +188,64 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (
return allExecRes, nil
}
func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlReaderExec) error {
dbConn := execReader.DbConn
clientId := execReader.ClientId
filename := stringx.Truncate(execReader.Filename, 20, 10, "...")
la := contextx.GetLoginAccount(ctx)
needSendMsg := la != nil && clientId != ""
defer func() {
if err := recover(); err != nil {
errInfo := anyx.ToString(err)
logx.Errorf("exec sql reader error: %s", errInfo)
if needSendMsg {
errInfo = stringx.Truncate(errInfo, 300, 10, "...")
d.msgApp.CreateAndSend(la, msgdto.ErrSysMsg(i18n.T(imsg.SqlScriptRunFail), fmt.Sprintf("[%s][%s] execution failure: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
}
}
}()
executedStatements := 0
progressId := stringx.Rand(32)
if needSendMsg {
defer ws.SendJsonMsg(ws.UserId(la.Id), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
Terminated: true,
}).WithCategory(progressCategory))
}
err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error {
if executedStatements%50 == 0 {
if needSendMsg {
ws.SendJsonMsg(ws.UserId(la.Id), clientId, msgdto.InfoSysMsg(i18n.T(imsg.SqlScripRunProgress), &progressMsg{
Id: progressId,
Title: filename,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
}
}
executedStatements++
if _, err := dbConn.Exec(sql); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
if needSendMsg {
d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId))
}
return nil
}
type FlowDbExecSqlBizForm struct {
DbId uint64 `json:"dbId"` // 库id
DbName string `json:"dbName"` // 库名
@@ -211,7 +273,7 @@ func (d *dbSqlExecAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *fl
return nil, err
}
execRes, err := d.Exec(contextx.NewLoginAccount(&model.LoginAccount{Id: procinst.CreatorId, Username: procinst.Creator}), &DbSqlExecReq{
execRes, err := d.Exec(contextx.NewLoginAccount(&model.LoginAccount{Id: procinst.CreatorId, Username: procinst.Creator}), &dto.DbSqlExecReq{
DbId: execSqlBizForm.DbId,
Db: execSqlBizForm.DbName,
Sql: execSqlBizForm.Sql,
@@ -257,7 +319,7 @@ func (d *dbSqlExecAppImpl) saveSqlExecLog(dbSqlExecRecord *entity.DbSqlExec, res
}
}
func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
maxCount := config.GetDbms().MaxResultSet
selectStmt := sqlExecParam.Stmt
selectSql := sqlExecParam.Sql
@@ -314,7 +376,7 @@ func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecPa
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
}
func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
selectSql := sqlExecParam.Sql
sqlExecParam.SqlExecRecord.Type = entity.DbSqlExecTypeQuery
@@ -327,7 +389,7 @@ func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExe
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
}
func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
selectSql := sqlExecParam.Sql
sqlExecParam.SqlExecRecord.Type = entity.DbSqlExecTypeDDL
@@ -340,7 +402,7 @@ func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecP
return d.doExec(ctx, sqlExecParam.DbConn, selectSql)
}
func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
dbConn := sqlExecParam.DbConn
if procdef := sqlExecParam.Procdef; procdef != nil {
@@ -365,7 +427,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
tableSources := updatestmt.TableSources.TableSources
// 不支持多表更新记录旧值
if len(tableSources) != 1 {
logx.ErrorContext(ctx, "Update SQL - logging old values only supports single-table updates")
logx.ErrorContext(ctx, "update SQL - logging old values only supports single-table updates")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
@@ -379,21 +441,21 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
}
if tableName == "" {
logx.ErrorContext(ctx, "Update SQL - failed to get table name")
logx.ErrorContext(ctx, "update SQL - failed to get table name")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
execRecord.Table = tableName
whereStr := updatestmt.Where.GetText()
if whereStr == "" {
logx.ErrorContext(ctx, "Update SQL - there is no where condition")
logx.ErrorContext(ctx, "update SQL - there is no where condition")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
// 获取表主键列名,排除使用别名
primaryKey, err := dbConn.GetMetadata().GetPrimaryKey(tableName)
if err != nil {
logx.ErrorfContext(ctx, "Update SQL - failed to get primary key column: %s", err.Error())
logx.ErrorfContext(ctx, "update SQL - failed to get primary key column: %s", err.Error())
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
@@ -417,12 +479,12 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
nowRec++
res = append(res, row)
if nowRec == maxRec {
return errorx.NewBiz(fmt.Sprintf("Update SQL - the maximum number of updated queries is exceeded: %d", maxRec))
return errorx.NewBiz(fmt.Sprintf("update SQL - the maximum number of updated queries is exceeded: %d", maxRec))
}
return nil
})
if err != nil {
logx.ErrorfContext(ctx, "Update SQL - failed to get the updated old value: %s", err.Error())
logx.ErrorfContext(ctx, "update SQL - failed to get the updated old value: %s", err.Error())
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
execRecord.OldValue = jsonx.ToStr(res)
@@ -430,7 +492,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
if procdef := sqlExecParam.Procdef; procdef != nil {
if needStartProc := procdef.MatchCondition(DbSqlExecFlowBizType, collx.Kvs("stmtType", "delete")); needStartProc {
return nil, errorx.NewBizI(ctx, imsg.ErrNeedSubmitWorkTicket)
@@ -454,7 +516,7 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
tableSources := deletestmt.TableSources.TableSources
// 不支持多表删除记录旧值
if len(tableSources) != 1 {
logx.ErrorContext(ctx, "Delete SQL - logging old values only supports single-table deletion")
logx.ErrorContext(ctx, "delete SQL - logging old values only supports single-table deletion")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
@@ -468,14 +530,14 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
}
if tableName == "" {
logx.ErrorContext(ctx, "Delete SQL - failed to get table name")
logx.ErrorContext(ctx, "delete SQL - failed to get table name")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
execRecord.Table = tableName
whereStr := deletestmt.Where.GetText()
if whereStr == "" {
logx.ErrorContext(ctx, "Delete SQL - there is no where condition")
logx.ErrorContext(ctx, "delete SQL - there is no where condition")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
@@ -487,7 +549,7 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecParam) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
if procdef := sqlExecParam.Procdef; procdef != nil {
if needStartProc := procdef.MatchCondition(DbSqlExecFlowBizType, collx.Kvs("stmtType", "insert")); needStartProc {
return nil, errorx.NewBizI(ctx, imsg.ErrNeedSubmitWorkTicket)
@@ -513,19 +575,19 @@ func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, sqlExecParam.DbConn, sqlExecParam.Sql)
}
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string) (*dto.DbSqlExecRes, error) {
cols, res, err := dbConn.QueryContext(ctx, sql)
if err != nil {
return nil, err
}
return &DbSqlExecRes{
return &dto.DbSqlExecRes{
Sql: sql,
Columns: cols,
Res: res,
}, nil
}
func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql string) (*DbSqlExecRes, error) {
func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql string) (*dto.DbSqlExecRes, error) {
rowsAffected, err := dbConn.ExecContext(ctx, sql)
if err != nil {
return nil, err
@@ -534,7 +596,7 @@ func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql s
res := make([]map[string]any, 0)
res = append(res, collx.Kvs("rowsAffected", rowsAffected))
return &DbSqlExecRes{
return &dto.DbSqlExecRes{
Columns: []*dbi.QueryColumn{
{Name: "rowsAffected", Type: "number"},
},

View File

@@ -3,10 +3,11 @@ package application
import (
"cmp"
"context"
"encoding/hex"
"fmt"
"io"
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
fileapp "mayfly-go/internal/file/application"
@@ -21,7 +22,6 @@ import (
"mayfly-go/pkg/scheduler"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/timex"
"sort"
"strings"
"time"
@@ -252,12 +252,90 @@ func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, lo
app.EndTransfer(ctx, logId, taskId, "failed to get target db connection", err, nil)
return
}
// 迁移表
if err = app.transferDbTables(ctx, logId, task, srcConn, targetConn, tables); err != nil {
ctx = context.Background()
tableNames := collx.ArrayMap(tables, func(t dbi.Table) string { return t.TableName })
// 分组迁移
tableGroups := collx.ArraySplit[string](tableNames, 2)
errGroup, _ := errgroup.WithContext(ctx)
for _, tables := range tableGroups {
errGroup.Go(func() error {
if !app.IsRunning(taskId) {
return errorx.NewBiz("transfer stopped")
}
currentDumpTable := tables[0]
pr, pw := io.Pipe()
go func() {
err := app.dbApp.DumpDb(ctx, &dto.DumpDb{
LogId: logId,
DbId: uint64(task.SrcDbId),
DbName: task.SrcDbName,
TargetDbType: dbi.DbType(task.TargetDbType),
Tables: tables,
DumpDDL: true,
DumpData: true,
Writer: pw,
Log: func(msg string) { // 记录日志
app.Log(ctx, logId, msg)
},
Progress: func(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) {
logExtraKey := fmt.Sprintf("`%s` amount of transfer data currently: ", currentDumpTable)
if stmtType == dbi.StmtTypeInsert {
app.logApp.SetExtra(logId, logExtraKey, stmtCount)
if currentStmtTypeEnd {
app.Log(ctx, logId, fmt.Sprintf("execute transfer table [%s] insert %d rows", currentDumpTable, stmtCount))
}
}
if currentDumpTable != currentTable {
currentDumpTable = currentTable
stmtCount = 0
// 置空当前表数据迁移量进度
app.logApp.SetExtra(logId, logExtraKey, nil)
}
},
})
if err != nil {
pr.CloseWithError(err)
return
}
}()
tx, err := targetConn.Begin()
if err != nil {
pw.CloseWithError(err)
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
return err
}
err = sqlparser.SQLSplit(pr, func(stmt string) error {
if _, err := targetConn.TxExecContext(ctx, tx, stmt); err != nil {
pw.CloseWithError(err)
return err
}
return nil
})
if err != nil {
tx.Rollback()
return err
}
_ = tx.Commit()
return nil
})
}
err = errGroup.Wait()
if err != nil {
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
return
}
app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("execute transfer task [taskId = %d] complete, time: %v", taskId, time.Since(start)), nil, nil)
}
@@ -281,10 +359,7 @@ func (app *dbTransferAppImpl) transfer2File(ctx context.Context, taskId uint64,
}
// 从tables提取表名
tableNames := make([]string, 0)
for _, table := range tables {
tableNames = append(tableNames, table.TableName)
}
tableNames := collx.ArrayMap(tables, func(t dbi.Table) string { return t.TableName })
// 2、把源库数据迁移到文件
app.Log(ctx, logId, fmt.Sprintf("start transfer table data to files: %s", filename))
app.Log(ctx, logId, fmt.Sprintf("dialect type of target db file: %s", task.TargetFileDbType))
@@ -341,228 +416,6 @@ func (app *dbTransferAppImpl) Stop(ctx context.Context, taskId uint64) error {
return nil
}
// 迁移表
func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64, task *entity.DbTransferTask, srcConn *dbi.DbConn, targetConn *dbi.DbConn, tables []dbi.Table) error {
tableNames := make([]string, 0)
tableMap := make(map[string]dbi.Table) // 以表名分组,存放表信息
for _, table := range tables {
tableNames = append(tableNames, table.TableName)
tableMap[table.TableName] = table
}
if len(tableNames) == 0 {
return errorx.NewBiz("there are no tables to migrate")
}
srcDialect := srcConn.GetDialect()
srcMetadata := srcConn.GetMetadata()
// 查询源表列信息
columns, err := srcMetadata.GetColumns(tableNames...)
if err != nil {
return errorx.NewBiz("failed to get the source table column information")
}
// 以表名分组,存放每个表的列信息
columnMap := make(map[string][]dbi.Column)
for _, column := range columns {
columnMap[column.TableName] = append(columnMap[column.TableName], column)
}
// 以表名排序
sortTableNames := collx.MapKeys(columnMap)
sort.Strings(sortTableNames)
targetDialect := targetConn.GetDialect()
srcColumnHelper := srcDialect.GetColumnHelper()
targetColumnHelper := targetConn.GetDialect().GetColumnHelper()
// 分组迁移
tableGroups := collx.ArraySplit[string](sortTableNames, 2)
errGroup, _ := errgroup.WithContext(ctx)
for _, tables := range tableGroups {
errGroup.Go(func() error {
for _, tbName := range tables {
cols := columnMap[tbName]
targetCols := make([]dbi.Column, 0)
for _, col := range cols {
colPtr := &col
// 源库列转为公共列
srcColumnHelper.ToCommonColumn(colPtr)
// 公共列转为目标库列
targetColumnHelper.ToColumn(colPtr)
targetCols = append(targetCols, *colPtr)
}
// 通过公共列信息生成目标库的建表语句,并执行目标库建表
app.Log(ctx, logId, fmt.Sprintf("start creating the target table: %s", tbName))
sqlArr := targetDialect.GenerateTableDDL(targetCols, tableMap[tbName], true)
for _, sqlStr := range sqlArr {
_, err := targetConn.Exec(sqlStr)
if err != nil {
return errorx.NewBiz(fmt.Sprintf("failed to create target table: %s, error: %s", tbName, err.Error()))
}
}
app.Log(ctx, logId, fmt.Sprintf("target table created successfully: %s", tbName))
// 迁移数据
app.Log(ctx, logId, fmt.Sprintf("start transfer data: %s", tbName))
total, err := app.transferData(ctx, logId, task.Id, tbName, targetCols, srcConn, targetConn)
if err != nil {
return errorx.NewBiz(fmt.Sprintf("failed to transfer => table: %s, error: %s", tbName, err.Error()))
}
app.Log(ctx, logId, fmt.Sprintf("successfully transfer data => table: %s, data: %d entries", tbName, total))
// 有些数据库迁移完数据之后,需要更新表自增序列为当前表最大值
targetDialect.UpdateSequence(tbName, targetCols)
// 迁移索引信息
app.Log(ctx, logId, fmt.Sprintf("start transfer index => table: %s", tbName))
err = app.transferIndex(ctx, tableMap[tbName], srcConn, targetConn)
if err != nil {
return errorx.NewBiz(fmt.Sprintf("failed to transfer index => table: %s, error: %s", tbName, err.Error()))
}
app.Log(ctx, logId, fmt.Sprintf("successfully transfer index => table: %s", tbName))
}
return nil
})
}
return errGroup.Wait()
}
func (app *dbTransferAppImpl) transferData(ctx context.Context, logId uint64, taskId uint64, tableName string, targetColumns []dbi.Column, srcConn *dbi.DbConn, targetConn *dbi.DbConn) (int, error) {
result := make([]map[string]any, 0)
total := 0 // 总条数
batchSize := 1000 // 每次查询并迁移1000条数据
var err error
srcDialect := srcConn.GetDialect()
srcConverter := srcDialect.GetDataHelper()
targetDialect := targetConn.GetDialect()
logExtraKey := fmt.Sprintf("`%s` amount of transfer data currently: ", tableName)
// 游标查询源表数据,并批量插入目标表
_, err = srcConn.WalkTableRows(context.Background(), tableName, func(row map[string]any, columns []*dbi.QueryColumn) error {
total++
rawValue := map[string]any{}
for _, column := range columns {
// 某些情况如oracle需要转换时间类型的字符串为time类型
res := srcConverter.ParseData(row[column.Name], srcConverter.GetDataType(column.Type))
rawValue[column.Name] = res
}
result = append(result, rawValue)
if total%batchSize == 0 {
err = app.transfer2Target(taskId, targetConn, targetColumns, result, targetDialect, tableName)
if err != nil {
logx.ErrorfContext(ctx, "batch insert data to target table failed: %v", err)
return err
}
result = result[:0]
app.logApp.SetExtra(logId, logExtraKey, total)
}
return nil
})
if err != nil {
return total, err
}
// 处理剩余的数据
if len(result) > 0 {
err = app.transfer2Target(taskId, targetConn, targetColumns, result, targetDialect, tableName)
if err != nil {
logx.ErrorfContext(ctx, "batch insert data to target table failed => table: %s, error: %v", tableName, err)
return 0, err
}
}
// 置空当前表数据迁移量进度
app.logApp.SetExtra(logId, logExtraKey, nil)
return total, err
}
func (app *dbTransferAppImpl) transfer2Target(taskId uint64, targetConn *dbi.DbConn, targetColumns []dbi.Column, result []map[string]any, targetDialect dbi.Dialect, tbName string) error {
if !app.IsRunning(taskId) {
return errorx.NewBiz("transfer stopped")
}
tx, err := targetConn.Begin()
if err != nil {
return err
}
// 收集字段名
var columnNames []string
for _, col := range targetColumns {
columnNames = append(columnNames, targetDialect.QuoteIdentifier(col.ColumnName))
}
dataHelper := targetDialect.GetDataHelper()
// 从目标库数据中取出源库字段对应的值
values := make([][]any, 0)
for _, record := range result {
rawValue := make([]any, 0)
for _, tc := range targetColumns {
columnName := tc.ColumnName
val := record[targetDialect.RemoveQuote(columnName)]
if !tc.Nullable {
// 如果val是文本则设置为空格字符
switch val.(type) {
case string:
if val == "" {
val = " "
}
}
}
if dataHelper.GetDataType(string(tc.DataType)) == dbi.DataTypeBlob {
decodeBytes, err := hex.DecodeString(val.(string))
if err == nil {
val = decodeBytes
}
}
rawValue = append(rawValue, val)
}
values = append(values, rawValue)
}
// 批量插入
_, err = targetDialect.BatchInsert(tx, tbName, columnNames, values, -1)
defer func() {
if r := recover(); r != nil {
tx.Rollback()
logx.Errorf("batch insert data to target table failed: %v", r)
}
}()
_ = tx.Commit()
return err
}
func (app *dbTransferAppImpl) transferIndex(ctx context.Context, tableInfo dbi.Table, srcConn *dbi.DbConn, targetConn *dbi.DbConn) error {
// 查询源表索引信息
indexs, err := srcConn.GetMetadata().GetTableIndex(tableInfo.TableName)
if err != nil {
logx.Errorf("failed to get index information: %s", err)
return err
}
if len(indexs) == 0 {
return nil
}
// 通过表名、索引信息生成建索引语句,并执行到目标表
sqlArr := targetConn.GetDialect().GenerateIndexDDL(indexs, tableInfo)
for _, sqlStr := range sqlArr {
_, err := targetConn.Exec(sqlStr)
if err != nil {
return err
}
}
return nil
}
func (d *dbTransferAppImpl) TimerDeleteTransferFile() {
logx.Debug("start deleting transfer files periodically...")
scheduler.AddFun("@every 100m", func() {

View File

@@ -23,10 +23,16 @@ type DumpDb struct {
LogId uint64
Writer io.WriteCloser
Log func(msg string)
TargetDbType dbi.DbType
Log func(msg string)
Progress func(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) // dump进度
}
func DefaultDumpLog(msg string) {
}
func DefaultDumpProgress(currentTable string, stmtType dbi.StmtType, stmtCount int, currentStmtTypeEnd bool) {
}

View File

@@ -0,0 +1,31 @@
package dto
import (
"io"
"mayfly-go/internal/db/dbm/dbi"
)
type DbSqlExecReq struct {
DbId uint64
Db string
Sql string // 需要执行的sql支持多条
Remark string // 执行备注
DbConn *dbi.DbConn
CheckFlow bool // 是否检查存储审批流程
}
type DbSqlExecRes struct {
Sql string `json:"sql"` // 执行的sql
ErrorMsg string `json:"errorMsg"` // 若执行失败,则将失败内容记录到该字段
Columns []*dbi.QueryColumn `json:"columns"` // 响应的列信息
Res []map[string]any `json:"res"` // 响应结果
}
type SqlReaderExec struct {
DbConn *dbi.DbConn
Reader io.Reader
Filename string
ClientId string // 客户端id若存在则会向其发送执行进度消息
}

View File

@@ -0,0 +1,594 @@
package dbi
import (
"cmp"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"mayfly-go/pkg/utils/collx"
"strings"
"time"
"github.com/may-fly/cast"
)
var (
dbDataTypes = make(map[DbType]map[string]*DbDataType) // 列类型
)
// registerColumnDbDataTypes 注册数据库对应的数据类型
func registerColumnDbDataTypes(dbType DbType, cts ...*DbDataType) {
dbDataTypes[dbType] = collx.ArrayToMap(cts, func(ct *DbDataType) string {
return strings.ToLower(string(ct.Name))
})
}
func GetDbDataType(dbType DbType, databaseColumnType string) *DbDataType {
return cmp.Or(dbDataTypes[dbType][strings.ToLower(databaseColumnType)], DefaultDbDataType)
}
var DefaultDbDataType = NewDbDataType("string", DTString).WithCT(CTVarchar)
// 表的列信息
type Column struct {
TableName string `json:"tableName"` // 表名
ColumnName string `json:"columnName"` // 列名
DataType string `json:"dataType"` // 数据类型
ColumnComment string `json:"columnComment"` // 列备注
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否为主键
IsIdentity bool `json:"isIdentity"` // 是否自增
ColumnDefault string `json:"columnDefault"` // 默认值
Nullable bool `json:"nullable"` // 是否可为null
CharMaxLength int `json:"charMaxLength"` // 字符最大长度
NumPrecision int `json:"numPrecision"` // 精度(总数字位数)
NumScale int `json:"numScale"` // 小数点位数
Extra collx.M `json:"extra"` // 其他额外信息
}
// GetColumnType 获取完整的列类型拼接数据类型与长度等。如varchar(2000)decimal(20,2)
func (c *Column) GetColumnType() string {
if c.CharMaxLength > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.CharMaxLength)
}
if c.NumPrecision > 0 {
if c.NumScale > 0 {
return fmt.Sprintf("%s(%d,%d)", c.DataType, c.NumPrecision, c.NumScale)
} else {
return fmt.Sprintf("%s(%d)", c.DataType, c.NumPrecision)
}
}
return string(c.DataType)
}
// 数据库对应的数据类型
type DbDataType struct {
Name string // 类型名
DataType *DataType // 数据类型
fixColumnFunc func(column *Column) // 修复字段长度、精度等, 如mysql text会返回长度需要将其置为0等
/** 以下为异构数据迁移同步使用,可不赋值,无值则不支持迁移同步 */
CommonType CommonDbDataType // 对应的公共类型
}
// WithFixColumn 修复列信息函数,用于修复字段长度、精度等
func (ct *DbDataType) WithFixColumn(fixColumnFunc func(column *Column)) *DbDataType {
ct.fixColumnFunc = fixColumnFunc
return ct
}
// WithCT 对应的公共类型,主要用于异构数据库迁移同步时进行类型转换使用
func (ct *DbDataType) WithCT(cct CommonDbDataType) *DbDataType {
ct.CommonType = cct
return ct
}
// FixColumn 使用修复列信息函数进行列信息修复
func (ct *DbDataType) FixColumn(column *Column) {
if ct.fixColumnFunc != nil {
ct.fixColumnFunc(column)
}
}
func NewDbDataType(name string, dataType *DataType) *DbDataType {
return &DbDataType{
Name: name,
DataType: dataType,
}
}
func ClearCharMaxLength(column *Column) {
column.CharMaxLength = 0
column.NumPrecision = 0
}
func ClearNumScale(column *Column) {
column.NumScale = 0
column.CharMaxLength = 0
}
// DataType 数据类型, 对应于go类型如int int64等。可自定义其他类型
type DataType struct {
Name string // 类型名
Valuer func() Valuer // 获取值对应的处理者用于sql的scan、解析value等
SQLValue func(val any) string // 转换为sql字符串值用于insert等SQL语句的值转换
}
// Copy 拷贝一个同类型的datatype主要方便用于定制化修改Valuer或ToString
func (dt *DataType) Copy() *DataType {
return &DataType{
Name: dt.Name,
Valuer: dt.Valuer,
SQLValue: dt.SQLValue,
}
}
func (dt *DataType) WithValuer(valuerFunc func() Valuer) *DataType {
dt.Valuer = valuerFunc
return dt
}
func (dt *DataType) WithSQLValue(sqlvalueFunc func(val any) string) *DataType {
dt.SQLValue = sqlvalueFunc
return dt
}
const NULL = "NULL"
// SQLValueDefault 默认使用fmt转string
func SQLValueDefault(val any) string {
if val == nil {
return NULL
}
return fmt.Sprintf("'%v'", val)
}
// SQLValueNumeric 数字类型转string
func SQLValueNumeric(val any) string {
if val == nil {
return NULL
}
return fmt.Sprintf("%v", val)
}
func SQLValueString(val any) string {
if val == nil {
return NULL
}
strVal, ok := val.(string)
if !ok {
return fmt.Sprintf("%v", val)
}
return fmt.Sprintf("'%s'", strings.ReplaceAll(strings.ReplaceAll(strVal, "'", "''"), `\`, `\\`))
}
var (
DTBit = &DataType{
Name: "bit",
Valuer: ValuerBit,
SQLValue: SQLValueNumeric,
}
DTByte = &DataType{
Name: "uint8",
Valuer: ValuerByte,
SQLValue: SQLValueNumeric,
}
DTInt8 = &DataType{
Name: "int8",
Valuer: ValuerInt16,
SQLValue: SQLValueNumeric,
}
DTInt16 = &DataType{
Name: "int16",
Valuer: ValuerInt16,
SQLValue: SQLValueNumeric,
}
DTInt32 = &DataType{
Name: "int32",
Valuer: ValuerInt32,
SQLValue: SQLValueNumeric,
}
DTInt64 = &DataType{
Name: "int64",
Valuer: ValuerInt64,
SQLValue: SQLValueNumeric,
}
// 所有无符号类型都使用int64存储
DTUint64 = &DataType{
Name: "uint64",
Valuer: ValuerUint64,
SQLValue: SQLValueNumeric,
}
DTNumeric = &DataType{
Name: "numeric",
Valuer: ValuerFloat64,
SQLValue: SQLValueNumeric,
}
DTDecimal = &DataType{
Name: "decimal",
Valuer: ValuerString,
SQLValue: SQLValueNumeric,
}
DTString = &DataType{
Name: "string",
Valuer: ValuerString,
SQLValue: SQLValueString,
}
DTDate = &DataType{
Name: "date",
Valuer: ValuerDate,
SQLValue: SQLValueDefault,
}
DTTime = &DataType{
Name: "time",
Valuer: ValuerTime,
SQLValue: SQLValueDefault,
}
DTDateTime = &DataType{
Name: "datetime",
Valuer: ValuerDatetime,
SQLValue: SQLValueDefault,
}
DTBytes = &DataType{
Name: "bytes",
Valuer: ValuerBytes,
SQLValue: SQLValueDefault,
}
)
// Valuer 获取值对应的处理者用于sql row scan、解析value等
type Valuer interface {
// NewValuePtr 新建值对应的指针用于sql的row scan
NewValuePtr() any
// Value 获取对应的值人类可阅读的值不可原样返回ValuePtr指针类型需取出具体的值
Value() any
}
type DefaultValuer[T any] struct {
ValuePtr *T
}
func (s *DefaultValuer[T]) NewValuePtr() any {
var t T
s.ValuePtr = &t
return s.ValuePtr
}
// Valuer工厂函数
func ValuerString() Valuer {
return &stringValuer{
DefaultValuer: new(DefaultValuer[sql.NullString]),
}
}
func ValuerInt64() Valuer {
return &int64Valuer{
DefaultValuer: new(DefaultValuer[sql.NullInt64]),
}
}
func ValuerUint64() Valuer {
return &uint64Valuer{
DefaultValuer: new(DefaultValuer[[]byte]),
}
}
func ValuerInt32() Valuer {
return &int32Valuer{
DefaultValuer: new(DefaultValuer[sql.NullInt32]),
}
}
func ValuerInt16() Valuer {
return &int16Valuer{
DefaultValuer: new(DefaultValuer[sql.NullInt16]),
}
}
func ValuerByte() Valuer {
return &byteValuer{
DefaultValuer: new(DefaultValuer[sql.NullByte]),
}
}
func ValuerBit() Valuer {
return &bitValuer{
DefaultValuer: new(DefaultValuer[[]byte]),
}
}
func ValuerFloat64() Valuer {
return &float64Valuer{
DefaultValuer: new(DefaultValuer[sql.NullFloat64]),
}
}
func ValuerDatetime() Valuer {
return &datetimeValuer{
DefaultValuer: new(DefaultValuer[NullTime]),
}
}
func ValuerDate() Valuer {
return &dateValuer{
DefaultValuer: new(DefaultValuer[NullTime]),
}
}
func ValuerTime() Valuer {
return &timeValuer{
DefaultValuer: new(DefaultValuer[NullTime]),
}
}
func ValuerBytes() Valuer {
return &bytesValuer{
DefaultValuer: new(DefaultValuer[sql.RawBytes]),
}
}
// 默认 valuer
// string
type stringValuer struct {
*DefaultValuer[sql.NullString]
}
func (s *stringValuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.String
}
return nil
}
// uint64
type uint64Valuer struct {
*DefaultValuer[[]byte]
}
func (s *uint64Valuer) Value() any {
valBytes := *s.ValuePtr
if valBytes == nil {
return nil
}
val := string(valBytes)
// 前端超过16位会丢失精度
if len(val) > 16 {
return val
}
return cast.ToUint64(val)
}
// int64
type int64Valuer struct {
*DefaultValuer[sql.NullInt64]
}
func (s *int64Valuer) Value() any {
if s.ValuePtr.Valid {
val := s.ValuePtr.Int64
// 前端超过16位会丢失精度
if val > 9999999999999999 {
return fmt.Sprintf("%d", val)
}
return val
}
return nil
}
// int32
type int32Valuer struct {
*DefaultValuer[sql.NullInt32]
}
func (s *int32Valuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Int32
}
return nil
}
// int16
type int16Valuer struct {
*DefaultValuer[sql.NullInt16]
}
func (s *int16Valuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Int16
}
return nil
}
// byteuint8
type byteValuer struct {
*DefaultValuer[sql.NullByte]
}
func (s *byteValuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Byte
}
return nil
}
// bit
type bitValuer struct {
*DefaultValuer[[]byte]
}
func (s *bitValuer) Value() any {
valBytes := *s.ValuePtr
if valBytes == nil {
return nil
}
return valBytes[0]
}
// float64
type float64Valuer struct {
*DefaultValuer[sql.NullFloat64]
}
func (s *float64Valuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Float64
}
return nil
}
// bytes
type bytesValuer struct {
*DefaultValuer[sql.RawBytes]
}
func (s *bytesValuer) Value() any {
val := s.ValuePtr
if *val == nil {
return nil
}
return hex.EncodeToString(*val)
}
// datetime
type datetimeValuer struct {
*DefaultValuer[NullTime]
}
func (s *datetimeValuer) NewValuePtr() any {
s.ValuePtr = &NullTime{
Layout: time.DateTime,
}
return s.ValuePtr
}
func (s *datetimeValuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Time
}
return nil
}
// date
type dateValuer struct {
*DefaultValuer[NullTime]
}
func (s *dateValuer) NewValuePtr() any {
s.ValuePtr = &NullTime{
Layout: time.DateOnly,
}
return s.ValuePtr
}
func (s *dateValuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Time
}
return nil
}
// time
type timeValuer struct {
*DefaultValuer[NullTime]
}
func (s *timeValuer) NewValuePtr() any {
s.ValuePtr = &NullTime{
Layout: time.TimeOnly,
}
return s.ValuePtr
}
func (s *timeValuer) Value() any {
if s.ValuePtr.Valid {
return s.ValuePtr.Time
}
return nil
}
// NullTime represents a time that may be null.
// NullTime implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullTime struct {
Time string
Valid bool // Valid is true if Time is not NULL
Layout string
}
var (
_ driver.Valuer = NullTime{}
)
// Scan implements the [Scanner] interface.
func (n *NullTime) Scan(value any) error {
if value == nil {
n.Time, n.Valid = "", false
return nil
}
n.Valid = true
time, err := convertTime(value, n.Layout)
if err != nil {
return err
}
n.Time = time
return nil
}
// Value implements the driver Valuer interface.
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func convertTime(src interface{}, layout string) (string, error) {
switch s := src.(type) {
case string:
return s, nil
case []uint8:
return string(s), nil
case time.Time:
return s.Format(layout), nil
default:
return "", nil
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/machine/mcm"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
@@ -23,19 +24,33 @@ type DbConn struct {
// 执行数据库查询返回的列信息
type QueryColumn struct {
Name string `json:"name"` // 列名
Type string `json:"type"` // 类型
Type string `json:"type"` // 数据类型
SqlColType *sql.ColumnType `json:"-"`
DbDataType *DbDataType `json:"-"`
valuer Valuer `json:"-"`
}
func NewQueryColumn(colName string, col *sql.ColumnType) *QueryColumn {
func NewQueryColumn(colName string, columnType *DbDataType) *QueryColumn {
return &QueryColumn{
Name: col.Name(),
Type: col.DatabaseTypeName(),
SqlColType: col,
Name: colName,
Type: columnType.DataType.Name,
DbDataType: columnType,
valuer: columnType.DataType.Valuer(),
}
}
func (qc *QueryColumn) getValuePtr() any {
return qc.valuer.NewValuePtr()
}
func (qc *QueryColumn) value() any {
return qc.valuer.Value()
}
func (qc *QueryColumn) SQLValue(val any) any {
return qc.DbDataType.DataType.SQLValue(val)
}
func (d *DbConn) GetDb() *sql.DB {
return d.db
}
@@ -80,7 +95,7 @@ func (d *DbConn) Query2Struct(execSql string, dest any) error {
// WalkQueryRows 游标方式遍历查询结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
func (d *DbConn) WalkQueryRows(ctx context.Context, querySql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
return walkQueryRows(ctx, d.GetDialect(), d.db, querySql, walkFn, args...)
return d.walkQueryRows(ctx, querySql, walkFn, args...)
}
// WalkTableRows 游标方式遍历指定表的结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
@@ -138,6 +153,11 @@ func (d *DbConn) GetMetadata() Metadata {
return d.Info.Meta.GetMetadata(d)
}
// GetDbDataType 获取定义的数据库数据类型
func (d *DbConn) GetDbDataType(dataType string) *DbDataType {
return GetDbDataType(d.Info.Type, dataType)
}
// Stats 返回数据库连接状态
func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBStats {
return d.db.Stats()
@@ -149,8 +169,8 @@ func (d *DbConn) Close() {
if err := d.db.Close(); err != nil {
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
}
// 如果是达梦并且使用了ssh隧道则需要手动将其关闭
if d.Info.Type == DbTypeDM && d.Info.SshTunnelMachineId > 0 {
// 如果是使用了自己实现的ssh隧道转发,则需要手动将其关闭
if d.Info.useSshTunnel {
mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id))
}
d.db = nil
@@ -158,11 +178,11 @@ func (d *DbConn) Close() {
}
// 游标方式遍历查询rows, walkFn error不为nil, 则跳出遍历
func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
func (d *DbConn) walkQueryRows(ctx context.Context, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
cancelCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
rows, err := db.QueryContext(cancelCtx, selectSql, args...)
rows, err := d.db.QueryContext(cancelCtx, selectSql, args...)
if err != nil {
return nil, err
}
@@ -170,8 +190,6 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
defer rows.Close()
columnHelper := dialect.GetColumnHelper()
colTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
@@ -188,9 +206,9 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
if colName == "" {
colName = fmt.Sprintf("<anonymous%d>", k+1)
}
qc := NewQueryColumn(colName, colType)
qc := NewQueryColumn(colName, d.GetDbDataType(colType.DatabaseTypeName()))
cols[k] = qc
scans[k] = columnHelper.GetScanDestPtr(qc)
scans[k] = qc.getValuePtr()
}
for rows.Next() {
@@ -201,8 +219,8 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
// 每行数据
rowData := make(map[string]any, lenCols)
// 把values中的数据复制到row中
for i, v := range scans {
rowData[cols[i].Name] = columnHelper.ConvertScanDestValue(v, cols[i])
for i := range scans {
rowData[cols[i].Name] = cols[i].value()
}
if err = walkFn(rowData, cols); err != nil {
logx.ErrorfContext(ctx, "[%s] cursor traversal query result set error, exit traversal: %s", selectSql, err.Error())

View File

@@ -11,19 +11,6 @@ import (
type DbType string
const (
DbTypeMysql DbType = "mysql"
DbTypeMariadb DbType = "mariadb"
DbTypePostgres DbType = "postgres"
DbTypeGauss DbType = "gauss"
DbTypeDM DbType = "dm"
DbTypeOracle DbType = "oracle"
DbTypeSqlite DbType = "sqlite"
DbTypeMssql DbType = "mssql"
DbTypeKingbaseEs DbType = "kingbaseEs"
DbTypeVastbase DbType = "vastbase"
)
func ToDbType(dbType string) DbType {
return DbType(dbType)
}
@@ -52,6 +39,7 @@ type DbInfo struct {
CodePath []string
SshTunnelMachineId int
useSshTunnel bool // 是否使用系统自己实现的ssh隧道连接,而非库自带的
Meta Meta
}
@@ -116,6 +104,7 @@ func (di *DbInfo) IfUseSshTunnelChangeIpPort() error {
}
di.Host = exposedIp
di.Port = exposedPort
di.useSshTunnel = true
}
return nil
}

View File

@@ -1,22 +1,12 @@
package dbi
import (
"database/sql"
"encoding/hex"
"errors"
"io"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/dbm/sqlparser/pgsql"
"reflect"
"strconv"
"strings"
pq "gitee.com/liuzongyang/libpq"
"github.com/may-fly/cast"
)
const DefaultQuoter = `"`
const (
// -1. 无操作
DuplicateStrategyNone = -1
@@ -36,47 +26,15 @@ type DbCopyTable struct {
// BaseDialect 基础dialect在DefaultDialect 都有默认的实现方法
type BaseDialect interface {
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := quoteIdentifier(tblname, '"')
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
QuoteIdentifier(name string) string
RemoveQuote(name string) string
// QuoteEscape 引号转义多用于sql注释转义防止拼接sql报错 comment xx is '注''释' 最终注释文本为: 注'释
QuoteEscape(str string) string
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
QuoteLiteral(literal string) string
// Quoter sql关键字引用处理如 table -> `table`、table -> "table"
Quoter() Quoter
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
GetDbProgram() (DbProgram, error)
// GetColumnHelper
GetColumnHelper() ColumnHelper
// GetDumpHeler
GetDumpHelper() DumpHelper
// GetDataHelper 获取数据处理助手 用于解析格式化列数据等
GetDataHelper() DataHelper
// GetSQLParser 获取sql解析器
GetSQLParser() sqlparser.SqlParser
}
@@ -86,20 +44,11 @@ type BaseDialect interface {
type Dialect interface {
BaseDialect
// BatchInsert 批量insert数据
BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error)
// CopyTable 拷贝表
CopyTable(copy *DbCopyTable) error
// GenerateTableDDL 生成建表ddl
GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string
// GenerateIndexDDL 生成索引ddl
GenerateIndexDDL(indexs []Index, tableInfo Table) []string
// UpdateSequence 有些数据库迁移完数据之后,需要更新表自增序列为当前表最大值
UpdateSequence(tableName string, columns []Column)
// GetSQLGenerator 获取sql生成器
GetSQLGenerator() SQLGenerator
}
// DefaultDialect 默认实现若需要覆盖则由各个数据库dialect实现去覆盖重写
@@ -108,29 +57,10 @@ type DefaultDialect struct {
var _ (BaseDialect) = (*DefaultDialect)(nil)
func (dx *DefaultDialect) QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return DefaultQuoter + strings.Replace(name, DefaultQuoter, DefaultQuoter+DefaultQuoter, -1) + DefaultQuoter
func (dx *DefaultDialect) Quoter() Quoter {
return DefaultQuoter
}
func (dx *DefaultDialect) RemoveQuote(name string) string {
return strings.ReplaceAll(name, DefaultQuoter, "")
}
func (dd *DefaultDialect) QuoteEscape(str string) string {
return strings.Replace(str, `'`, `''`, -1)
}
func (dd *DefaultDialect) QuoteLiteral(literal string) string {
return pq.QuoteLiteral(literal)
}
func (dd *DefaultDialect) UpdateSequence(tableName string, columns []Column) {}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (dd *DefaultDialect) GetDbProgram() (DbProgram, error) {
return nil, errors.New("not support db program")
}
@@ -139,113 +69,10 @@ func (dd *DefaultDialect) GetDumpHelper() DumpHelper {
return new(DefaultDumpHelper)
}
func (dd *DefaultDialect) GetColumnHelper() ColumnHelper {
return new(DefaultColumnHelper)
}
func (pd *DefaultDialect) GetSQLParser() sqlparser.SqlParser {
return new(pgsql.PgsqlParser)
}
func (pd *DefaultDialect) GetDataHelper() DataHelper {
return nil
}
// ColumnHelper 数据库迁移辅助方法
type ColumnHelper interface {
// ToCommonColumn 数据库方言自带的列转换为公共列
ToCommonColumn(dialectColumn *Column)
// ToColumn 公共列转为各个数据库方言自带的列
ToColumn(commonColumn *Column)
// FixColumn 根据数据库类型修复字段长度、精度等
FixColumn(column *Column)
// GetScanDestPtr 获取scan列目标值指针用于在*sql.Rows.Scan()填充该值
GetScanDestPtr(*QueryColumn) any
// ConvertScanDestValue 将scan的填充的原始值data转为可阅读的值如[]byte -> string or number...
ConvertScanDestValue(data any, qc *QueryColumn) any
}
type DefaultColumnHelper struct {
}
func (dd *DefaultColumnHelper) ToCommonColumn(dialectColumn *Column) {}
func (dd *DefaultColumnHelper) ToColumn(commonColumn *Column) {}
func (dd *DefaultColumnHelper) FixColumn(column *Column) {}
func (dd *DefaultColumnHelper) GetScanDestPtr(qc *QueryColumn) any {
return &[]byte{}
}
func (dd *DefaultColumnHelper) ConvertScanDestValue(data any, qc *QueryColumn) any {
if data == nil {
return nil
}
colType := qc.SqlColType
// 列的数据库类型名
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
stringV := ""
if slicePtr, ok := data.(*[]uint8); ok {
bytes := *slicePtr
if bytes == nil {
return nil
}
// 如果类型是bit则直接返回第一个字节即可
if strings.Contains(colDatabaseTypeName, "bit") {
return (bytes)[0]
}
if colDatabaseTypeName == "blob" {
return hex.EncodeToString(bytes)
}
// 把[]byte数据转成string
stringV = string(bytes)
} else {
stringV = cast.ToString(data)
}
if colType == nil || colType.ScanType() == nil {
return stringV
}
colScanType := strings.ToLower(colType.ScanType().Name())
if strings.Contains(colScanType, "int") {
// 如果长度超过16位则返回字符串因为前端js长度大于16会丢失精度
if len(stringV) > 16 {
return stringV
}
intV, _ := strconv.Atoi(stringV)
switch colType.ScanType().Kind() {
case reflect.Int8:
return int8(intV)
case reflect.Uint8:
return uint8(intV)
case reflect.Int64:
return int64(intV)
case reflect.Uint64:
return uint64(intV)
case reflect.Uint:
return uint(intV)
default:
return intV
}
}
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
floatV, _ := strconv.ParseFloat(stringV, 64)
return floatV
}
return stringV
}
// DumpHelper 导出辅助方法
type DumpHelper interface {
BeforeInsert(writer io.Writer, tableName string)
@@ -269,3 +96,14 @@ func (dd *DefaultDumpHelper) BeforeInsertSql(quoteSchema string, quoteTableName
func (dd *DefaultDumpHelper) AfterInsert(writer io.Writer, tableName string, columns []Column) {
writer.Write([]byte("COMMIT;\n"))
}
type SQLGenerator interface {
// GenTableDDL 生成建表语句
GenTableDDL(table Table, columns []Column, dropBeforeCreate bool) []string
// GenIndexDDL 生成索引语句
GenIndexDDL(table Table, indexs []Index) []string
// GenInsert 生成插入语句
GenInsert(tableName string, columns []Column, values [][]any, duplicateStrategy int) []string
}

View File

@@ -5,30 +5,59 @@ import (
)
var (
metas = make(map[DbType]Meta)
metas = make(map[DbType]Meta)
metaInit = make(map[DbType]bool)
)
// 注册数据库类型与dbmeta
func Register(dt DbType, meta Meta) {
metas[dt] = meta
}
// 根据数据库类型获取对应的Meta
func GetMeta(dt DbType) Meta {
return metas[dt]
}
type DbVersion string
// 数据库元信息如获取sql.DB、Dialect等
// Meta 数据库元信息如获取sql.DB、Dialect等
type Meta interface {
// GetSqlDb 根据数据库信息获取sql.DB
GetSqlDb(*DbInfo) (*sql.DB, error)
// GetDialect 获取数据库方言, 若一些接口如QuoteIdentifier不需要DbConn则可以传nil
// GetDialect 获取数据库方言, 若一些接口不需要DbConn则可以传nil
GetDialect(*DbConn) Dialect
// GetMetadata 获取元数据信息接口
// @param *DbConn 数据库连接
GetMetadata(*DbConn) Metadata
// GetDbDataTypes 获取所有数据库对应的数据类型
GetDbDataTypes() []*DbDataType
// GetCommonTypeConverter 获取公共类型转换器,用于迁移与同步
GetCommonTypeConverter() CommonTypeConverter
}
// 注册数据库类型与dbmeta
func Register(dt DbType, meta Meta) {
metas[dt] = meta
metaInit[dt] = false
}
// 根据数据库类型获取对应的Meta
func GetMeta(dt DbType) Meta {
// 未初始化,则进行初始化,如注册数据库类型等。防止未使用到的数据库都被注册
if inited := metaInit[dt]; !inited {
initMeta(dt, metas[dt])
}
return metas[dt]
}
// GetDialect 获取数据库方言如果dialect方法内需要用到dbConn的则不支持该方法
func GetDialect(dt DbType) Dialect {
// 创建一个假连接仅用于调用方言生成sql不做数据库连接操作
meta := GetMeta(dt)
dbConn := &DbConn{Info: &DbInfo{
Type: dt,
Meta: meta,
}}
return meta.GetDialect(dbConn)
}
// initMeta 初始化数据库类型,如注册数据库类型等
func initMeta(dt DbType, meta Meta) {
registerColumnDbDataTypes(dt, meta.GetDbDataTypes()...)
registerCommonTypeConverter(dt, meta.GetCommonTypeConverter())
}

View File

@@ -2,7 +2,6 @@ package dbi
import (
"embed"
"fmt"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
@@ -55,9 +54,6 @@ func (dd *DefaultMetadata) GetDefaultDb() string {
return ""
}
// GenerateSQLStepFunc 生成insert sql的step函数用于生成insert sql时每生成100条sql时调用
type GenerateSQLStepFunc func(sqlArr []string)
// 数据库服务实例信息
type DbServer struct {
Version string `json:"version"` // 版本信息
@@ -74,115 +70,16 @@ type Table struct {
IndexLength int64 `json:"indexLength"`
}
// 表的列信息
type Column struct {
TableName string `json:"tableName"` // 表名
ColumnName string `json:"columnName"` // 列名
DataType ColumnDataType `json:"dataType"` // 数据类型
ColumnComment string `json:"columnComment"` // 列备注
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否为主键
IsIdentity bool `json:"isIdentity"` // 是否自增
ColumnDefault string `json:"columnDefault"` // 默认值
Nullable bool `json:"nullable"` // 是否可为null
CharMaxLength int `json:"charMaxLength"` // 字符最大长度
NumPrecision int `json:"numPrecision"` // 精度(总数字位数)
NumScale int `json:"numScale"` // 小数点位数
Extra collx.M `json:"extra"` // 其他额外信息
}
// 拼接数据类型与长度等。如varchar(2000)decimal(20,2)
func (c *Column) GetColumnType() string {
// 哪些mysql数据类型不需要添加字段长度
if collx.ArrayAnyMatches([]string{"int", "blob", "float", "double", "date", "year", "json"}, string(c.DataType)) {
return string(c.DataType)
}
if c.DataType == "timestamp" {
return "timestamp(6)"
}
if c.CharMaxLength > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.CharMaxLength)
}
if c.NumPrecision > 0 {
if c.NumScale > 0 {
return fmt.Sprintf("%s(%d,%d)", c.DataType, c.NumPrecision, c.NumScale)
} else {
return fmt.Sprintf("%s(%d)", c.DataType, c.NumPrecision)
}
}
return string(c.DataType)
}
// 表索引信息
type Index struct {
IndexName string `json:"indexName"` // 索引名
ColumnName string `json:"columnName"` // 列名
IndexType string `json:"indexType"` // 索引类型
IndexComment string `json:"indexComment"` // 备注
SeqInIndex int `json:"seqInIndex"`
IsUnique bool `json:"isUnique"`
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引
}
type ColumnDataType string
const (
CommonTypeVarchar ColumnDataType = "varchar"
CommonTypeChar ColumnDataType = "char"
CommonTypeText ColumnDataType = "text"
CommonTypeBlob ColumnDataType = "blob"
CommonTypeLongblob ColumnDataType = "longblob"
CommonTypeLongtext ColumnDataType = "longtext"
CommonTypeBinary ColumnDataType = "binary"
CommonTypeMediumblob ColumnDataType = "mediumblob"
CommonTypeMediumtext ColumnDataType = "mediumtext"
CommonTypeVarbinary ColumnDataType = "varbinary"
CommonTypeInt ColumnDataType = "int"
CommonTypeBit ColumnDataType = "bit"
CommonTypeSmallint ColumnDataType = "smallint"
CommonTypeTinyint ColumnDataType = "tinyint"
CommonTypeNumber ColumnDataType = "number"
CommonTypeBigint ColumnDataType = "bigint"
CommonTypeDatetime ColumnDataType = "datetime"
CommonTypeDate ColumnDataType = "date"
CommonTypeTime ColumnDataType = "time"
CommonTypeTimestamp ColumnDataType = "timestamp"
CommonTypeEnum ColumnDataType = "enum"
CommonTypeJSON ColumnDataType = "json"
)
type DataType string
const (
DataTypeString DataType = "string"
DataTypeNumber DataType = "number"
DataTypeDate DataType = "date"
DataTypeTime DataType = "time"
DataTypeDateTime DataType = "datetime"
DataTypeBlob DataType = "blob"
)
// 列数据处理帮助方法
type DataHelper interface {
// 获取数据对应的类型
// @param dbColumnType 数据库原始列类型如varchar等
GetDataType(dbColumnType string) DataType
// 根据数据类型格式化指定数据
FormatData(dbColumnValue any, dataType DataType) string
// 根据数据类型解析数据为符合要求的指定类型等
ParseData(dbColumnValue any, dataType DataType) any
// WrapValue 根据数据类型包装value
// 1.数字型:不需要引号,
// 2.文本型:需要用引号包裹,单引号需要转义,换行符转义,
// 3.date型需要格式化成对应的字符串timehh:mm:ss.SSS date: yyyy-mm-dd datetime:
// 4.特殊oracle date型需要用函数包裹to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss')
WrapValue(dbColumnValue any, dataType DataType) string
IndexName string `json:"indexName"` // 索引名
ColumnName string `json:"columnName"` // 列名
IndexType string `json:"indexType"` // 索引类型
IndexComment string `json:"indexComment"` // 备注
SeqInIndex int `json:"seqInIndex"`
IsUnique bool `json:"isUnique"`
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引
Extra collx.M `json:"extra"` // 其他额外信息,如索引列的前缀长度等
}
// ------------------------- 元数据sql操作 -------------------------

View File

@@ -36,6 +36,7 @@ SELECT
IF(non_unique, 0, 1) isUnique,
SEQ_IN_INDEX seqInIndex,
INDEX_COMMENT indexComment,
SUB_PART subPart,
index_name = 'PRIMARY' as isPrimaryKey
FROM
information_schema.STATISTICS

View File

@@ -0,0 +1,190 @@
package dbi
import (
"mayfly-go/pkg/utils/collx"
"strings"
)
// Quoter represents a quoter to the SQL identifier. i.e. table name or column name
type Quoter struct {
Prefix byte
Suffix byte
IsReserved func(string) bool
}
var (
// AlwaysNoReserve always think it's not a reverse word
AlwaysNoReserve = func(string) bool { return false }
// AlwaysReserve always reverse the word
AlwaysReserve = func(string) bool { return true }
// DefaultQuoter represents the default quoter
DefaultQuote byte = '"'
// DefaultQuoter the default quoter
DefaultQuoter = Quoter{DefaultQuote, DefaultQuote, AlwaysReserve}
)
// IsEmpty return true if no prefix and suffix
func (q Quoter) IsEmpty() bool {
return q.Prefix == 0 && q.Suffix == 0
}
// Quote quote a string
func (q Quoter) Quote(s string) string {
var buf strings.Builder
_ = q.QuoteTo(&buf, s)
return buf.String()
}
// Strings quotes a slice of string
func (q Quoter) Quotes(s []string) []string {
return collx.ArrayMap[string, string](s, func(val string) string {
return q.Quote(val)
})
}
// QuoteTo quotes the identifier. if the quotes are [ and ]
//
// name -> [name]
// [name] -> [name]
// schema.name -> [schema].[name]
// [schema].name -> [schema].[name]
// schema.[name] -> [schema].[name]
// name AS a -> [name] AS a
// schema.name AS a -> [schema].[name] AS a
func (q Quoter) QuoteTo(buf *strings.Builder, value string) error {
var i int
for i < len(value) {
start := findStart(value, i)
if start > i {
if _, err := buf.WriteString(value[i:start]); err != nil {
return err
}
}
if start == len(value) {
return nil
}
nextEnd := findWord(value, start)
if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil {
return err
}
i = nextEnd
}
return nil
}
// Trim removes quotes from s
func (q Quoter) Trim(s string) string {
if len(s) < 2 {
return s
}
var buf strings.Builder
for i := 0; i < len(s); i++ {
switch {
case i == 0 && s[i] == q.Prefix:
case i == len(s)-1 && s[i] == q.Suffix:
case s[i] == q.Suffix && s[i+1] == '.':
case s[i] == q.Prefix && s[i-1] == '.':
default:
buf.WriteByte(s[i])
}
}
return buf.String()
}
// Join joins a slice with quoters
func (q Quoter) Join(a []string, separator string) string {
var b strings.Builder
_ = q.JoinWrite(&b, a, separator)
return b.String()
}
// JoinWrite writes quoted content to a builder
func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
if len(a) == 0 {
return nil
}
n := len(sep) * (len(a) - 1)
for i := 0; i < len(a); i++ {
n += len(a[i])
}
b.Grow(n)
for i, s := range a {
if i > 0 {
if _, err := b.WriteString(sep); err != nil {
return err
}
}
if err := q.QuoteTo(b, strings.TrimSpace(s)); err != nil {
return err
}
}
return nil
}
func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error {
if (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) ||
q.IsEmpty() || !q.IsReserved(word) || word == "*" {
if _, err := buf.WriteString(word); err != nil {
return err
}
return nil
}
if err := buf.WriteByte(q.Prefix); err != nil {
return err
}
if _, err := buf.WriteString(word); err != nil {
return err
}
return buf.WriteByte(q.Suffix)
}
func findWord(v string, start int) int {
for j := start; j < len(v); j++ {
switch v[j] {
case '.', ' ':
return j
}
}
return len(v)
}
func findStart(value string, start int) int {
if value[start] == '.' {
return start + 1
}
if value[start] != ' ' {
return start
}
k := -1
for j := start; j < len(value); j++ {
if value[j] != ' ' {
k = j
break
}
}
if k == -1 {
return len(value)
}
if k+1 < len(value) &&
(value[k] == 'A' || value[k] == 'a') &&
(value[k+1] == 'S' || value[k+1] == 's') {
k += 2
}
for j := k; j < len(value); j++ {
if value[j] != ' ' {
return j
}
}
return len(value)
}

View File

@@ -0,0 +1,106 @@
package dbi
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', AlwaysReserve}
kases = []struct {
expected string
value string
}{
{"[table]", "table"},
{"[table]", "[table]"},
{`[table].*`, `[table].*`},
{"[schema].[table]", "schema.table"},
{`["schema].[table"]`, `"schema.table"`},
{"[schema].[table] AS [table]", "schema.table AS table"},
{" [table]", " table"},
{" [table]", " table"},
{"[table] ", "table "},
{"[table] ", "table "},
{" [table] ", " table "},
{" [table] ", " table "},
}
)
for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
err := quoter.QuoteTo(buf, v.value)
assert.NoError(t, err)
assert.EqualValues(t, v.expected, buf.String())
})
}
}
func TestReversedQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', func(s string) bool {
return s == "table"
}}
kases = []struct {
expected string
value string
}{
{"[table]", "table"},
{"[table].*", `[table].*`},
{`"table"`, `"table"`},
{"schema.[table]", "schema.table"},
{"[schema].[table]", `[schema].table`},
{"schema.[table]", `schema.[table]`},
{"[schema].[table]", `[schema].[table]`},
{`"schema.table"`, `"schema.table"`},
{"schema.[table] AS table1", "schema.table AS table1"},
}
)
for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}
func TestJoin(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoter := Quoter{'[', ']', AlwaysReserve}
assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))
assert.EqualValues(t, "[a].*,[b].[c]", quoter.Join([]string{"a.*", " b.c"}, ","))
assert.EqualValues(t, "[b] [a]", quoter.Join([]string{"b a"}, ","))
assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", "))
quoter.IsReserved = AlwaysNoReserve
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
}
func TestQuotes(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3", "t4.*"}
quoter := Quoter{'[', ']', AlwaysReserve}
quotedCols := quoter.Quotes(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]", "[t4].*"}, quotedCols)
}
func TestTrim(t *testing.T) {
kases := map[string]string{
"[table_name]": "table_name",
"[schema].[table_name]": "schema.table_name",
}
for src, dst := range kases {
assert.EqualValues(t, src, DefaultQuoter.Trim(src))
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
}
}

View File

@@ -0,0 +1,89 @@
package dbi
import (
"fmt"
"mayfly-go/pkg/logx"
"strings"
)
type StmtType string
const (
StmtTypeSelect StmtType = "select"
StmtTypeInsert StmtType = "insert"
StmtTypeUpdate StmtType = "update"
StmtTypeDelete StmtType = "delete"
StmtTypeDDL StmtType = "ddl"
)
// GenTableDDL 生成通用表DDL
func GenTableDDL(dialect Dialect, md Metadata, tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
if len(tbs) == 0 {
logx.Errorf("get table error: %s", tableName)
return "", err
}
table := tbs[0]
// 2.获取列信息
columns, err := md.GetColumns(tableName)
if err != nil {
logx.Errorf("get columns error: %s", tableName)
return "", err
}
sqlGenerator := dialect.GetSQLGenerator()
tableDDLArr := sqlGenerator.GenTableDDL(table, columns, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
logx.Errorf("get indexs error: %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, sqlGenerator.GenIndexDDL(table, indexs)...)
return strings.Join(tableDDLArr, ";\n"), nil
}
// GenCommonInsert 生成通用insert sql
//
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
func GenCommonInsert(dialect Dialect, dbType DbType, tableName string, columns []Column, values [][]any) string {
quote := dialect.Quoter().Quote
columnStr, valuesStrs := GenInsertSqlColumnAndValues(dialect, dbType, columns, values)
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
return fmt.Sprintf("INSERT INTO %s %s VALUES \n%s", quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
}
// GenInsertSqlColumnAndValues 生成insert sql对应的 columes信息和values信息
//
// columnsStr -> (column1, column2, column3, ...)
// valuesStrs -> ['(value1, value2, value3, ...)', '(value1, value2, value3, ...)', ...]
func GenInsertSqlColumnAndValues(dialect Dialect, dbType DbType, columns []Column, values [][]any) (columnsStr string, valuesStrs []string) {
quote := dialect.Quoter().Quote
columnNames := make([]string, 0, len(columns))
columnTypes := make([]*DbDataType, len(columns))
strValueArr := make([]string, 0, len(values))
for i, column := range columns {
columnNames = append(columnNames, quote(column.ColumnName))
columnType := GetDbDataType(dbType, column.DataType)
columnTypes[i] = columnType
}
for _, value := range values {
vs := make([]string, 0, len(value))
for i, v := range value {
vs = append(vs, columnTypes[i].DataType.SQLValue(v))
}
strValueArr = append(strValueArr, fmt.Sprintf("(%s)", strings.Join(vs, ", ")))
}
return fmt.Sprintf("(%s)", strings.Join(columnNames, ", ")), strValueArr
}

View File

@@ -0,0 +1,154 @@
package dbi
import (
"fmt"
)
type CommonDbDataType int
// common column type enum
const (
CTVarchar CommonDbDataType = iota
CTChar
CTText
CTMediumtext
CTLongtext
CTBit // 1 bit
CTInt1 // 1字节 -128~127
CTInt2 // 2字节 -32768~32767
CTInt4 // 4字节 -2147483648~2147483647
CTInt8 // 8字节 -9223372036854775808~9223372036854775807
CTNumeric
CTDecimal
CTUnsignedInt8
CTUnsignedInt4
CTUnsignedInt2
CTUnsignedInt1
CTDate
CTTime
CTDateTime
CTTimestamp
CTBinary
CTVarbinary
CTMediumblob
CTBlob
CTLongblob
CTEnum
CTJSON
)
type CommonTypeConverter interface {
Varchar(*Column) *DbDataType
Char(*Column) *DbDataType
Text(*Column) *DbDataType
Mediumtext(*Column) *DbDataType
Longtext(*Column) *DbDataType
Bit(*Column) *DbDataType
Int1(*Column) *DbDataType
Int2(*Column) *DbDataType
Int4(*Column) *DbDataType
Int8(*Column) *DbDataType
Numeric(*Column) *DbDataType
Decimal(*Column) *DbDataType
UnsignedInt8(*Column) *DbDataType
UnsignedInt4(*Column) *DbDataType
UnsignedInt2(*Column) *DbDataType
UnsignedInt1(*Column) *DbDataType
Date(*Column) *DbDataType
Time(*Column) *DbDataType
Datetime(*Column) *DbDataType
Timestamp(*Column) *DbDataType
Binary(*Column) *DbDataType
Varbinary(*Column) *DbDataType
Mediumblob(*Column) *DbDataType
Blob(*Column) *DbDataType
Longblob(*Column) *DbDataType
Enum(*Column) *DbDataType
JSON(*Column) *DbDataType
}
var (
commonTypeConverters = make(map[DbType]map[CommonDbDataType]func(*Column) *DbDataType) // 公共列转换器
)
// registerCommonTypeConverter 注册公共列转换器
func registerCommonTypeConverter(dbType DbType, ctc CommonTypeConverter) {
if ctc == nil {
return
}
cts := make(map[CommonDbDataType]func(*Column) *DbDataType)
cts[CTVarchar] = ctc.Varchar
cts[CTChar] = ctc.Char
cts[CTText] = ctc.Text
cts[CTMediumtext] = ctc.Mediumtext
cts[CTLongtext] = ctc.Longtext
cts[CTBit] = ctc.Bit
cts[CTInt1] = ctc.Int1
cts[CTInt2] = ctc.Int2
cts[CTInt4] = ctc.Int4
cts[CTInt8] = ctc.Int8
cts[CTNumeric] = ctc.Numeric
cts[CTDecimal] = ctc.Decimal
cts[CTUnsignedInt8] = ctc.UnsignedInt8
cts[CTUnsignedInt4] = ctc.UnsignedInt4
cts[CTUnsignedInt2] = ctc.UnsignedInt2
cts[CTUnsignedInt1] = ctc.UnsignedInt1
cts[CTDate] = ctc.Date
cts[CTTime] = ctc.Time
cts[CTDateTime] = ctc.Datetime
cts[CTTimestamp] = ctc.Timestamp
cts[CTBinary] = ctc.Binary
cts[CTVarbinary] = ctc.Varbinary
cts[CTMediumblob] = ctc.Mediumblob
cts[CTBlob] = ctc.Blob
cts[CTLongblob] = ctc.Longblob
cts[CTEnum] = ctc.Enum
cts[CTJSON] = ctc.JSON
commonTypeConverters[dbType] = cts
}
// ConvToTargetDbColumn 转换至异构数据库对应的列信息
func ConvToTargetDbColumn(srcDbType DbType, targetDbType DbType, targetDialect Dialect, column *Column) error {
// 同类型数据库,不转换
if srcDbType == targetDbType {
return nil
}
srcMap := commonTypeConverters[srcDbType]
if srcMap == nil {
return fmt.Errorf("src database type [%s] not suport transfer", srcDbType)
}
targetMap := commonTypeConverters[targetDbType]
if targetMap == nil {
return fmt.Errorf("target database type [%s] not suport transfer", targetDbType)
}
srcDataType := GetDbDataType(srcDbType, column.DataType)
// 获取目标数据库的数据类型,并进行可能存在的列信息修复,如长度、精度等
targetDbDataType := targetMap[srcDataType.CommonType](column)
if targetDbDataType == nil {
return fmt.Errorf("target database type [%s] not suport transfer, src data type [%d]", targetDbType, srcDataType.CommonType)
}
// 替换为目标数据库的数据类型
column.DataType = targetDbDataType.Name
return nil
}

View File

@@ -0,0 +1,9 @@
package dbi
import (
"strings"
)
func QuoteEscape(str string) string {
return strings.Replace(str, `'`, `''`, -1)
}

View File

@@ -21,7 +21,7 @@ import (
var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
logx.Info(fmt.Sprintf("delete db conn cache, id = %s", key))
value.(*dbi.DbConn).Close()
})

View File

@@ -0,0 +1,106 @@
package dm
import (
"encoding/hex"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"strings"
"gitee.com/chunanyong/dm"
)
var (
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar)
VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
LONGVARCHAR = dbi.NewDbDataType("LONGVARCHAR", dbi.DTString).WithCT(dbi.CTLongtext)
IMAGE = dbi.NewDbDataType("IMAGE", dbi.DTString).WithCT(dbi.CTLongtext)
LONGVARBINARY = dbi.NewDbDataType("LONGVARBINARY", dbi.DTString).WithCT(dbi.CTLongtext)
CLOB = dbi.NewDbDataType("CLOB", dbi.DTString).WithCT(dbi.CTLongtext)
BLOB = dbi.NewDbDataType("BLOB", dbi.DTBytes).WithCT(dbi.CTBlob)
NUMERIC = dbi.NewDbDataType("NUMERIC", dbi.DTNumeric).WithCT(dbi.CTNumeric)
DECIMAL = dbi.NewDbDataType("DECIMAL", dbi.DTDecimal).WithCT(dbi.CTDecimal)
NUMBER = dbi.NewDbDataType("NUMBER", dbi.DTNumeric).WithCT(dbi.CTNumeric)
INTEGER = dbi.NewDbDataType("INTEGER", dbi.DTInt32).WithCT(dbi.CTInt4)
INT = dbi.NewDbDataType("INT", dbi.DTInt32).WithCT(dbi.CTInt4)
BIGINT = dbi.NewDbDataType("BIGINT", dbi.DTInt64).WithCT(dbi.CTInt8)
TINYINT = dbi.NewDbDataType("TINYINT", dbi.DTInt8).WithCT(dbi.CTInt1)
BYTE = dbi.NewDbDataType("BYTE", dbi.DTInt8).WithCT(dbi.CTInt1)
SMALLINT = dbi.NewDbDataType("SMALLINT", dbi.DTInt16).WithCT(dbi.CTInt2)
BIT = dbi.NewDbDataType("BIT", dbi.DTBit).WithCT(dbi.CTBit)
DOUBLE = dbi.NewDbDataType("DOUBLE", dbi.DTNumeric).WithCT(dbi.CTNumeric)
FLOAT = dbi.NewDbDataType("FLOAT", dbi.DTNumeric).WithCT(dbi.CTNumeric)
TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength)
ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等
ST_LINESTRING = dbi.NewDbDataType("ST_LINESTRING", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条或多条连续的线段
ST_GEOMCOLLECTION = dbi.NewDbDataType("ST_GEOMCOLLECTION", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个几何对象集合,可以包含多个不同类型的几何对象
ST_GEOMETRY = dbi.NewDbDataType("ST_GEOMETRY", DTDmStruct).WithCT(dbi.CTVarchar) // 通用几何对象类型,可以表示点、线、面等任何几何形状
ST_MULTICURVE = dbi.NewDbDataType("ST_MULTICURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个曲线的集合
ST_MULTILINESTRING = dbi.NewDbDataType("ST_MULTILINESTRING", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个线串的集合
ST_MULTIPOINT = dbi.NewDbDataType("ST_MULTIPOINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个点的集合
ST_MULTIPOLYGON = dbi.NewDbDataType("ST_MULTIPOLYGON", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个曲线的集合
ST_MULTISURFACE = dbi.NewDbDataType("ST_MULTISURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示多个表面的集合
ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点
ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形
ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面
)
var DTDmStruct = &dbi.DataType{
Name: "dm_struct",
Valuer: func() dbi.Valuer {
return &dmStructValuer{
DefaultValuer: new(dbi.DefaultValuer[dm.DmStruct]),
}
},
SQLValue: dbi.SQLValueString,
}
type dmStructValuer struct {
*dbi.DefaultValuer[dm.DmStruct]
}
func (s *dmStructValuer) Value() any {
if !s.ValuePtr.Valid {
return ""
}
return ParseDmStruct(s.ValuePtr)
}
func ParseDmStruct(dmStruct *dm.DmStruct) string {
if !dmStruct.Valid {
return ""
}
name, _ := dmStruct.GetSQLTypeName()
attributes, _ := dmStruct.GetAttributes()
arr := make([]string, len(attributes))
arr = append(arr, name, "(")
for i, v := range attributes {
if blb, ok1 := v.(*dm.DmBlob); ok1 {
if blb.Valid {
length, _ := blb.GetLength()
var dest = make([]byte, length)
_, _ = blb.Read(dest)
// 2进制转16进制字符串
hexStr := hex.EncodeToString(dest)
arr = append(arr, "0x", strings.ToUpper(hexStr))
}
} else {
arr = append(arr, anyx.ToString(v))
}
if i < len(attributes)-1 {
arr = append(arr, ",")
}
}
arr = append(arr, ")")
return strings.Join(arr, "")
}

View File

@@ -1,11 +1,8 @@
package dm
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
"time"
@@ -19,109 +16,6 @@ type DMDialect struct {
dc *dbi.DbConn
}
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 执行批量insert sql
// insert into "table_name" ("column1", "column2", ...) values (value1, value2, ...)
// 无需处理重复数据直接执行批量insert
if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 {
return dd.batchInsertSimple(tx, tableName, columns, values)
} else { // 执行MERGE INTO语句
return dd.batchInsertMerge(tx, tableName, columns, values)
}
}
func (dd *DMDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
identityInsert := fmt.Sprintf("set identity_insert \"%s\" on;", tableName)
sqlTemp := fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, dd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
effRows := 0
// 设置允许填充自增列之后,显示指定列名可以插入自增列
for _, value := range values {
// 达梦数据库只能一条条的执行insert
res, err := dd.dc.TxExec(tx, sqlTemp, value...)
if err != nil {
logx.Errorf("执行sql失败%s, sql: [ %s ]", err.Error(), sqlTemp)
return 0, err
}
effRows += int(res)
}
// 执行批量insert sql
return int64(effRows), nil
}
func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 查询主键字段
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
metadata := dd.dc.GetMetadata()
tableCols, _ := metadata.GetColumns(tableName)
identityCols := make([]string, 0)
for _, col := range tableCols {
if col.IsPrimaryKey {
uniqueCols = append(uniqueCols, col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", dd.QuoteIdentifier(col.ColumnName), dd.QuoteIdentifier(col.ColumnName)))
}
if col.IsIdentity {
// 自增字段不放入insert内即使是设置了identity_insert on也不起作用
identityCols = append(identityCols, dd.QuoteIdentifier(col.ColumnName))
}
}
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := metadata.GetTableIndex(tableName)
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
tmp := make([]string, 0)
for _, col := range cols {
uniqueCols = append(uniqueCols, col)
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dd.QuoteIdentifier(col), dd.QuoteIdentifier(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
}
// 重复数据处理策略
phs := make([]string, 0)
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
for _, column := range columns {
phs = append(phs, fmt.Sprintf("? %s", column))
if !collx.ArrayContains(uniqueCols, dd.RemoveQuote(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
if !collx.ArrayContains(identityCols, column) {
insertCols = append(insertCols, column)
insertVals = append(insertVals, fmt.Sprintf("T2.%s", column))
}
}
t2s := make([]string, 0)
for i := 0; i < len(values); i++ {
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + dd.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
return dd.dc.TxExec(tx, sqlTemp, args...)
}
func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
metadata := dd.dc.GetMetadata()
@@ -162,119 +56,13 @@ func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
func (dd *DMDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
tbName := dd.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName))
}
// 组装建表语句
createSql := fmt.Sprintf("create table %s (", tbName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, dd.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, dd.genColumnBasicSql(column))
if column.ColumnComment != "" {
comment := dd.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, dd.QuoteIdentifier(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
tableCommentSql := ""
if tableInfo.TableComment != "" {
comment := dd.QuoteEscape(tableInfo.TableComment)
tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment)
}
sqlArr = append(sqlArr, createSql)
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
func (dd *DMDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
sqls := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = dd.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, dd.QuoteIdentifier(index.IndexName), dd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
}
return sqls
}
func (dd *DMDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (dd *DMDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
}
func (dd *DMDialect) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}
func (dd *DMDialect) genColumnBasicSql(column dbi.Column) string {
colName := dd.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " IDENTITY"
func (sd *DMDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{
Dialect: sd,
Metadata: sd.dc.GetMetadata(),
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}

View File

@@ -1,289 +1,11 @@
package dm
import (
"encoding/hex"
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"regexp"
"strings"
"time"
"gitee.com/chunanyong/dm"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
// 达梦数据类型 对应 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"CHAR": dbi.CommonTypeChar, // 字符数据类型
"VARCHAR": dbi.CommonTypeVarchar,
"TEXT": dbi.CommonTypeText,
"LONG": dbi.CommonTypeText,
"LONGVARCHAR": dbi.CommonTypeLongtext,
"IMAGE": dbi.CommonTypeLongtext,
"LONGVARBINARY": dbi.CommonTypeLongtext,
"BLOB": dbi.CommonTypeBlob,
"CLOB": dbi.CommonTypeText,
"NUMERIC": dbi.CommonTypeNumber, // 精确数值数据类型
"DECIMAL": dbi.CommonTypeNumber,
"NUMBER": dbi.CommonTypeNumber,
"INTEGER": dbi.CommonTypeInt,
"INT": dbi.CommonTypeInt,
"BIGINT": dbi.CommonTypeBigint,
"TINYINT": dbi.CommonTypeTinyint,
"BYTE": dbi.CommonTypeTinyint,
"SMALLINT": dbi.CommonTypeSmallint,
"BIT": dbi.CommonTypeTinyint,
"DOUBLE": dbi.CommonTypeNumber, // 近似数值类型
"FLOAT": dbi.CommonTypeNumber,
"DATE": dbi.CommonTypeDate, // 一般日期时间数据类型
"TIME": dbi.CommonTypeTime,
"TIMESTAMP": dbi.CommonTypeTimestamp,
}
// 公共数据类型 对应 达梦数据类型
dmColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "VARCHAR",
dbi.CommonTypeChar: "CHAR",
dbi.CommonTypeText: "TEXT",
dbi.CommonTypeBlob: "BLOB",
dbi.CommonTypeLongblob: "TEXT",
dbi.CommonTypeLongtext: "TEXT",
dbi.CommonTypeBinary: "TEXT",
dbi.CommonTypeMediumblob: "TEXT",
dbi.CommonTypeMediumtext: "TEXT",
dbi.CommonTypeVarbinary: "TEXT",
dbi.CommonTypeInt: "INT",
dbi.CommonTypeSmallint: "SMALLINT",
dbi.CommonTypeTinyint: "TINYINT",
dbi.CommonTypeNumber: "NUMBER",
dbi.CommonTypeBigint: "BIGINT",
dbi.CommonTypeDatetime: "TIMESTAMP",
dbi.CommonTypeDate: "DATE",
dbi.CommonTypeTime: "DATE",
dbi.CommonTypeTimestamp: "TIMESTAMP",
dbi.CommonTypeEnum: "TEXT",
dbi.CommonTypeJSON: "TEXT",
}
dmStructTypes = map[string]bool{
"ST_CURVE": true, // 表示一条曲线,可以是圆弧、抛物线等
"ST_LINESTRING": true, // 表示一条或多条连续的线段
"ST_GEOMCOLLECTION": true, // 表示一个几何对象集合,可以包含多个不同类型的几何对象
"ST_GEOMETRY": true, // 通用几何对象类型,可以表示点、线、面等任何几何形状
"ST_MULTICURVE": true, // 表示多个曲线的集合
"ST_MULTILINESTRING": true, // 表示多个线串的集合
"ST_MULTIPOINT": true, // 表示多个点的集合
"ST_MULTIPOLYGON": true, // 表示多个多边形的集合
"ST_MULTISURFACE": true, // 表示多个表面的集合
"ST_POINT": true, // 表示一个点
"ST_POLYGON": true, // 表示一个多边形
"ST_SURFACE": true, // 表示一个表面,通常是一个多边形
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return str
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
// 翻译为通用数据库类型
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
if t1 == "" {
dialectColumn.DataType = dbi.CommonTypeVarchar
dialectColumn.CharMaxLength = 2000
} else {
dialectColumn.DataType = t1
}
}
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
ctype := dmColumnTypeMap[commonColumn.DataType]
if ctype == "" {
commonColumn.DataType = "VARCHAR"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
ch.FixColumn(commonColumn)
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
// 如果是date不设长度
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) {
column.CharMaxLength = 0
column.NumPrecision = 0
} else
// 如果是char且长度未设置则默认长度2000
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(string(column.DataType))) && column.CharMaxLength == 0 {
column.CharMaxLength = 2000
}
}
func (dd *ColumnHelper) GetScanDestPtr(qc *dbi.QueryColumn) any {
if dmStructTypes[strings.ToUpper(qc.Type)] {
return &dm.DmStruct{}
}
return dd.DefaultColumnHelper.GetScanDestPtr(qc)
}
func (dd *ColumnHelper) ConvertScanDestValue(data any, qc *dbi.QueryColumn) any {
if data == nil {
return nil
}
// 达梦特殊数据类型
if dmStruct, ok := data.(*dm.DmStruct); ok {
return ParseDmStruct(dmStruct)
}
return dd.DefaultColumnHelper.ConvertScanDestValue(data, qc)
}
func ParseDmStruct(dmStruct *dm.DmStruct) string {
if !dmStruct.Valid {
return ""
}
name, _ := dmStruct.GetSQLTypeName()
attributes, _ := dmStruct.GetAttributes()
arr := make([]string, len(attributes))
arr = append(arr, name, "(")
for i, v := range attributes {
if blb, ok1 := v.(*dm.DmBlob); ok1 {
if blb.Valid {
length, _ := blb.GetLength()
var dest = make([]byte, length)
_, _ = blb.Read(dest)
// 2进制转16进制字符串
hexStr := hex.EncodeToString(dest)
arr = append(arr, "0x", strings.ToUpper(hexStr))
}
} else {
arr = append(arr, anyx.ToString(v))
}
if i < len(attributes)-1 {
arr = append(arr, ",")
}
}
arr = append(arr, ")")
return strings.Join(arr, "")
}
type DumpHelper struct {
}

View File

@@ -4,14 +4,19 @@ import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"net/url"
"strings"
)
func init() {
dbi.Register(dbi.DbTypeDM, new(Meta))
dbi.Register(DbTypeDM, new(Meta))
}
const (
DbTypeDM dbi.DbType = "dm"
)
type Meta struct {
}
@@ -49,3 +54,17 @@ func (dm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
dc: conn,
}
}
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB,
BLOB,
NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT,
TIME, DATE, TIMESTAMP,
ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING,
ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE,
)
}
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
@@ -55,7 +54,7 @@ func (dd *DMMetadata) GetDbNames() ([]string, error) {
func (dd *DMMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := dd.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -89,7 +88,7 @@ func (dd *DMMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := dd.dc.GetDialect()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
@@ -97,13 +96,12 @@ func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
return nil, err
}
columnHelper := dd.dc.GetDialect().GetColumnHelper()
columns := make([]dbi.Column, 0)
for _, re := range res {
column := dbi.Column{
TableName: cast.ToString(re["TABLE_NAME"]),
ColumnName: cast.ToString(re["COLUMN_NAME"]),
DataType: dbi.ColumnDataType(anyx.ToString(re["DATA_TYPE"])),
DataType: anyx.ToString(re["DATA_TYPE"]),
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
@@ -113,7 +111,7 @@ func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
NumPrecision: cast.ToInt(re["NUM_PRECISION"]),
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
columnHelper.FixColumn(&column)
dd.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -176,34 +174,7 @@ func (dd *DMMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (dd *DMMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := dd.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := dd.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
dialect := dd.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := dd.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(dd.dc.GetDialect(), dd, tableName, dropBeforeCreate)
}
// 获取DM当前连接的库可访问的schemaNames

View File

@@ -0,0 +1,195 @@
package dm
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
)
type SQLGenerator struct {
Dialect dbi.Dialect
Metadata dbi.Metadata
}
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
quoter := sg.Dialect.Quoter()
quote := quoter.Quote
tbName := quote(table.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName))
}
// 组装建表语句
createSql := fmt.Sprintf("create table %s (", tbName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, quote(column.ColumnName))
}
fields = append(fields, sg.genColumnBasicSql(quoter, column))
if column.ColumnComment != "" {
comment := dbi.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, quote(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
tableCommentSql := ""
if table.TableComment != "" {
comment := dbi.QuoteEscape(table.TableComment)
tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment)
}
sqlArr = append(sqlArr, createSql)
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
quote := sg.Dialect.Quoter().Quote
sqls := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = quote(name)
}
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
}
return sqls
}
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
quoter := sg.Dialect.Quoter()
quote := quoter.Quote
if duplicateStrategy == dbi.DuplicateStrategyNone {
identityInsert := fmt.Sprintf("set identity_insert %s on;", quote(tableName))
// 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句
return collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value})
return fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
})
}
// 查询主键字段
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
metadata := sg.Metadata
tableCols, _ := metadata.GetColumns(tableName)
identityCols := make([]string, 0)
for _, col := range tableCols {
if col.IsPrimaryKey {
uniqueCols = append(uniqueCols, col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", quote(col.ColumnName), quote(col.ColumnName)))
}
if col.IsIdentity {
// 自增字段不放入insert内即使是设置了identity_insert on也不起作用
identityCols = append(identityCols, quote(col.ColumnName))
}
}
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := metadata.GetTableIndex(tableName)
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
tmp := make([]string, 0)
for _, col := range cols {
uniqueCols = append(uniqueCols, col)
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
}
// 重复数据处理策略
phs := make([]string, 0)
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
for _, column := range columns {
columnName := column.ColumnName
phs = append(phs, fmt.Sprintf("? %s", columnName))
if !collx.ArrayContains(uniqueCols, quoter.Trim(columnName)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", columnName, columnName))
}
if !collx.ArrayContains(identityCols, columnName) {
insertCols = append(insertCols, columnName)
insertVals = append(insertVals, fmt.Sprintf("T2.%s", columnName))
}
}
t2s := make([]string, 0)
for i := 0; i < len(values); i++ {
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
return collx.AsArray(sqlTemp)
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
colName := quoter.Quote(column.ColumnName)
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " IDENTITY"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}

View File

@@ -0,0 +1,97 @@
package dm
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
return VARCHAR
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return CHAR
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
return TEXT
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
return TEXT
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
return LONGVARCHAR
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return BIT
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return TINYINT
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return SMALLINT
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return INTEGER
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return BIGINT
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return NUMBER
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return DECIMAL
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return BIGINT
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return DATE
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return TIME
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return VARCHAR
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return VARCHAR
}

View File

@@ -0,0 +1,50 @@
package mssql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
)
var (
DTOracleDate = dbi.DTDateTime.Copy().WithSQLValue(func(val any) string {
// oracle date型需要用函数包裹to_date('%s', 'yyyy-mm-dd hh24:mi:ss')
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", val)
})
)
var (
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8)
Numeric = dbi.NewDbDataType("numeric", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2)
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Smallmoney = dbi.NewDbDataType("smallmoney", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4)
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1)
Money = dbi.NewDbDataType("money", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Real = dbi.NewDbDataType("real", dbi.DTString).WithCT(dbi.CTVarchar)
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
Datetimeoffset = dbi.NewDbDataType("datetimeoffset", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Datetime2 = dbi.NewDbDataType("datetime2", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Smalldatetime = dbi.NewDbDataType("smalldatetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTVarchar)
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTVarchar)
Nchar = dbi.NewDbDataType("nchar", dbi.DTString).WithCT(dbi.CTVarchar)
Nvarchar = dbi.NewDbDataType("nvarchar", dbi.DTString).WithCT(dbi.CTVarchar)
Ntext = dbi.NewDbDataType("ntext", dbi.DTString).WithCT(dbi.CTVarchar)
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTBinary)
Cursor = dbi.NewDbDataType("cursor", dbi.DTString).WithCT(dbi.CTVarchar)
Rowversion = dbi.NewDbDataType("rowversion", dbi.DTBytes).WithCT(dbi.CTBinary)
Hierarchyid = dbi.NewDbDataType("hierarchyid", dbi.DTString).WithCT(dbi.CTVarchar)
Uniqueidentifier = dbi.NewDbDataType("uniqueidentifier", dbi.DTString).WithCT(dbi.CTVarchar)
Sql_variant = dbi.NewDbDataType("sql_variant", dbi.DTString).WithCT(dbi.CTVarchar)
Xml = dbi.NewDbDataType("xml", dbi.DTString).WithCT(dbi.CTVarchar)
Table = dbi.NewDbDataType("table", dbi.DTString).WithCT(dbi.CTVarchar)
Geometry = dbi.NewDbDataType("geometry", dbi.DTString).WithCT(dbi.CTVarchar)
Geography = dbi.NewDbDataType("geography", dbi.DTString).WithCT(dbi.CTVarchar)
)

View File

@@ -10,6 +10,14 @@ import (
"time"
)
var (
mssqlQuoter = dbi.Quoter{
Prefix: '[',
Suffix: ']',
IsReserved: dbi.AlwaysReserve,
}
)
type MssqlDialect struct {
dbi.DefaultDialect
@@ -99,7 +107,8 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
baseTable := fmt.Sprintf("%s.%s", md.QuoteIdentifier(schema), md.QuoteIdentifier(tableName))
quote := md.Quoter().Quote
baseTable := fmt.Sprintf("%s.%s", quote(schema), quote(tableName))
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", baseTable, strings.Join(columns, ","), placeholder)
// 执行批量insert sql
@@ -119,6 +128,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
msMetadata := md.dc.GetMetadata()
schema := md.dc.Info.CurrentSchema()
quote := md.Quoter().Quote
// 收集MERGE 语句的 ON 子句条件
caseSqls := make([]string, 0)
@@ -136,7 +146,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
}
if col.IsPrimaryKey {
pkCols = append(pkCols, col.ColumnName)
name := md.QuoteIdentifier(col.ColumnName)
name := quote(col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name))
}
}
@@ -150,7 +160,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
// 源数据占位sql
phs := make([]string, 0)
for _, column := range columns {
if !collx.ArrayContains(identityCols, md.RemoveQuote(column)) {
if !collx.ArrayContains(identityCols, md.Quoter().Trim(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
insertCols = append(insertCols, column)
@@ -168,7 +178,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + md.QuoteIdentifier(schema) + "." + md.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
sqlTemp := "MERGE INTO " + quote(schema) + "." + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
@@ -247,7 +257,7 @@ func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (str
// 查询表名和表注释, 设置表注释
tbs, err := metadata.GetTables(tableName)
if err != nil || len(tbs) < 1 {
logx.Errorf("获取表信息失败, %s", tableName)
logx.Errorf("failed to get table, %s", tableName)
return "", err
}
tabInfo := &dbi.Table{
@@ -258,165 +268,31 @@ func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (str
// 查询列信息
columns, err := metadata.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
logx.Errorf("failed to get columns, %s", tableName)
return "", err
}
sqlArr := md.GenerateTableDDL(columns, *tabInfo, true)
sqlGener := md.GetSQLGenerator()
sqlArr := sqlGener.GenTableDDL(*tabInfo, columns, true)
// 设置索引
indexs, err := metadata.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
logx.Errorf("failed to get indexs, %s", tableName)
return strings.Join(sqlArr, ";"), err
}
sqlArr = append(sqlArr, md.GenerateIndexDDL(indexs, *tabInfo)...)
sqlArr = append(sqlArr, sqlGener.GenIndexDDL(*tabInfo, indexs)...)
return strings.Join(sqlArr, ";"), nil
}
// 获取建表ddl
func (md *MssqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
tbName := tableInfo.TableName
schemaName := md.dc.Info.CurrentSchema()
sqlArr := make([]string, 0)
// 删除表
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName))
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, md.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, md.genColumnBasicSql(column))
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'"
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := md.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName))
}
}
// create
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// comment
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
tableCommentSql = fmt.Sprintf(commentTmp, md.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName)
}
sqlArr = append(sqlArr, createSql)
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
// 获取建索引ddl
func (md *MssqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
tbName := tableInfo.TableName
sqls := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = md.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, md.QuoteIdentifier(index.IndexName), md.QuoteIdentifier(md.dc.Info.CurrentSchema()), md.QuoteIdentifier(tbName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := md.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName))
}
}
if len(comments) > 0 {
sqls = append(sqls, comments...)
}
return sqls
}
func (dx *MssqlDialect) QuoteIdentifier(name string) string {
return fmt.Sprintf("[%s]", name)
}
func (dx *MssqlDialect) RemoveQuote(name string) string {
return strings.Trim(name, "[]")
}
func (md *MssqlDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (md *MssqlDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
func (dx *MssqlDialect) Quoter() dbi.Quoter {
return mssqlQuoter
}
func (md *MssqlDialect) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}
func (md *MssqlDialect) genColumnBasicSql(column dbi.Column) string {
colName := md.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " IDENTITY(1,1)"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
func (md *MssqlDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{dc: md.dc}
}

View File

@@ -4,218 +4,15 @@ import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"regexp"
"strings"
"time"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
// mssql数据类型 对应 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"bigint": dbi.CommonTypeBigint,
"numeric": dbi.CommonTypeNumber,
"bit": dbi.CommonTypeInt,
"smallint": dbi.CommonTypeSmallint,
"decimal": dbi.CommonTypeNumber,
"smallmoney": dbi.CommonTypeNumber,
"int": dbi.CommonTypeInt,
"tinyint": dbi.CommonTypeSmallint, // mssql tinyint不支持负数
"money": dbi.CommonTypeNumber,
"float": dbi.CommonTypeNumber, // 近似数字
"real": dbi.CommonTypeVarchar,
"date": dbi.CommonTypeDate, // 日期和时间
"datetimeoffset": dbi.CommonTypeDatetime,
"datetime2": dbi.CommonTypeDatetime,
"smalldatetime": dbi.CommonTypeDatetime,
"datetime": dbi.CommonTypeDatetime,
"time": dbi.CommonTypeTime,
"char": dbi.CommonTypeChar, // 字符串
"varchar": dbi.CommonTypeVarchar,
"text": dbi.CommonTypeText,
"nchar": dbi.CommonTypeChar,
"nvarchar": dbi.CommonTypeVarchar,
"ntext": dbi.CommonTypeText,
"binary": dbi.CommonTypeBinary,
"varbinary": dbi.CommonTypeBinary,
"cursor": dbi.CommonTypeVarchar, // 其他
"rowversion": dbi.CommonTypeVarchar,
"hierarchyid": dbi.CommonTypeVarchar,
"uniqueidentifier": dbi.CommonTypeVarchar,
"sql_variant": dbi.CommonTypeVarchar,
"xml": dbi.CommonTypeText,
"table": dbi.CommonTypeText,
"geometry": dbi.CommonTypeText, // 空间几何类型
"geography": dbi.CommonTypeText, // 空间地理类型
}
// 公共数据类型 对应 mssql数据类型
mssqlColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "nvarchar",
dbi.CommonTypeChar: "nchar",
dbi.CommonTypeText: "ntext",
dbi.CommonTypeBlob: "ntext",
dbi.CommonTypeLongblob: "ntext",
dbi.CommonTypeLongtext: "ntext",
dbi.CommonTypeBinary: "varbinary",
dbi.CommonTypeMediumblob: "ntext",
dbi.CommonTypeMediumtext: "ntext",
dbi.CommonTypeVarbinary: "varbinary",
dbi.CommonTypeInt: "int",
dbi.CommonTypeSmallint: "smallint",
dbi.CommonTypeTinyint: "smallint",
dbi.CommonTypeNumber: "decimal",
dbi.CommonTypeBigint: "bigint",
dbi.CommonTypeDatetime: "datetime2",
dbi.CommonTypeDate: "date",
dbi.CommonTypeTime: "time",
dbi.CommonTypeTimestamp: "timestamp",
dbi.CommonTypeEnum: "nvarchar",
dbi.CommonTypeJSON: "nvarchar",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
// 如果dataType是datetime而dbColumnValue是string类型则需要根据类型格式化
str, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
if dataType == dbi.DataTypeDate && ok {
// 尝试用时间格式解析
res, _ := time.Parse(time.DateOnly, str)
return res.Format(time.DateOnly)
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, str)
return res.Format(time.TimeOnly)
}
return anyx.ToString(dbColumnValue)
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
// 翻译为通用数据库类型
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
if t1 == "" {
dialectColumn.DataType = dbi.CommonTypeVarchar
dialectColumn.CharMaxLength = 2000
} else {
dialectColumn.DataType = t1
}
}
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
ctype := mssqlColumnTypeMap[commonColumn.DataType]
if ctype == "" {
commonColumn.DataType = "varchar"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
ch.FixColumn(commonColumn)
// 修复数据库迁移字段长度
dataType := string(commonColumn.DataType)
if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) {
commonColumn.CharMaxLength = commonColumn.CharMaxLength * 2
}
if collx.ArrayAnyMatches([]string{"char"}, dataType) {
// char最大长度4000
if commonColumn.CharMaxLength >= 4000 {
commonColumn.DataType = "ntext"
commonColumn.CharMaxLength = 0
}
}
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"net/url"
"strings"
@@ -12,9 +13,13 @@ import (
func init() {
meta := new(Meta)
dbi.Register(dbi.DbTypeMssql, meta)
dbi.Register(DbTypeMssql, meta)
}
const (
DbTypeMssql dbi.DbType = "mssql"
)
type Meta struct {
}
@@ -60,3 +65,45 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
return &MssqlMetadata{dc: conn}
}
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray[*dbi.DbDataType](Bigint,
Numeric,
Bit,
Smallint,
Decimal,
Smallmoney,
Int,
Tinyint,
Money,
Float,
Real,
Date,
Datetimeoffset,
Datetime2,
Smalldatetime,
Datetime,
Time,
Char,
Varchar,
Text,
Nchar,
Nvarchar,
Ntext,
Binary,
Varbinary,
Cursor,
Rowversion,
Hierarchyid,
Uniqueidentifier,
Sql_variant,
Xml,
Table,
Geometry,
Geography,
)
}
func (sm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
@@ -58,7 +57,7 @@ func (md *MssqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := md.dc.GetDialect()
schema := md.dc.Info.CurrentSchema()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -91,9 +90,8 @@ func (md *MssqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
// 获取列元信息, 如列名等
func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := md.dc.GetDialect()
columnHelper := dialect.GetColumnHelper()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_COLUMN_MA_KEY), tableName), md.dc.Info.CurrentSchema())
@@ -107,7 +105,7 @@ func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
column := dbi.Column{
TableName: anyx.ToString(re["TABLE_NAME"]),
ColumnName: anyx.ToString(re["COLUMN_NAME"]),
DataType: dbi.ColumnDataType(anyx.ToString(re["DATA_TYPE"])),
DataType: anyx.ToString(re["DATA_TYPE"]),
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
ColumnComment: anyx.ToString(re["COLUMN_COMMENT"]),
Nullable: anyx.ToString(re["NULLABLE"]) == "YES",
@@ -118,8 +116,7 @@ func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
columnHelper.FixColumn(&column)
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -197,33 +194,7 @@ func (md *MssqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (md *MssqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := md.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
dialect := md.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
}
func (md *MssqlMetadata) GetSchemas() ([]string, error) {

View File

@@ -0,0 +1,146 @@
package mssql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
)
type SQLGenerator struct {
dc *dbi.DbConn
}
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
tbName := table.TableName
schemaName := sg.dc.Info.CurrentSchema()
quoter := sg.dc.GetDialect().Quoter()
quote := quoter.Quote
sqlArr := make([]string, 0)
// 删除表
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", quote(schemaName), quote(tbName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", quote(schemaName), quote(tbName))
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, quote(column.ColumnName))
}
fields = append(fields, sg.genColumnBasicSql(quoter, column))
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'"
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := dbi.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, sg.dc.Info.CurrentSchema(), tbName, column.ColumnName))
}
}
// create
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// comment
tableCommentSql := ""
if table.TableComment != "" {
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
tableCommentSql = fmt.Sprintf(commentTmp, dbi.QuoteEscape(table.TableComment), sg.dc.Info.CurrentSchema(), tbName)
}
sqlArr = append(sqlArr, createSql)
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
quote := sg.dc.GetDialect().Quoter().Quote
tbName := table.TableName
sqls := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = quote(name)
}
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, quote(index.IndexName), quote(sg.dc.Info.CurrentSchema()), quote(tbName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := dbi.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, sg.dc.Info.CurrentSchema(), tbName, index.IndexName))
}
}
if len(comments) > 0 {
sqls = append(sqls, comments...)
}
return sqls
}
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
return collx.AsArray("")
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
colName := quoter.Quote(column.ColumnName)
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " IDENTITY(1,1)"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}

View File

@@ -0,0 +1,97 @@
package mssql
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
return Varchar
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return Char
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return Bit
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return Tinyint
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return Smallint
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return Int
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return Bigint
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return Numeric
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return Decimal
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return Bigint
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return Int
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return Smallint
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return Tinyint
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return Date
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return Time
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return Datetime
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return Datetime
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return Varbinary
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return Varchar
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return Text
}

View File

@@ -0,0 +1,48 @@
package mysql
import (
"mayfly-go/internal/db/dbm/dbi"
)
const (
IndexSubPartKey = "subPart"
)
var (
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1).WithFixColumn(dbi.ClearNumScale)
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
Mediumint = dbi.NewDbDataType("mediumint", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
UnsignedBigint = dbi.NewDbDataType("unsigned bigint", dbi.DTUint64).WithCT(dbi.CTUnsignedInt8).WithFixColumn(dbi.ClearNumScale)
UnsignedInt = dbi.NewDbDataType("unsigned int", dbi.DTUint64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
UnsignedSmallint = dbi.NewDbDataType("unsigned smallint", dbi.DTInt32).WithCT(dbi.CTUnsignedInt2).WithFixColumn(dbi.ClearNumScale)
UnsignedMediumint = dbi.NewDbDataType("unsigned mediumint", dbi.DTInt64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Double = dbi.NewDbDataType("double", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
Mediumtext = dbi.NewDbDataType("mediumtext", dbi.DTString).WithCT(dbi.CTMediumtext).WithFixColumn(dbi.ClearCharMaxLength)
Longtext = dbi.NewDbDataType("longtext", dbi.DTString).WithCT(dbi.CTLongtext).WithFixColumn(dbi.ClearCharMaxLength)
JSON = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTTimestamp)
Enum = dbi.NewDbDataType("enum", dbi.DTString).WithCT(dbi.CTEnum)
Set = dbi.NewDbDataType("set", dbi.DTString).WithCT(dbi.CTVarchar)
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
Mediumblob = dbi.NewDbDataType("mediumblob", dbi.DTBytes).WithCT(dbi.CTMediumblob)
Longblob = dbi.NewDbDataType("longblob", dbi.DTBytes).WithCT(dbi.CTLongblob)
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTVarbinary)
)

View File

@@ -1,17 +1,20 @@
package mysql
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/dbm/sqlparser/mysql"
"mayfly-go/pkg/utils/collx"
"strings"
"time"
)
const Quoter = "`"
var (
mysqlQuoter = dbi.Quoter{
Prefix: '`',
Suffix: '`',
IsReserved: dbi.AlwaysReserve,
}
)
type MysqlDialect struct {
dbi.DefaultDialect
@@ -24,38 +27,6 @@ func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return NewDbProgramMysql(md.dc), nil
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
// 执行批量insert sqlmysql支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
// 重复占位符字符串n遍
repeated = strings.Repeat(placeholder+",", len(values))
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert ignore into"
} else if duplicateStrategy == 2 {
prefix = "replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
return md.dc.TxExec(tx, sqlStr, args...)
}
func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
@@ -77,145 +48,14 @@ func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
// 获取建表ddl
func (md *MysqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", md.QuoteIdentifier(tableInfo.TableName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", md.QuoteIdentifier(tableInfo.TableName))
fields := make([]string, 0)
pks := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, column.ColumnName)
}
fields = append(fields, md.genColumnBasicSql(column))
}
// 建表ddl
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// 表注释
if tableInfo.TableComment != "" {
createSql += fmt.Sprintf(" COMMENT '%s'", md.QuoteEscape(tableInfo.TableComment))
}
sqlArr = append(sqlArr, createSql)
return sqlArr
}
// 获取建索引ddl
func (md *MysqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
sqlArr := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = md.QuoteIdentifier(name)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
sqlStr := fmt.Sprintf(sqlTmp, md.QuoteIdentifier(tableInfo.TableName), unique, md.QuoteIdentifier(index.IndexName), strings.Join(colNames, ","))
comment := md.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
}
sqlArr = append(sqlArr, sqlStr)
}
return sqlArr
}
func (md *MysqlDialect) genColumnBasicSql(column dbi.Column) string {
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " AUTO_INCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
columnType := column.GetColumnType()
if nullAble == "" && strings.Contains(columnType, "timestamp") {
nullAble = " NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
if column.ColumnDefault != "" &&
// 当默认值是字符串'NULL'时,不需要设置默认值
column.ColumnDefault != "NULL" &&
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
!strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
comment := ""
if column.ColumnComment != "" {
// 防止注释内含有特殊字符串导致sql出错
commentStr := md.QuoteEscape(column.ColumnComment)
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
}
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment)
return columnSql
}
func (dx *MysqlDialect) QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return Quoter + strings.Replace(name, Quoter, Quoter+Quoter, -1) + Quoter
}
func (dx *MysqlDialect) RemoveQuote(name string) string {
return strings.ReplaceAll(name, Quoter, "")
}
func (md *MysqlDialect) QuoteLiteral(literal string) string {
literal = strings.ReplaceAll(literal, `\`, `\\`)
literal = strings.ReplaceAll(literal, `'`, `''`)
return "'" + literal + "'"
}
func (md *MysqlDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (md *MysqlDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
func (md *MysqlDialect) Quoter() dbi.Quoter {
return mysqlQuoter
}
func (pd *MysqlDialect) GetSQLParser() sqlparser.SqlParser {
return new(mysql.MysqlParser)
}
func (md *MysqlDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{Dialect: md}
}

View File

@@ -1,218 +1 @@
package mysql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"regexp"
"strings"
"time"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
blobRegexp = regexp.MustCompile(`(?i)blob`)
// mysql数据类型 映射 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"bigint": dbi.CommonTypeBigint,
"binary": dbi.CommonTypeBinary,
"blob": dbi.CommonTypeBlob,
"char": dbi.CommonTypeChar,
"datetime": dbi.CommonTypeDatetime,
"date": dbi.CommonTypeDate,
"decimal": dbi.CommonTypeNumber,
"double": dbi.CommonTypeNumber,
"enum": dbi.CommonTypeEnum,
"float": dbi.CommonTypeNumber,
"int": dbi.CommonTypeInt,
"json": dbi.CommonTypeJSON,
"longblob": dbi.CommonTypeLongblob,
"longtext": dbi.CommonTypeLongtext,
"mediumblob": dbi.CommonTypeBlob,
"mediumtext": dbi.CommonTypeMediumtext,
"bit": dbi.CommonTypeBit,
"set": dbi.CommonTypeVarchar,
"smallint": dbi.CommonTypeSmallint,
"text": dbi.CommonTypeText,
"time": dbi.CommonTypeTime,
"timestamp": dbi.CommonTypeTimestamp,
"tinyint": dbi.CommonTypeTinyint,
"varbinary": dbi.CommonTypeVarbinary,
"varchar": dbi.CommonTypeVarchar,
}
// 公共数据类型 映射 mysql数据类型
mysqlColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "varchar",
dbi.CommonTypeChar: "char",
dbi.CommonTypeText: "text",
dbi.CommonTypeBlob: "blob",
dbi.CommonTypeLongblob: "longblob",
dbi.CommonTypeLongtext: "longtext",
dbi.CommonTypeBinary: "binary",
dbi.CommonTypeMediumblob: "blob",
dbi.CommonTypeMediumtext: "mediumtext",
dbi.CommonTypeVarbinary: "varbinary",
dbi.CommonTypeInt: "int",
dbi.CommonTypeBit: "bit",
dbi.CommonTypeSmallint: "smallint",
dbi.CommonTypeTinyint: "tinyint",
dbi.CommonTypeNumber: "decimal",
dbi.CommonTypeBigint: "bigint",
dbi.CommonTypeDatetime: "datetime",
dbi.CommonTypeDate: "date",
dbi.CommonTypeTime: "time",
dbi.CommonTypeTimestamp: "timestamp",
dbi.CommonTypeEnum: "enum",
dbi.CommonTypeJSON: "json",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
// blob类型
if blobRegexp.MatchString(dbColumnType) {
return dbi.DataTypeBlob
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
// 如果dataType是datetime而dbColumnValue是string类型则需要根据类型格式化
str, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, str)
return res.Format(time.DateOnly)
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, str)
return res.Format(time.TimeOnly)
}
return anyx.ToString(dbColumnValue)
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
// mysql时间类型无需格式化
return fmt.Sprintf("'%s'", dbColumnValue)
case dbi.DataTypeBlob:
return fmt.Sprintf("unhex('%s')", dbColumnValue)
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
commonColumnType := dbi.CommonTypeVarchar
if t1 != "" {
commonColumnType = t1
}
dialectColumn.DataType = commonColumnType
}
func (ch *ColumnHelper) ToColumn(column *dbi.Column) {
ctype := mysqlColumnTypeMap[column.DataType]
if ctype == "" {
column.DataType = "varchar"
column.CharMaxLength = 1000
} else {
column.DataType = dbi.ColumnDataType(ctype)
ch.FixColumn(column)
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
// 如果是int整型删除精度
if strings.Contains(strings.ToLower(string(column.DataType)), "int") {
column.NumScale = 0
column.CharMaxLength = 0
} else
// 如果是text删除长度
if strings.Contains(strings.ToLower(string(column.DataType)), "text") {
column.CharMaxLength = 0
column.NumPrecision = 0
}
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"net"
"github.com/go-sql-driver/mysql"
@@ -12,10 +13,15 @@ import (
func init() {
meta := new(Meta)
dbi.Register(dbi.DbTypeMysql, meta)
dbi.Register(dbi.DbTypeMariadb, meta)
dbi.Register(DbTypeMysql, meta)
dbi.Register(DbTypeMariadb, meta)
}
const (
DbTypeMysql dbi.DbType = "mysql"
DbTypeMariadb dbi.DbType = "mariadb"
)
type Meta struct {
}
@@ -31,7 +37,7 @@ func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?parseTime=true&timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
@@ -46,3 +52,17 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
return &MysqlMetadata{dc: conn}
}
func (mm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray(
UnsignedBigint, Bigint, Tinyint, Smallint, Int, Bit, Float, Double, Decimal,
Varchar, Char, Text, Longtext, Mediumtext,
Datetime, Date, Time, Timestamp,
Enum, JSON, Set,
Binary, Blob, Longblob, Mediumblob, Varbinary,
)
}
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
@@ -54,7 +53,7 @@ func (md *MysqlMetadata) GetDbNames() ([]string, error) {
func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := md.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -87,9 +86,8 @@ func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
// 获取列元信息, 如列名等
func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := md.dc.GetDialect()
columnHelper := dialect.GetColumnHelper()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
@@ -103,7 +101,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
column := dbi.Column{
TableName: cast.ToString(re["tableName"]),
ColumnName: cast.ToString(re["columnName"]),
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
DataType: cast.ToString(re["dataType"]),
ColumnComment: cast.ToString(re["columnComment"]),
Nullable: cast.ToString(re["nullable"]) == "YES",
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
@@ -114,7 +112,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["numScale"]),
}
columnHelper.FixColumn(&column)
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -156,6 +154,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
IsUnique: cast.ToInt(re["isUnique"]) == 1,
SeqInIndex: cast.ToInt(re["seqInIndex"]),
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
Extra: collx.Kvs(IndexSubPartKey, cast.ToInt(re[IndexSubPartKey])),
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -179,34 +178,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (md *MysqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := md.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
dialect := md.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
}
func (md *MysqlMetadata) GetSchemas() ([]string, error) {

View File

@@ -56,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
dbInfo := svc.dbInfo()
var mysqlBin *config.MysqlBin
switch dbInfo.Type {
case dbi.DbTypeMariadb:
case DbTypeMariadb:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
case dbi.DbTypeMysql:
case DbTypeMysql:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
default:
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))

View File

@@ -0,0 +1,147 @@
package mysql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
"github.com/may-fly/cast"
)
type SQLGenerator struct {
Dialect dbi.Dialect
}
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
sqlArr := make([]string, 0)
quoter := msg.Dialect.Quoter()
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoter.Quote(table.TableName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoter.Quote(table.TableName))
fields := make([]string, 0)
pks := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, column.ColumnName)
}
fields = append(fields, msg.genColumnBasicSql(quoter, column))
}
// 建表ddl
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// 表注释
if table.TableComment != "" {
createSql += fmt.Sprintf(" COMMENT '%s'", dbi.QuoteEscape(table.TableComment))
}
sqlArr = append(sqlArr, createSql)
return sqlArr
}
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
sqlArr := make([]string, 0)
quoter := msg.Dialect.Quoter()
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
colNames := quoter.Quotes(strings.Split(index.ColumnName, ","))
// 暂时先处理单个索引的情况,多个涉及获取索引时的合并等,以及前端调整等,后续完善
if subPart := cast.ToInt(index.Extra[IndexSubPartKey]); subPart > 0 && len(colNames) == 1 {
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s"
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType)
comment := dbi.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
}
sqlArr = append(sqlArr, sqlStr)
}
return sqlArr
}
func (msg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
if duplicateStrategy == dbi.DuplicateStrategyNone {
return collx.AsArray(dbi.GenCommonInsert(msg.Dialect, DbTypeMysql, tableName, columns, values))
}
prefix := "insert ignore into"
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
prefix = "replace into"
}
quote := msg.Dialect.Quoter().Quote
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(msg.Dialect, DbTypeMysql, columns, values)
return collx.AsArray[string](fmt.Sprintf("%s %s %s VALUES \n%s", prefix, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " AUTO_INCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
columnType := column.GetColumnType()
if nullAble == "" && strings.Contains(columnType, "timestamp") {
nullAble = " NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
if column.ColumnDefault != "" &&
// 当默认值是字符串'NULL'时,不需要设置默认值
column.ColumnDefault != "NULL" &&
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
!strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
comment := ""
if column.ColumnComment != "" {
// 防止注释内含有特殊字符串导致sql出错
commentStr := dbi.QuoteEscape(column.ColumnComment)
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
}
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", quoter.Quote(column.ColumnName), columnType, nullAble, incr, defVal, comment)
return columnSql
}

View File

@@ -0,0 +1,108 @@
package mysql
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
// 如果字符长度大于16383则转为text类型
if col.CharMaxLength > 16383 {
col.CharMaxLength = 0
return Text
}
return Varchar
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return Char
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Text
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Mediumtext
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Longtext
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return Bit
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return Tinyint
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return Smallint
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return Int
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return Bigint
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return Double
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return Decimal
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return UnsignedBigint
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return UnsignedInt
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return UnsignedMediumint
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return UnsignedSmallint
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return Date
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return Time
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return Datetime
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return Timestamp
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return Varbinary
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return Mediumblob
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return Longblob
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return Enum
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return JSON
}

View File

@@ -0,0 +1,45 @@
package oracle
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
)
var (
DTOracleDate = dbi.DTDateTime.Copy().WithSQLValue(func(val any) string {
// oracle date型需要用函数包裹to_date('%s', 'yyyy-mm-dd hh24:mi:ss')
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", val)
})
)
var (
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar)
NCHAR = dbi.NewDbDataType("NCHAR", dbi.DTString).WithCT(dbi.CTChar)
VARCHAR2 = dbi.NewDbDataType("VARCHAR2", dbi.DTString).WithCT(dbi.CTVarchar)
NVARCHAR2 = dbi.NewDbDataType("NVARCHAR2", dbi.DTString).WithCT(dbi.CTVarchar)
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
LONGVARCHAR = dbi.NewDbDataType("LONGVARCHAR", dbi.DTString).WithCT(dbi.CTLongtext)
IMAGE = dbi.NewDbDataType("IMAGE", dbi.DTString).WithCT(dbi.CTLongtext)
LONGVARBINARY = dbi.NewDbDataType("LONGVARBINARY", dbi.DTString).WithCT(dbi.CTLongtext)
CLOB = dbi.NewDbDataType("CLOB", dbi.DTString).WithCT(dbi.CTLongtext)
BLOB = dbi.NewDbDataType("BLOB", dbi.DTBytes).WithCT(dbi.CTBlob)
DECIMAL = dbi.NewDbDataType("DECIMAL", dbi.DTDecimal).WithCT(dbi.CTDecimal)
NUMBER = dbi.NewDbDataType("NUMBER", dbi.DTNumeric).WithCT(dbi.CTNumeric)
INTEGER = dbi.NewDbDataType("INTEGER", dbi.DTInt32).WithCT(dbi.CTInt4)
INT = dbi.NewDbDataType("INT", dbi.DTInt32).WithCT(dbi.CTInt4)
BIGINT = dbi.NewDbDataType("BIGINT", dbi.DTInt64).WithCT(dbi.CTInt8)
TINYINT = dbi.NewDbDataType("TINYINT", dbi.DTInt8).WithCT(dbi.CTInt1)
BYTE = dbi.NewDbDataType("BYTE", dbi.DTInt8).WithCT(dbi.CTInt1)
SMALLINT = dbi.NewDbDataType("SMALLINT", dbi.DTInt16).WithCT(dbi.CTInt2)
BIT = dbi.NewDbDataType("BIT", dbi.DTBit).WithCT(dbi.CTBit)
DOUBLE = dbi.NewDbDataType("DOUBLE", dbi.DTNumeric).WithCT(dbi.CTNumeric)
FLOAT = dbi.NewDbDataType("FLOAT", dbi.DTNumeric).WithCT(dbi.CTNumeric)
TIME = dbi.NewDbDataType("TIME", DTOracleDate).WithCT(dbi.CTTime)
DATE = dbi.NewDbDataType("DATE", DTOracleDate).WithCT(dbi.CTDate)
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", DTOracleDate).WithCT(dbi.CTTimestamp)
)

View File

@@ -59,6 +59,8 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v
ignore = fmt.Sprintf("/*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */", tableName, strings.Join(arr, ","))
}
}
quote := od.Quoter().Quote
effRows := 0
for _, value := range values {
// 拼接带占位符的sql oracle的占位符是:1,:2,:3....
@@ -66,7 +68,7 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v
for i := 0; i < len(value); i++ {
placeholder = append(placeholder, fmt.Sprintf(":%d", i+1))
}
sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, od.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, quote(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
// oracle数据库为了兼容ignore主键冲突只能一条条的执行insert
res, err := od.dc.TxExec(tx, sqlTemp, value...)
@@ -83,6 +85,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
metadata := od.dc.GetMetadata()
quote := od.Quoter().Quote
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := metadata.GetTableIndex(tableName)
if indexs != nil {
@@ -94,7 +97,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
if !collx.ArrayContains(uniqueCols, col) {
uniqueCols = append(uniqueCols, col)
}
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", od.QuoteIdentifier(col), od.QuoteIdentifier(col)))
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
@@ -110,8 +113,9 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
quoteTrim := od.Quoter().Trim
for _, column := range columns {
if !collx.ArrayContains(uniqueCols, od.RemoveQuote(column)) {
if !collx.ArrayContains(uniqueCols, quoteTrim(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
insertCols = append(insertCols, fmt.Sprintf("T1.%s", column))
@@ -132,7 +136,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + od.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
@@ -157,7 +161,8 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 获取建表ddl
func (od *OracleDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
quoteTableName := od.QuoteIdentifier(tableInfo.TableName)
quote := od.Quoter().Quote
quoteTableName := quote(tableInfo.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
@@ -181,13 +186,13 @@ end`
// 把通用类型转换为达梦类型
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, od.QuoteIdentifier(column.ColumnName))
pks = append(pks, quote(column.ColumnName))
}
fields = append(fields, od.genColumnBasicSql(column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := od.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, od.QuoteIdentifier(column.ColumnName), comment))
comment := dbi.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, quote(column.ColumnName), comment))
}
}
@@ -202,7 +207,7 @@ end`
// 表注释
tableCommentSql := ""
if tableInfo.TableComment != "" {
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", od.QuoteIdentifier(tableInfo.TableName), od.QuoteEscape(tableInfo.TableComment))
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", quote(tableInfo.TableName), dbi.QuoteEscape(tableInfo.TableComment))
sqlArr = append(sqlArr, tableCommentSql)
}
@@ -222,7 +227,7 @@ end`
func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
sqls := make([]string, 0)
comments := make([]string, 0)
quote := od.Quoter().Quote
for _, index := range indexs {
unique := ""
if index.IsUnique {
@@ -233,10 +238,10 @@ func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tabl
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = od.QuoteIdentifier(name)
colNames[i] = quote(name)
}
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, od.QuoteIdentifier(index.IndexName), od.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, quote(index.IndexName), quote(tableInfo.TableName), strings.Join(colNames, ",")))
}
sqlArr := make([]string, 0)
@@ -255,16 +260,15 @@ func (od *OracleDialect) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableNa
return nil
}
func (od *OracleDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (od *OracleDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
func (od *OracleDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{
Dialect: od,
Metadata: od.dc.GetMetadata(),
}
}
func (od *OracleDialect) genColumnBasicSql(column dbi.Column) string {
colName := od.QuoteIdentifier(column.ColumnName)
colName := od.Quoter().Quote(column.ColumnName)
if column.IsIdentity {
// 如果是自增不需要设置默认值和空值自增列数据类型必须是number

View File

@@ -1,174 +1,16 @@
package oracle
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"regexp"
"strings"
"time"
"github.com/may-fly/cast"
)
var (
// 数字类型
numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
dateTimeReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`)
dateTimeIsoReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*$`)
// 日期时间类型
datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`)
// oracle数据类型 映射 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"CHAR": dbi.CommonTypeChar,
"NCHAR": dbi.CommonTypeChar,
"VARCHAR2": dbi.CommonTypeVarchar,
"NVARCHAR2": dbi.CommonTypeVarchar,
"NUMBER": dbi.CommonTypeNumber,
"INTEGER": dbi.CommonTypeInt,
"INT": dbi.CommonTypeInt,
"DECIMAL": dbi.CommonTypeNumber,
"FLOAT": dbi.CommonTypeNumber,
"REAL": dbi.CommonTypeNumber,
"BINARY_FLOAT": dbi.CommonTypeNumber,
"BINARY_DOUBLE": dbi.CommonTypeNumber,
"DATE": dbi.CommonTypeDate,
"TIMESTAMP": dbi.CommonTypeDatetime,
"LONG": dbi.CommonTypeLongtext,
"BLOB": dbi.CommonTypeLongtext,
"CLOB": dbi.CommonTypeLongtext,
"NCLOB": dbi.CommonTypeLongtext,
"BFILE": dbi.CommonTypeBinary,
}
// 公共数据类型 映射 oracle数据类型
oracleColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "NVARCHAR2",
dbi.CommonTypeChar: "NCHAR",
dbi.CommonTypeText: "CLOB",
dbi.CommonTypeBlob: "CLOB",
dbi.CommonTypeLongblob: "CLOB",
dbi.CommonTypeLongtext: "CLOB",
dbi.CommonTypeBinary: "BFILE",
dbi.CommonTypeMediumblob: "CLOB",
dbi.CommonTypeMediumtext: "CLOB",
dbi.CommonTypeVarbinary: "BFILE",
dbi.CommonTypeInt: "INT",
dbi.CommonTypeSmallint: "INT",
dbi.CommonTypeTinyint: "INT",
dbi.CommonTypeNumber: "NUMBER",
dbi.CommonTypeBigint: "NUMBER",
dbi.CommonTypeDatetime: "DATE",
dbi.CommonTypeDate: "DATE",
dbi.CommonTypeTime: "DATE",
dbi.CommonTypeTimestamp: "TIMESTAMP",
dbi.CommonTypeEnum: "CLOB",
dbi.CommonTypeJSON: "CLOB",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberTypeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeTypeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
if dateTimeReg.MatchString(str) || dateTimeIsoReg.MatchString(str) {
dataType = dbi.DataTypeDateTime
}
switch dataType {
// oracle把日期类型数据格式化输出
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
return str
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// oracle把日期类型的数据转化为time类型
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.RFC3339, cast.ToString(dbColumnValue))
return res
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("to_date('%s', 'yyyy-mm-dd hh24:mi:ss')", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
// 翻译为通用数据库类型
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
if t1 == "" {
dialectColumn.DataType = dbi.CommonTypeVarchar
dialectColumn.CharMaxLength = 2000
} else {
dialectColumn.DataType = t1
// 如果是number类型需要根据公共类型加上长度, 如 bigint 需要转换为number(19,0)
if strings.Contains(string(t1), "NUMBER") {
dialectColumn.CharMaxLength = 19
}
}
}
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
ctype := oracleColumnTypeMap[commonColumn.DataType]
if ctype == "" {
commonColumn.DataType = "NVARCHAR2"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
ch.FixColumn(commonColumn)
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/jsonx"
"strings"
@@ -12,11 +13,12 @@ import (
)
func init() {
dbi.Register(dbi.DbTypeOracle, new(Meta))
dbi.Register(DbTypeOracle, new(Meta))
}
const (
DbVersionOracle11 dbi.DbVersion = "11"
DbTypeOracle dbi.DbType = "oracle"
)
type Meta struct {
@@ -114,3 +116,37 @@ func (om *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
}
return &OracleMetadata{dc: conn}
}
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray[*dbi.DbDataType](
CHAR,
NCHAR,
VARCHAR2,
NVARCHAR2,
TEXT,
LONG,
LONGVARCHAR,
IMAGE,
LONGVARBINARY,
CLOB,
BLOB,
DECIMAL,
NUMBER,
INTEGER,
INT,
BIGINT,
TINYINT,
BYTE,
SMALLINT,
BIT,
DOUBLE,
FLOAT,
TIME,
DATE,
TIMESTAMP,
)
}
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
@@ -61,7 +60,7 @@ func (od *OracleMetadata) GetDbNames() ([]string, error) {
func (od *OracleMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := od.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -95,7 +94,7 @@ func (od *OracleMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := od.dc.GetDialect()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
// 如果表数量超过了1000需要分批查询
@@ -121,13 +120,12 @@ func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
return nil, err
}
columnHelper := dialect.GetColumnHelper()
columns := make([]dbi.Column, 0)
for _, re := range res {
column := dbi.Column{
TableName: cast.ToString(re["TABLE_NAME"]),
ColumnName: cast.ToString(re["COLUMN_NAME"]),
DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])),
DataType: cast.ToString(re["DATA_TYPE"]),
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
@@ -138,7 +136,7 @@ func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
columnHelper.FixColumn(&column)
od.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -201,33 +199,7 @@ func (od *OracleMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (od *OracleMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
dialect := od.dc.GetDialect()
// 1.获取表信息
tbs, err := od.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := od.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := od.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(od.dc.GetDialect(), od, tableName, dropBeforeCreate)
}
// 获取DM当前连接的库可访问的schemaNames

View File

@@ -21,7 +21,7 @@ type OracleMetadata11 struct {
func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := od.dc.GetDialect()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
// 如果表数量超过了1000需要分批查询
@@ -47,13 +47,12 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
return nil, err
}
columnHelper := dialect.GetColumnHelper()
columns := make([]dbi.Column, 0)
for _, re := range res {
column := dbi.Column{
TableName: cast.ToString(re["TABLE_NAME"]),
ColumnName: cast.ToString(re["COLUMN_NAME"]),
DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])),
DataType: cast.ToString(re["DATA_TYPE"]),
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
@@ -64,7 +63,7 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
columnHelper.FixColumn(&column)
od.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -72,7 +71,7 @@ func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, erro
func (od *OracleMetadata11) genColumnBasicSql(column dbi.Column) string {
dialect := od.dc.GetDialect()
colName := dialect.QuoteIdentifier(column.ColumnName)
colName := dialect.Quoter().Quote(column.ColumnName)
if column.IsIdentity {
// 11g以前的版本 如果是自增自增列数据类型必须是number不需要设置默认值和空值建表后设置自增序列

View File

@@ -0,0 +1,226 @@
package oracle
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
)
type SQLGenerator struct {
Dialect dbi.Dialect
Metadata dbi.Metadata
}
func (sg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
quoter := sg.Dialect.Quoter()
quote := quoter.Quote
quoteTableName := quote(table.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
dropSqlTmp := `
declare
num number;
begin
select count(1) into num from user_tables where table_name = '%s' and owner = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM dual) ;
if num > 0 then
execute immediate 'drop table "%s"' ;
end if;
end`
sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, table.TableName, table.TableName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s ( \n", quoteTableName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
// 把通用类型转换为达梦类型
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, quote(column.ColumnName))
}
quote := quoter.Quote
fields = append(fields, sg.genColumnBasicSql(quoter, column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := dbi.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, quote(column.ColumnName), comment))
}
}
// 建表
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
sqlArr = append(sqlArr, createSql)
// 表注释
tableCommentSql := ""
if table.TableComment != "" {
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", quote(table.TableName), dbi.QuoteEscape(table.TableComment))
sqlArr = append(sqlArr, tableCommentSql)
}
// 列注释
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
otherSql := sg.GenerateTableOtherDDL(table, quoteTableName, columns)
if len(otherSql) > 0 {
sqlArr = append(sqlArr, otherSql...)
}
return sqlArr
}
func (sg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
sqls := make([]string, 0)
comments := make([]string, 0)
quote := sg.Dialect.Quoter().Quote
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = quote(name)
}
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
}
sqlArr := make([]string, 0)
sqlArr = append(sqlArr, sqls...)
if len(comments) > 0 {
sqlArr = append(sqlArr, comments...)
}
return sqlArr
}
func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
quoter := sg.Dialect.Quoter()
quote := quoter.Quote
if duplicateStrategy == dbi.DuplicateStrategyNone {
identityInsert := fmt.Sprintf("set identity_insert %s on;", quote(tableName))
// 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句
return collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeOracle, columns, [][]any{value})
return fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
})
}
// 查询主键字段
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
metadata := sg.Metadata
tableCols, _ := metadata.GetColumns(tableName)
identityCols := make([]string, 0)
for _, col := range tableCols {
if col.IsPrimaryKey {
uniqueCols = append(uniqueCols, col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", quote(col.ColumnName), quote(col.ColumnName)))
}
if col.IsIdentity {
// 自增字段不放入insert内即使是设置了identity_insert on也不起作用
identityCols = append(identityCols, quote(col.ColumnName))
}
}
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := metadata.GetTableIndex(tableName)
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
tmp := make([]string, 0)
for _, col := range cols {
uniqueCols = append(uniqueCols, col)
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", quote(col), quote(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
}
// 重复数据处理策略
phs := make([]string, 0)
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
for _, column := range columns {
columnName := column.ColumnName
phs = append(phs, fmt.Sprintf("? %s", columnName))
if !collx.ArrayContains(uniqueCols, quoter.Trim(columnName)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", columnName, columnName))
}
if !collx.ArrayContains(identityCols, columnName) {
insertCols = append(insertCols, columnName)
insertVals = append(insertVals, fmt.Sprintf("T2.%s", columnName))
}
}
t2s := make([]string, 0)
for i := 0; i < len(values); i++ {
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + quote(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
return collx.AsArray(sqlTemp)
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
colName := quoter.Quote(column.ColumnName)
if column.IsIdentity {
// 如果是自增不需要设置默认值和空值自增列数据类型必须是number
return fmt.Sprintf(" %s NUMBER generated by default as IDENTITY", colName)
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := ""
if column.ColumnDefault != "" {
defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault)
}
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble)
return columnSql
}
// 11g及以下版本会设置自增序列
func (sg *SQLGenerator) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string {
return nil
}
// 11g及以下版本会设置自增序列和触发器
func (sg *SQLGenerator) Oracle11GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string {
result := make([]string, 0)
for _, col := range columns {
if col.IsIdentity {
seqName := fmt.Sprintf("%s_%s_seq", tableInfo.TableName, col.ColumnName)
trgName := fmt.Sprintf("%s_%s_trg", tableInfo.TableName, col.ColumnName)
result = append(result, fmt.Sprintf("CREATE SEQUENCE %s START WITH 1 INCREMENT BY 1", seqName))
result = append(result, fmt.Sprintf("CREATE OR REPLACE TRIGGER %s BEFORE INSERT ON %s FOR EACH ROW WHEN (NEW.%s IS NULL) BEGIN SELECT %s.nextval INTO :new.%s FROM dual; END", trgName, quoteTableName, col.ColumnName, seqName, col.ColumnName))
}
}
return result
}

View File

@@ -0,0 +1,97 @@
package oracle
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
return VARCHAR2
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return CHAR
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
return NVARCHAR2
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
return NVARCHAR2
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
return NVARCHAR2
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return BIT
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return TINYINT
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return SMALLINT
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return INTEGER
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return BIGINT
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return NUMBER
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return DECIMAL
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return BIGINT
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return INT
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return DATE
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return TIME
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return BLOB
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return NVARCHAR2
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return NVARCHAR2
}

View File

@@ -0,0 +1,31 @@
package postgres
import (
"mayfly-go/internal/db/dbm/dbi"
)
var (
Bool = dbi.NewDbDataType("bool", dbi.DTBit).WithCT(dbi.CTBit).WithFixColumn(dbi.ClearNumScale)
Int2 = dbi.NewDbDataType("int2", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
Int4 = dbi.NewDbDataType("int4", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
Int8 = dbi.NewDbDataType("int8", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
Numeric = dbi.NewDbDataType("numeric", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Smallserial = dbi.NewDbDataType("smallserial", dbi.DTInt16).WithCT(dbi.CTInt2)
Serial = dbi.NewDbDataType("serial", dbi.DTInt32).WithCT(dbi.CTInt4)
Bigserial = dbi.NewDbDataType("bigserial", dbi.DTInt64).WithCT(dbi.CTInt8)
Largeserial = dbi.NewDbDataType("largeserial", dbi.DTInt64).WithCT(dbi.CTInt8)
Money = dbi.NewDbDataType("money", dbi.DTString).WithCT(dbi.CTVarchar)
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
Nchar = dbi.NewDbDataType("nchar", dbi.DTString).WithCT(dbi.CTVarchar)
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
Json = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
Bytea = dbi.NewDbDataType("bytea", dbi.DTString).WithCT(dbi.CTBinary)
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTDateTime).WithFixColumn(dbi.ClearCharMaxLength)
)

View File

@@ -1,12 +1,8 @@
package postgres
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"strings"
"time"
"github.com/may-fly/cast"
@@ -18,107 +14,6 @@ type PgsqlDialect struct {
dc *dbi.DbConn
}
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 执行批量insert sql跟mysql一样 pg或高斯支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
// 构建占位符字符串 "($1, $2, $3), ($4, $5, $6), ..." 用于指定参数
var placeholders []string
for i := 0; i < len(args); i += len(columns) {
var placeholder []string
for j := 0; j < len(columns); j++ {
placeholder = append(placeholder, fmt.Sprintf("$%d", i+j+1))
}
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
}
// 根据冲突策略生成后缀
suffix := ""
if pd.dc.Info.Type == dbi.DbTypeGauss {
// 高斯db使用ON DUPLICATE KEY UPDATE 语法参考 https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
suffix = pd.gaussOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
} else {
// pgsql 默认使用 on conflict 语法参考 http://www.postgres.cn/docs/12/sql-insert.html
// vastbase语法参考 https://docs.vastdata.com.cn/zh/docs/VastbaseE100Ver3.0.0/doc/SQL%E8%AF%AD%E6%B3%95/INSERT.html
// kingbase语法参考 https://help.kingbase.com.cn/v8/development/sql-plsql/sql/SQL_Statements_9.html#insert
suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
}
sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix)
// 执行批量insert sql
return pd.dc.TxExec(tx, sqlStr, args...)
}
// pgsql默认唯一键冲突策略
func (pd *PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
suffix := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n on conflict do nothing"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 生成 on conflict () do update set column1 = excluded.column1, column2 = excluded.column2, ...
var updateColumns []string
for _, col := range columns {
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", col, col))
}
// 查询唯一键名,拼接冲突sql
_, keyRes, _ := pd.dc.Query("SELECT constraint_name FROM information_schema.table_constraints WHERE constraint_schema = $1 AND table_name = $2 AND constraint_type in ('PRIMARY KEY', 'UNIQUE') ", pd.dc.Info.CurrentSchema(), tableName)
if len(keyRes) > 0 {
for _, re := range keyRes {
key := anyx.ToString(re["constraint_name"])
if key != "" {
suffix += fmt.Sprintf(" \n on conflict on constraint %s do update set %s \n", key, strings.Join(updateColumns, ", "))
}
}
}
}
return suffix
}
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
suffix := ""
metadata := pd.dc.GetMetadata()
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 查出表里的唯一键涉及的字段
var uniqueColumns []string
indexs, err := metadata.GetTableIndex(tableName)
if err == nil {
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
for _, col := range cols {
if !collx.ArrayContains(uniqueColumns, strings.ToLower(col)) {
uniqueColumns = append(uniqueColumns, strings.ToLower(col))
}
}
}
}
}
suffix = " \n ON DUPLICATE KEY UPDATE "
for i, col := range columns {
// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段所以得去掉
if !collx.ArrayContains(uniqueColumns, pd.RemoveQuote(strings.ToLower(col))) {
suffix += fmt.Sprintf("%s = excluded.%s", col, col)
if i < len(columns)-1 {
suffix += ", "
}
}
}
}
return suffix
}
func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
// 生成新表名,为老表明+_copy_时间戳
@@ -176,182 +71,13 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
func (pd *PgsqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
quoteTableName := pd.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
commentTmp := "comment on column %s.%s is '%s'"
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, pd.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, pd.genColumnBasicSql(column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := pd.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteIdentifier(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "comment on table %s is '%s'"
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteEscape(tableInfo.TableComment))
}
// create
sqlArr = append(sqlArr, createSql)
// table comment
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
// column comment
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
func (pd *PgsqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
creates := make([]string, 0)
drops := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 如果索引名存在,先删除索引
drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = pd.QuoteIdentifier(name)
}
// 创建索引
creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, pd.QuoteIdentifier(index.IndexName), pd.QuoteIdentifier(pd.dc.Info.CurrentSchema()), pd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := pd.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
}
}
sqlArr := make([]string, 0)
if len(drops) > 0 {
sqlArr = append(sqlArr, drops...)
}
if len(creates) > 0 {
sqlArr = append(sqlArr, creates...)
}
if len(comments) > 0 {
sqlArr = append(sqlArr, comments...)
}
return sqlArr
}
func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) {
for _, column := range columns {
if column.IsIdentity {
_, _ = pd.dc.Exec(fmt.Sprintf("select setval('%s_%s_seq', (SELECT max(%s) from %s))", tableName, column.ColumnName, column.ColumnName, tableName))
}
}
}
func (pd *PgsqlDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (pd *PgsqlDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
}
func (pd *PgsqlDialect) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}
func (pd *PgsqlDialect) genColumnBasicSql(column dbi.Column) string {
colName := pd.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
// 如果数据类型是数字,则去掉长度
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
column.NumPrecision = 0
column.CharMaxLength = 0
func (md *PgsqlDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{
dialect: md,
dc: md.dc,
}
// 如果是自增类型需要转换为serial
if column.IsIdentity {
if dataType == "int4" {
column.DataType = "serial"
} else if dataType == "int2" {
column.DataType = "smallserial"
} else if dataType == "int8" {
column.DataType = "bigserial"
} else {
column.DataType = "bigserial"
}
return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
mark := false
// 哪些字段类型默认值需要加引号
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
// 如果数据类型是日期时间,则写死默认值函数
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
column.ColumnDefault = "CURRENT_TIMESTAMP"
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
// 如果是varchar长度翻倍防止报错
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
column.CharMaxLength = column.CharMaxLength * 2
}
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
return columnSql
}

View File

@@ -4,212 +4,16 @@ import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"regexp"
"strings"
"time"
"github.com/may-fly/cast"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
// 提取pg默认值 如:'id'::varchar 提取id '-1'::integer 提取-1
defaultValueRegexp = regexp.MustCompile(`'([^']*)'`)
// pgsql数据类型 映射 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"int2": dbi.CommonTypeSmallint,
"int4": dbi.CommonTypeInt,
"int8": dbi.CommonTypeBigint,
"numeric": dbi.CommonTypeNumber,
"decimal": dbi.CommonTypeNumber,
"smallserial": dbi.CommonTypeSmallint,
"serial": dbi.CommonTypeInt,
"bigserial": dbi.CommonTypeBigint,
"largeserial": dbi.CommonTypeBigint,
"money": dbi.CommonTypeNumber,
"bool": dbi.CommonTypeTinyint,
"char": dbi.CommonTypeChar,
"character": dbi.CommonTypeChar,
"nchar": dbi.CommonTypeChar,
"varchar": dbi.CommonTypeVarchar,
"text": dbi.CommonTypeText,
"bytea": dbi.CommonTypeText,
"date": dbi.CommonTypeDate,
"time": dbi.CommonTypeTime,
"timestamp": dbi.CommonTypeTimestamp,
}
// 公共数据类型 映射 pgsql数据类型
pgsqlColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "varchar",
dbi.CommonTypeChar: "char",
dbi.CommonTypeText: "text",
dbi.CommonTypeBlob: "text",
dbi.CommonTypeLongblob: "text",
dbi.CommonTypeLongtext: "text",
dbi.CommonTypeBinary: "text",
dbi.CommonTypeMediumblob: "text",
dbi.CommonTypeMediumtext: "text",
dbi.CommonTypeVarbinary: "text",
dbi.CommonTypeInt: "int4",
dbi.CommonTypeSmallint: "int2",
dbi.CommonTypeTinyint: "int2",
dbi.CommonTypeNumber: "numeric",
dbi.CommonTypeBigint: "int8",
dbi.CommonTypeDatetime: "timestamp",
dbi.CommonTypeDate: "date",
dbi.CommonTypeTime: "time",
dbi.CommonTypeTimestamp: "timestamp",
dbi.CommonTypeEnum: "varchar(2000)",
dbi.CommonTypeJSON: "varchar(2000)",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := fmt.Sprintf("%v", dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, err = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
// 尝试用时间格式解析
res, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
// 尝试用时间格式解析
res, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return cast.ToString(dbColumnValue)
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(column *dbi.Column) {
// 翻译为通用数据库类型
dataType := column.DataType
t1 := commonColumnTypeMap[string(dataType)]
if t1 == "" {
column.DataType = dbi.CommonTypeVarchar
column.CharMaxLength = 2000
} else {
column.DataType = t1
}
}
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
ctype := pgsqlColumnTypeMap[commonColumn.DataType]
if ctype == "" {
commonColumn.DataType = "varchar"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
dataType := strings.ToLower(string(column.DataType))
// 哪些字段可以指定长度
if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, dataType) {
column.CharMaxLength = 0
column.NumPrecision = 0
} else if strings.Contains(dataType, "char") {
// 如果类型是文本,长度翻倍
column.CharMaxLength = column.CharMaxLength * 2
}
func FixColumnDefault(column *dbi.Column) {
// 如果默认值带冒号,如:'id'::varchar
if column.ColumnDefault != "" && strings.Contains(column.ColumnDefault, "::") && !strings.HasPrefix(column.ColumnDefault, "nextval") {
match := defaultValueRegexp.FindStringSubmatch(column.ColumnDefault)

View File

@@ -16,16 +16,23 @@ import (
func init() {
meta := new(Meta)
dbi.Register(dbi.DbTypePostgres, meta)
dbi.Register(dbi.DbTypeKingbaseEs, meta)
dbi.Register(dbi.DbTypeVastbase, meta)
dbi.Register(DbTypePostgres, meta)
dbi.Register(DbTypeKingbaseEs, meta)
dbi.Register(DbTypeVastbase, meta)
gauss := &Meta{
Param: "dbtype=gauss",
}
dbi.Register(dbi.DbTypeGauss, gauss)
dbi.Register(DbTypeGauss, gauss)
}
const (
DbTypePostgres dbi.DbType = "postgres"
DbTypeGauss dbi.DbType = "gauss"
DbTypeKingbaseEs dbi.DbType = "kingbaseEs"
DbTypeVastbase dbi.DbType = "vastbase"
)
type Meta struct {
Param string
}
@@ -89,6 +96,20 @@ func (pm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
return &PgsqlMetadata{dc: conn}
}
func (pm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray(
Bool, Int2, Int4, Int8, Numeric, Decimal, Smallserial, Serial, Bigserial, Largeserial,
Money,
Char, Nchar, Varchar, Text, Json,
Date, Time, Timestamp,
Bytea,
)
}
func (pm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}
// pgsql dialer
type PqSqlDialer struct {
sshTunnelMachineId int

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
@@ -54,7 +53,7 @@ func (pd *PgsqlMetadata) GetDbNames() ([]string, error) {
func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := pd.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -88,7 +87,7 @@ func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := pd.dc.GetDialect()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
@@ -96,13 +95,12 @@ func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
return nil, err
}
columnHelper := dialect.GetColumnHelper()
columns := make([]dbi.Column, 0)
for _, re := range res {
column := dbi.Column{
TableName: cast.ToString(re["tableName"]),
ColumnName: cast.ToString(re["columnName"]),
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
DataType: cast.ToString(re["dataType"]),
CharMaxLength: cast.ToInt(re["charMaxLength"]),
ColumnComment: cast.ToString(re["columnComment"]),
Nullable: cast.ToString(re["nullable"]) == "YES",
@@ -112,7 +110,9 @@ func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumPrecision: cast.ToInt(re["numPrecision"]),
NumScale: cast.ToInt(re["numScale"]),
}
columnHelper.FixColumn(&column)
pd.dc.GetDbDataType(column.DataType).FixColumn(&column)
FixColumnDefault(&column)
columns = append(columns, column)
}
return columns, nil
@@ -164,7 +164,9 @@ func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 索引字段已根据名称和顺序排序,故取最后一个即可
i := len(result) - 1
// 同索引字段以逗号连接
result[i].ColumnName = result[i].ColumnName + "," + v.ColumnName
if result[i].ColumnName != v.ColumnName {
result[i].ColumnName = result[i].ColumnName + "," + v.ColumnName
}
} else {
key = in
result = append(result, v)
@@ -175,33 +177,7 @@ func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (pd *PgsqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := pd.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := pd.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
dialect := pd.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := pd.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(pd.dc.GetDialect(), pd, tableName, dropBeforeCreate)
}
// 获取pgsql当前连接的库可访问的schemaNames
@@ -220,11 +196,11 @@ func (pd *PgsqlMetadata) GetSchemas() ([]string, error) {
func (pd *PgsqlMetadata) GetDefaultDb() string {
switch pd.dc.Info.Type {
case dbi.DbTypePostgres, dbi.DbTypeGauss:
case DbTypePostgres, DbTypeGauss:
return "postgres"
case dbi.DbTypeKingbaseEs:
case DbTypeKingbaseEs:
return "security"
case dbi.DbTypeVastbase:
case DbTypeVastbase:
return "vastbase"
default:
return ""

View File

@@ -0,0 +1,263 @@
package postgres
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"strings"
)
type SQLGenerator struct {
dialect dbi.Dialect
dc *dbi.DbConn
}
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
quoter := msg.dialect.Quoter()
quote := quoter.Quote
quoteTableName := quote(table.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
commentTmp := "COMMENT ON COLUMN %s.%s IS '%s'"
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, quote(column.ColumnName))
}
fields = append(fields, msg.genColumnBasicSql(quoter, column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := dbi.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, quote(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
tableCommentSql := ""
if table.TableComment != "" {
commentTmp := "COMMENT ON TABLE %s IS '%s'"
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, dbi.QuoteEscape(table.TableComment))
}
// create
sqlArr = append(sqlArr, createSql)
// table comment
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
// column comment
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
quoter := msg.dialect.Quoter()
quote := quoter.Quote
creates := make([]string, 0)
drops := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = " unique"
}
currentSchema := msg.dc.Info.CurrentSchema()
// 带上后缀. 避免后续判断
if currentSchema != "" {
currentSchema = quote(currentSchema) + "."
}
// 如果索引名存在,先删除索引
drops = append(drops, fmt.Sprintf("DROP INDEX IF EXISTS %s%s", currentSchema, index.IndexName))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = quote(name)
}
// 创建索引
creates = append(creates, fmt.Sprintf("CREATE%s INDEX %s ON %s%s(%s)", unique, quote(index.IndexName), currentSchema, quote(table.TableName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := dbi.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s%s IS '%s'", currentSchema, index.IndexName, comment))
}
}
sqlArr := make([]string, 0)
if len(drops) > 0 {
sqlArr = append(sqlArr, drops...)
}
if len(creates) > 0 {
sqlArr = append(sqlArr, creates...)
}
if len(comments) > 0 {
sqlArr = append(sqlArr, comments...)
}
return sqlArr
}
func (psg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
insertSql := dbi.GenCommonInsert(psg.dialect, psg.dc.Info.Type, tableName, columns, values)
// 根据冲突策略生成后缀
suffix := ""
if psg.dc.Info.Type == DbTypeGauss {
// 高斯db使用ON DUPLICATE KEY UPDATE 语法参考 https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
suffix = psg.gaussOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
} else {
// pgsql 默认使用 on conflict 语法参考 http://www.postgres.cn/docs/12/sql-insert.html
// vastbase语法参考 https://docs.vastdata.com.cn/zh/docs/VastbaseE100Ver3.0.0/doc/SQL%E8%AF%AD%E6%B3%95/INSERT.html
// kingbase语法参考 https://help.kingbase.com.cn/v8/development/sql-plsql/sql/SQL_Statements_9.html#insert
suffix = psg.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
}
return collx.AsArray[string](insertSql + suffix)
}
// pgsql默认唯一键冲突策略
func (psg *SQLGenerator) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []dbi.Column) string {
suffix := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n on conflict do nothing"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 生成 on conflict () do update set column1 = excluded.column1, column2 = excluded.column2, ...
var updateColumns []string
for _, col := range columns {
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", col.ColumnName, col.ColumnName))
}
// 查询唯一键名,拼接冲突sql
_, keyRes, _ := psg.dc.Query("SELECT constraint_name FROM information_schema.table_constraints WHERE constraint_schema = $1 AND table_name = $2 AND constraint_type in ('PRIMARY KEY', 'UNIQUE') ", psg.dc.Info.CurrentSchema(), tableName)
if len(keyRes) > 0 {
for _, re := range keyRes {
key := anyx.ToString(re["constraint_name"])
if key != "" {
suffix += fmt.Sprintf(" \n on conflict on constraint %s do update set %s \n", key, strings.Join(updateColumns, ", "))
}
}
}
}
return suffix
}
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
func (psg *SQLGenerator) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []dbi.Column) string {
suffix := ""
metadata := psg.dc.GetMetadata()
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 查出表里的唯一键涉及的字段
var uniqueColumns []string
indexs, err := metadata.GetTableIndex(tableName)
if err == nil {
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
for _, col := range cols {
if !collx.ArrayContains(uniqueColumns, strings.ToLower(col)) {
uniqueColumns = append(uniqueColumns, strings.ToLower(col))
}
}
}
}
}
suffix = " \n ON DUPLICATE KEY UPDATE "
for i, col := range columns {
// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段所以得去掉
if !collx.ArrayContains(uniqueColumns, psg.dialect.Quoter().Trim(strings.ToLower(col.ColumnName))) {
suffix += fmt.Sprintf("%s = excluded.%s", col.ColumnName, col.ColumnName)
if i < len(columns)-1 {
suffix += ", "
}
}
}
}
return suffix
}
func (pd *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
colName := quoter.Quote(column.ColumnName)
dataType := string(column.DataType)
// 如果数据类型是数字,则去掉长度
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
column.NumPrecision = 0
column.CharMaxLength = 0
}
// 如果是自增类型需要转换为serial
if column.IsIdentity {
if dataType == "int4" {
column.DataType = "serial"
} else if dataType == "int2" {
column.DataType = "smallserial"
} else if dataType == "int8" {
column.DataType = "bigserial"
} else {
column.DataType = "bigserial"
}
return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
mark := false
// 哪些字段类型默认值需要加引号
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
// 如果数据类型是日期时间,则写死默认值函数
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
column.ColumnDefault = "CURRENT_TIMESTAMP"
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
return columnSql
}

View File

@@ -0,0 +1,97 @@
package postgres
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
return Varchar
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return Char
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return Int2
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return Int2
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return Int2
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return Int4
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return Int8
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return Numeric
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return Decimal
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return Int8
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return Int4
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return Int2
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return Int2
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return Date
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return Time
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return Timestamp
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return Timestamp
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return Bytea
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return Bytea
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return Bytea
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return Bytea
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return Bytea
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return Varchar
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return Json
}

View File

@@ -0,0 +1,14 @@
package sqlite
import "mayfly-go/internal/db/dbm/dbi"
var (
Integer = dbi.NewDbDataType("integer", dbi.DTInt64).WithCT(dbi.CTInt8)
Real = dbi.NewDbDataType("real", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText)
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
DateTime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
)

View File

@@ -1,10 +1,8 @@
package sqlite
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
"time"
)
@@ -15,42 +13,6 @@ type SqliteDialect struct {
dc *dbi.DbConn
}
func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
_, _ = sd.dc.Exec("PRAGMA foreign_keys = false")
// 执行批量insert sql跟mysql一样 支持批量insert语法
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
// 重复占位符字符串n遍
repeated = strings.Repeat(placeholder+",", len(values))
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert or ignore into"
} else if duplicateStrategy == 2 {
prefix = "insert or replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
exec, err := sd.dc.TxExec(tx, sqlStr, args...)
_, _ = sd.dc.Exec("PRAGMA foreign_keys = true;")
// 执行批量insert sql
return exec, err
}
func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
@@ -83,101 +45,12 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
// 获取建表ddl
func (sd *SqliteDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
sqlArr := make([]string, 0)
tbName := sd.QuoteIdentifier(tableInfo.TableName)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName)
fields := make([]string, 0)
// 把通用类型转换为达梦类型
for _, column := range columns {
fields = append(fields, sd.genColumnBasicSql(column))
}
createSql += strings.Join(fields, ",\n")
createSql += "\n)"
sqlArr = append(sqlArr, createSql)
return sqlArr
}
func (sd *SqliteDialect) genColumnBasicSql(column dbi.Column) string {
incr := ""
if column.IsIdentity {
incr = " AUTOINCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
quoteColumnName := sd.QuoteIdentifier(column.ColumnName)
// 如果是主键,则直接返回,不判断默认值
if column.IsPrimaryKey {
return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble)
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal)
}
// 获取建索引ddl
func (sd *SqliteDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
sqls := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = sd.QuoteIdentifier(name)
}
// 创建前尝试删除
sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName))
sqlTmp := "CREATE %s INDEX %s ON %s (%s) "
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, sd.QuoteIdentifier(index.IndexName), sd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
}
return sqls
}
func (sd *SqliteDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (sd *SqliteDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
}
func (sd *SqliteDialect) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}
func (sd *SqliteDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{
dialect: sd,
}
}

View File

@@ -1,180 +1,10 @@
package sqlite
import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"regexp"
"strings"
"time"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime`)
dataTypeRegexp = regexp.MustCompile(`(\w+)\((\d*),?(\d*)\)`)
// sqlite数据类型 映射 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"int": dbi.CommonTypeInt,
"integer": dbi.CommonTypeInt,
"tinyint": dbi.CommonTypeTinyint,
"smallint": dbi.CommonTypeSmallint,
"mediumint": dbi.CommonTypeSmallint,
"bigint": dbi.CommonTypeBigint,
"int2": dbi.CommonTypeInt,
"int8": dbi.CommonTypeInt,
"character": dbi.CommonTypeChar,
"varchar": dbi.CommonTypeVarchar,
"varying character": dbi.CommonTypeVarchar,
"nchar": dbi.CommonTypeChar,
"native character": dbi.CommonTypeVarchar,
"nvarchar": dbi.CommonTypeVarchar,
"text": dbi.CommonTypeText,
"clob": dbi.CommonTypeBlob,
"blob": dbi.CommonTypeBlob,
"real": dbi.CommonTypeNumber,
"double": dbi.CommonTypeNumber,
"double precision": dbi.CommonTypeNumber,
"float": dbi.CommonTypeNumber,
"numeric": dbi.CommonTypeNumber,
"decimal": dbi.CommonTypeNumber,
"boolean": dbi.CommonTypeTinyint,
"date": dbi.CommonTypeDate,
"datetime": dbi.CommonTypeDatetime,
}
// 公共数据类型 映射 sqlite数据类型
sqliteColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "nvarchar",
dbi.CommonTypeChar: "nchar",
dbi.CommonTypeText: "text",
dbi.CommonTypeBlob: "blob",
dbi.CommonTypeLongblob: "blob",
dbi.CommonTypeLongtext: "text",
dbi.CommonTypeBinary: "text",
dbi.CommonTypeMediumblob: "blob",
dbi.CommonTypeMediumtext: "text",
dbi.CommonTypeVarbinary: "text",
dbi.CommonTypeInt: "int",
dbi.CommonTypeSmallint: "smallint",
dbi.CommonTypeTinyint: "tinyint",
dbi.CommonTypeNumber: "number",
dbi.CommonTypeBigint: "bigint",
dbi.CommonTypeDatetime: "datetime",
dbi.CommonTypeDate: "date",
dbi.CommonTypeTime: "datetime",
dbi.CommonTypeTimestamp: "datetime",
dbi.CommonTypeEnum: "nvarchar(2000)",
dbi.CommonTypeJSON: "nvarchar(2000)",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
// 尝试用时间格式解析
_, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
// 尝试用时间格式解析
_, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
// 尝试用时间格式解析
_, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return str
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
// 翻译为通用数据库类型
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
if t1 == "" {
dialectColumn.DataType = dbi.CommonTypeVarchar
dialectColumn.CharMaxLength = 2000
} else {
dialectColumn.DataType = t1
}
}
func (ch *ColumnHelper) ToColumn(commonColumn *dbi.Column) {
ctype := sqliteColumnTypeMap[commonColumn.DataType]
if ctype == "" {
commonColumn.DataType = "nvarchar"
commonColumn.CharMaxLength = 2000
} else {
ch.FixColumn(commonColumn)
}
}
type DumpHelper struct {
dbi.DefaultDumpHelper
}

View File

@@ -4,13 +4,18 @@ import (
"database/sql"
"errors"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"os"
)
func init() {
dbi.Register(dbi.DbTypeSqlite, new(Meta))
dbi.Register(DbTypeSqlite, new(Meta))
}
const (
DbTypeSqlite dbi.DbType = "sqlite"
)
type Meta struct {
}
@@ -36,3 +41,16 @@ func (sm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
func (sm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
return &SqliteMetadata{dc: conn}
}
func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray(
Integer, Real,
Text,
Blob,
DateTime, Date, Time,
)
}
func (sm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -19,6 +19,10 @@ const (
SQLITE_INDEX_INFO_KEY = "SQLITE_INDEX_INFO"
)
var (
dataTypeRegexp = regexp.MustCompile(`(\w+)\((\d*),?(\d*)\)`)
)
type SqliteMetadata struct {
dbi.DefaultMetadata
@@ -53,7 +57,7 @@ func (sd *SqliteMetadata) GetDbNames() ([]string, error) {
func (sd *SqliteMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := sd.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -98,7 +102,6 @@ func (sd *SqliteMetadata) getDataTypes(dataType string) (string, string, string)
// 获取列元信息, 如列名等
func (sd *SqliteMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
columns := make([]dbi.Column, 0)
columnHelper := sd.dc.GetDialect().GetColumnHelper()
for i := 0; i < len(tableNames); i++ {
tableName := tableNames[i]
_, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName))
@@ -134,9 +137,9 @@ func (sd *SqliteMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
} else {
column.CharMaxLength = cast.ToInt(length)
}
column.DataType = dbi.ColumnDataType(strings.ToLower(dataType))
columnHelper.FixColumn(&column)
column.DataType = strings.ToLower(dataType)
sd.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
}

View File

@@ -0,0 +1,123 @@
package sqlite
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
)
type SQLGenerator struct {
dialect dbi.Dialect
}
func (ssg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
quoter := ssg.dialect.Quoter()
sqlArr := make([]string, 0)
tbName := ssg.dialect.Quoter().Quote(table.TableName)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName)
fields := make([]string, 0)
// 把通用类型转换为达梦类型
for _, column := range columns {
fields = append(fields, ssg.genColumnBasicSql(quoter, column))
}
createSql += strings.Join(fields, ",\n")
createSql += "\n)"
sqlArr = append(sqlArr, createSql)
return sqlArr
}
func (ssg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
quoter := ssg.dialect.Quoter()
quote := quoter.Quote
sqls := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = quote(name)
}
// 创建前尝试删除
sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName))
sqlTmp := "CREATE %s INDEX %s ON %s (%s) "
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, quote(index.IndexName), quote(table.TableName), strings.Join(colNames, ",")))
}
return sqls
}
func (ssg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
if duplicateStrategy == dbi.DuplicateStrategyNone {
return collx.AsArray(dbi.GenCommonInsert(ssg.dialect, DbTypeSqlite, tableName, columns, values))
}
sqls := make([]string, 0)
sqls = append(sqls, "PRAGMA foreign_keys = false")
prefix := "insert or ignore into"
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
prefix = "insert or replace into"
}
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(ssg.dialect, DbTypeSqlite, columns, values)
sqls = append(sqls, "PRAGMA foreign_keys = true")
sqls = append(sqls, fmt.Sprintf("%s %s %s VALUES \n%s", prefix, ssg.dialect.Quoter().Quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
return sqls
}
func (ssg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
incr := ""
if column.IsIdentity {
incr = " AUTOINCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
quoteColumnName := quoter.Quote(column.ColumnName)
// 如果是主键,则直接返回,不判断默认值
if column.IsPrimaryKey {
return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble)
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal)
}

View File

@@ -0,0 +1,97 @@
package sqlite
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return Real
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return Real
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return Integer
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return Date
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return Time
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return DateTime
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return DateTime
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return Text
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return Text
}

View File

@@ -110,7 +110,7 @@ func (d *DbJobBaseImpl) setLastStatus(jobType DbJobType, status DbJobStatus, err
if err != nil {
result = fmt.Sprintf("%s: %v", result, err)
}
d.LastResult = stringx.TruncateStr(result, LastResultSize)
d.LastResult = stringx.Truncate(result, LastResultSize, LastResultSize, "")
d.LastTime = timex.NewNullTime(time.Now())
}

View File

@@ -16,7 +16,7 @@ var En = map[i18n.MsgId]string{
SqlScriptRunFail: "sql script failed to execute",
SqlScriptRunSuccess: "sql script executed successfully",
SqlScripRunProgress: "sql script execution progress",
SqlScripRunProgress: "sql execution progress",
DbDumpErr: "Database export failed",
ErrDbNameExist: "The database name already exists in this instance",
ErrDbNotAccess: "The operation permissions of database [{{.dbName}}] are not configured",

View File

@@ -16,7 +16,7 @@ var Zh_CN = map[i18n.MsgId]string{
SqlScriptRunFail: "sql脚本执行失败",
SqlScriptRunSuccess: "sql脚本执行成功",
SqlScripRunProgress: "sql脚本执行进度",
SqlScripRunProgress: "sql执行进度",
DbDumpErr: "数据库导出失败",
ErrDbNameExist: "该实例下数据库名已存在",
ErrDbNotAccess: "未配置数据库【{{.dbName}}】的操作权限",

View File

@@ -12,7 +12,7 @@ type dbRepoImpl struct {
}
func newDbRepo() repository.Db {
return &dbRepoImpl{base.RepoImpl[*entity.Db]{M: new(entity.Db)}}
return &dbRepoImpl{}
}
// 分页获取数据库信息列表

Some files were not shown because too many files have changed in this diff Show More