mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 15:30:25 +08:00
refactor: 数据同步编辑页优化等
This commit is contained in:
@@ -33,7 +33,7 @@
|
||||
"splitpanes": "^3.1.5",
|
||||
"sql-formatter": "^14.0.0",
|
||||
"uuid": "^9.0.1",
|
||||
"vue": "^3.4.7",
|
||||
"vue": "^3.4.8",
|
||||
"vue-router": "^4.2.5",
|
||||
"xterm": "^5.3.0",
|
||||
"xterm-addon-fit": "^0.8.0",
|
||||
@@ -47,8 +47,8 @@
|
||||
"@types/sortablejs": "^1.15.3",
|
||||
"@typescript-eslint/eslint-plugin": "^6.7.4",
|
||||
"@typescript-eslint/parser": "^6.7.4",
|
||||
"@vitejs/plugin-vue": "^5.0.2",
|
||||
"@vue/compiler-sfc": "^3.4.7",
|
||||
"@vitejs/plugin-vue": "^5.0.3",
|
||||
"@vue/compiler-sfc": "^3.4.8",
|
||||
"dotenv": "^16.3.1",
|
||||
"eslint": "^8.35.0",
|
||||
"eslint-plugin-vue": "^9.19.2",
|
||||
|
||||
@@ -255,6 +255,7 @@ const NodeTypeDbInst = new NodeType(SqlExecNodeType.DbInst).withLoadNodesFunc((p
|
||||
|
||||
// 数据库节点
|
||||
const NodeTypeDb = new NodeType(SqlExecNodeType.Db)
|
||||
.withContextMenuItems([new ContextmenuItem('reloadTables', '刷新').withIcon('RefreshRight').withOnClick((data: any) => reloadNode(data.key))])
|
||||
.withLoadNodesFunc(async (parentNode: TagTreeNode) => {
|
||||
const params = parentNode.params;
|
||||
// pg类数据库会多一层schema
|
||||
@@ -280,6 +281,7 @@ const NodeTypeDb = new NodeType(SqlExecNodeType.Db)
|
||||
|
||||
// postgres schema模式
|
||||
const NodeTypePostgresScheam = new NodeType(SqlExecNodeType.PgSchema)
|
||||
.withContextMenuItems([new ContextmenuItem('reloadTables', '刷新').withIcon('RefreshRight').withOnClick((data: any) => reloadNode(data.key))])
|
||||
.withLoadNodesFunc(async (parentNode: TagTreeNode) => {
|
||||
const params = parentNode.params;
|
||||
return [
|
||||
@@ -292,7 +294,7 @@ const NodeTypePostgresScheam = new NodeType(SqlExecNodeType.PgSchema)
|
||||
// 数据库表菜单节点
|
||||
const NodeTypeTableMenu = new NodeType(SqlExecNodeType.TableMenu)
|
||||
.withContextMenuItems([
|
||||
new ContextmenuItem('reloadTables', '刷新').withIcon('RefreshRight').withOnClick((data: any) => reloadTables(data.key)),
|
||||
new ContextmenuItem('reloadTables', '刷新').withIcon('RefreshRight').withOnClick((data: any) => reloadNode(data.key)),
|
||||
|
||||
new ContextmenuItem('tablesOp', '表操作').withIcon('Setting').withOnClick((data: any) => {
|
||||
const params = data.params;
|
||||
@@ -596,7 +598,7 @@ const getSqlMenuNodeKey = (dbId: number, db: string) => {
|
||||
return `${dbId}.${db}.sql-menu`;
|
||||
};
|
||||
|
||||
const reloadTables = (nodeKey: string) => {
|
||||
const reloadNode = (nodeKey: string) => {
|
||||
state.reloadStatus = true;
|
||||
tagTreeRef.value.reloadNode(nodeKey);
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
>
|
||||
<el-form :model="form" ref="dbForm" :rules="rules" label-width="auto">
|
||||
<el-tabs v-model="tabActiveName" style="height: 450px">
|
||||
<el-tab-pane label="基本信息" name="basic">
|
||||
<el-tab-pane label="基本信息" :name="basicTab">
|
||||
<el-form-item>
|
||||
<el-row>
|
||||
<el-col :span="11">
|
||||
@@ -108,7 +108,7 @@
|
||||
</el-form-item>
|
||||
</el-tab-pane>
|
||||
|
||||
<el-tab-pane label="字段映射" name="field">
|
||||
<el-tab-pane label="字段映射" :name="fieldTab" :disabled="!baseFieldCompleted">
|
||||
<el-form-item prop="fieldMap" label="字段映射" required>
|
||||
<el-table :data="form.fieldMap" :max-height="400" size="small">
|
||||
<el-table-column prop="src" label="源字段" :width="200" />
|
||||
@@ -128,7 +128,7 @@
|
||||
</el-form-item>
|
||||
</el-tab-pane>
|
||||
|
||||
<el-tab-pane label="sql预览" name="sqlPreview">
|
||||
<el-tab-pane label="sql预览" :name="sqlPreviewTab" :disabled="!baseFieldCompleted">
|
||||
<el-form-item prop="fieldMap" label="查询sql">
|
||||
<el-input type="textarea" v-model="state.previewDataSql" readonly :input-style="{ height: '190px' }" />
|
||||
</el-form-item>
|
||||
@@ -140,7 +140,41 @@
|
||||
</el-form>
|
||||
|
||||
<template #footer>
|
||||
<div class="dialog-footer">
|
||||
<div>
|
||||
<el-button
|
||||
v-if="tabActiveName != basicTab"
|
||||
@click="
|
||||
() => {
|
||||
switch (tabActiveName) {
|
||||
case fieldTab:
|
||||
tabActiveName = basicTab;
|
||||
break;
|
||||
case sqlPreviewTab:
|
||||
tabActiveName = fieldTab;
|
||||
break;
|
||||
}
|
||||
}
|
||||
"
|
||||
>上一步</el-button
|
||||
>
|
||||
<el-button
|
||||
v-if="tabActiveName != sqlPreviewTab"
|
||||
:disabled="!baseFieldCompleted"
|
||||
@click="
|
||||
() => {
|
||||
switch (tabActiveName) {
|
||||
case basicTab:
|
||||
tabActiveName = fieldTab;
|
||||
break;
|
||||
case fieldTab:
|
||||
tabActiveName = sqlPreviewTab;
|
||||
break;
|
||||
}
|
||||
}
|
||||
"
|
||||
>下一步</el-button
|
||||
>
|
||||
|
||||
<el-button @click="cancel()">取 消</el-button>
|
||||
<el-button type="primary" :loading="saveBtnLoading" @click="btnOk">确 定</el-button>
|
||||
</div>
|
||||
@@ -150,7 +184,7 @@
|
||||
</template>
|
||||
|
||||
<script lang="ts" setup>
|
||||
import { reactive, ref, toRefs, watch } from 'vue';
|
||||
import { reactive, ref, toRefs, watch, computed } from 'vue';
|
||||
import { dbApi } from './api';
|
||||
import { ElMessage } from 'element-plus';
|
||||
import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue';
|
||||
@@ -191,6 +225,10 @@ const rules = {
|
||||
|
||||
const dbForm: any = ref(null);
|
||||
|
||||
const basicTab = 'basic';
|
||||
const fieldTab = 'field';
|
||||
const sqlPreviewTab = 'sqlPreview';
|
||||
|
||||
type FormData = {
|
||||
id?: number;
|
||||
taskName?: string;
|
||||
@@ -235,34 +273,15 @@ const state = reactive({
|
||||
previewInsertSql: '',
|
||||
});
|
||||
|
||||
const onSelectSrcDb = async (params: any) => {
|
||||
// 初始化数据源
|
||||
params.databases = params.dbs; // 数据源里需要这个值
|
||||
state.srcDbInst = DbInst.getOrNewInst(params);
|
||||
registerDbCompletionItemProvider(params.id, params.db, params.dbs, params.type);
|
||||
};
|
||||
|
||||
const onSelectTargetDb = async (params: any) => {
|
||||
state.targetDbInst = DbInst.getOrNewInst(params);
|
||||
await loadDbTables(params.id, params.db);
|
||||
};
|
||||
|
||||
const loadDbTables = async (dbId: number, db: string) => {
|
||||
// 加载db下的表
|
||||
let data = await dbApi.tableInfos.request({ id: dbId, db });
|
||||
state.targetTableList = data;
|
||||
if (data && data.length > 0) {
|
||||
let names = data.map((a: any) => a.tableName);
|
||||
if (!names.includes(state.form.targetTableName)) {
|
||||
state.form.targetTableName = data[0].tableName;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { tabActiveName, form, submitForm } = toRefs(state);
|
||||
|
||||
const { isFetching: saveBtnLoading, execute: saveExec } = dbApi.saveDatasyncTask.useApi(submitForm);
|
||||
|
||||
// 基础字段信息是否填写完整
|
||||
const baseFieldCompleted = computed(() => {
|
||||
return state.form.srcDbId && state.form.srcDbName && state.form.targetDbId && state.form.targetDbName && state.form.targetTableName;
|
||||
});
|
||||
|
||||
watch(dialogVisible, async (newValue: boolean) => {
|
||||
if (!newValue) {
|
||||
return;
|
||||
@@ -303,6 +322,10 @@ watch(dialogVisible, async (newValue: boolean) => {
|
||||
state.targetDbInst = DbInst.getOrNewInst(db);
|
||||
}
|
||||
|
||||
if (targetDbId && state.form.targetDbName) {
|
||||
await loadDbTables(targetDbId, state.form.targetDbName);
|
||||
}
|
||||
|
||||
// 注册sql代码提示
|
||||
if (srcDbId && srcDbName) {
|
||||
registerDbCompletionItemProvider(srcDbId, srcDbName, state.srcDbInst.databases, state.srcDbInst.type);
|
||||
@@ -311,21 +334,15 @@ watch(dialogVisible, async (newValue: boolean) => {
|
||||
|
||||
watch(tabActiveName, async (newValue: string) => {
|
||||
switch (newValue) {
|
||||
case 'field':
|
||||
case fieldTab:
|
||||
await handleGetSrcFields();
|
||||
await handleGetTargetFields();
|
||||
break;
|
||||
case 'dbConf':
|
||||
await handleGetTargetFields();
|
||||
if (state.form.targetDbId && state.form.targetDbName) {
|
||||
await loadDbTables(state.form.targetDbId, state.form.targetDbName);
|
||||
}
|
||||
break;
|
||||
case 'sqlPreview':
|
||||
case sqlPreviewTab:
|
||||
let srcDbDialect = getDbDialect(state.srcDbInst.type);
|
||||
let targetDbDialect = getDbDialect(state.targetDbInst.type);
|
||||
|
||||
let updField = srcDbDialect.wrapName(state.form.updField!);
|
||||
let updField = srcDbDialect.quoteIdentifier(state.form.updField!);
|
||||
state.previewDataSql = `SELECT * FROM (\n ${state.form.dataSql?.trim() || '请输入数据sql'} \n ) t \n where ${updField} > '${
|
||||
state.form.updFieldVal || ''
|
||||
}'`;
|
||||
@@ -339,19 +356,46 @@ watch(tabActiveName, async (newValue: string) => {
|
||||
});
|
||||
if (fields.size < (state.form.fieldMap?.length || 0)) {
|
||||
ElMessage.warning('字段映射中存在重复的目标字段,请检查');
|
||||
state.previewInsertSql = '';
|
||||
return;
|
||||
}
|
||||
|
||||
let fieldArr = state.form.fieldMap?.map((a: any) => targetDbDialect.wrapName(a.target)) || [];
|
||||
let fieldArr = state.form.fieldMap?.map((a: any) => targetDbDialect.quoteIdentifier(a.target)) || [];
|
||||
let placeholder = '?'.repeat(fieldArr.length).split('').join(',');
|
||||
|
||||
state.previewInsertSql = ` insert into ${targetDbDialect.wrapName(state.form.targetTableName!)}(${fieldArr.join(',')}) values (${placeholder});`;
|
||||
state.previewInsertSql = ` insert into ${targetDbDialect.quoteIdentifier(state.form.targetTableName!)}(${fieldArr.join(
|
||||
','
|
||||
)}) values (${placeholder});`;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
const onSelectSrcDb = async (params: any) => {
|
||||
// 初始化数据源
|
||||
params.databases = params.dbs; // 数据源里需要这个值
|
||||
state.srcDbInst = DbInst.getOrNewInst(params);
|
||||
registerDbCompletionItemProvider(params.id, params.db, params.dbs, params.type);
|
||||
};
|
||||
|
||||
const onSelectTargetDb = async (params: any) => {
|
||||
state.targetDbInst = DbInst.getOrNewInst(params);
|
||||
await loadDbTables(params.id, params.db);
|
||||
};
|
||||
|
||||
const loadDbTables = async (dbId: number, db: string) => {
|
||||
// 加载db下的表
|
||||
let data = await dbApi.tableInfos.request({ id: dbId, db });
|
||||
state.targetTableList = data;
|
||||
if (data && data.length > 0) {
|
||||
let names = data.map((a: any) => a.tableName);
|
||||
if (!names.includes(state.form.targetTableName)) {
|
||||
state.form.targetTableName = data[0].tableName;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleGetSrcFields = async () => {
|
||||
// 执行sql,获取字段信息
|
||||
if (!state.form.dataSql || !state.form.dataSql.trim()) {
|
||||
@@ -456,12 +500,8 @@ const cancel = () => {
|
||||
<style lang="scss">
|
||||
.sync-task-edit {
|
||||
.el-select {
|
||||
// width: 360px;
|
||||
width: 100%;
|
||||
}
|
||||
// .el-input__inner {
|
||||
// width: 100%; /* 将el-select内部输入框的宽度设置为100% */
|
||||
// }
|
||||
.task-sql {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import * as monaco from 'monaco-editor/esm/vs/editor/editor.api';
|
||||
import { editor, languages, Position } from 'monaco-editor';
|
||||
|
||||
import { registerCompletionItemProvider } from '@/components/monaco/completionItemProvider';
|
||||
import {DbDialect, EditorCompletionItem, getDbDialect} from './dialect'
|
||||
import { DbDialect, EditorCompletionItem, getDbDialect } from './dialect';
|
||||
|
||||
const dbInstCache: Map<number, DbInst> = new Map();
|
||||
|
||||
@@ -104,7 +104,7 @@ export class DbInst {
|
||||
},
|
||||
kind: monaco.languages.CompletionItemKind.File,
|
||||
detail: tableComment,
|
||||
insertText: dbDialect.wrapName(tableName) + ' ',
|
||||
insertText: dbDialect.quoteIdentifier(tableName) + ' ',
|
||||
range,
|
||||
sortText: 300 + index + '',
|
||||
});
|
||||
@@ -113,7 +113,7 @@ export class DbInst {
|
||||
}
|
||||
|
||||
/** 加载列信息提示 */
|
||||
async loadTableColumnSuggestions(dbDialect: DbDialect,db: string, tableName: string, range: any) {
|
||||
async loadTableColumnSuggestions(dbDialect: DbDialect, db: string, tableName: string, range: any) {
|
||||
let dbHits = await this.loadDbHints(db);
|
||||
let columns = dbHits[tableName];
|
||||
let suggestions: languages.CompletionItem[] = [];
|
||||
@@ -128,7 +128,7 @@ export class DbInst {
|
||||
},
|
||||
kind: monaco.languages.CompletionItemKind.Property,
|
||||
detail: '', // 不显示detail, 否则选中时备注等会被遮挡
|
||||
insertText: dbDialect.wrapName(fieldName)+ ' ', // create_time
|
||||
insertText: dbDialect.quoteIdentifier(fieldName) + ' ', // create_time
|
||||
range,
|
||||
sortText: 100 + index + '', // 使用表字段声明顺序排序,排序需为字符串类型
|
||||
});
|
||||
@@ -287,7 +287,7 @@ export class DbInst {
|
||||
* @returns
|
||||
*/
|
||||
wrapName = (name: string) => {
|
||||
return getDbDialect(this.type).wrapName(name);
|
||||
return getDbDialect(this.type).quoteIdentifier(name);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -618,7 +618,7 @@ export function registerDbCompletionItemProvider(dbId: number, db: string, dbs:
|
||||
description: 'schema',
|
||||
},
|
||||
kind: monaco.languages.CompletionItemKind.Folder,
|
||||
insertText: dbDialect.wrapName(a),
|
||||
insertText: dbDialect.quoteIdentifier(a),
|
||||
range,
|
||||
});
|
||||
});
|
||||
@@ -679,7 +679,7 @@ export function registerDbCompletionItemProvider(dbId: number, db: string, dbs:
|
||||
},
|
||||
kind: monaco.languages.CompletionItemKind.File,
|
||||
detail: tableComment,
|
||||
insertText: dbDialect.wrapName(tableName) + ' ',
|
||||
insertText: dbDialect.quoteIdentifier(tableName) + ' ',
|
||||
range,
|
||||
sortText: 300 + index + '',
|
||||
});
|
||||
|
||||
@@ -444,8 +444,8 @@ class DMDialect implements DbDialect {
|
||||
};
|
||||
}
|
||||
|
||||
wrapName = (name: string) => {
|
||||
return name;
|
||||
quoteIdentifier = (name: string) => {
|
||||
return `"${name}"`;
|
||||
};
|
||||
|
||||
matchType(text: string, arr: string[]): boolean {
|
||||
|
||||
@@ -143,10 +143,10 @@ export interface DbDialect {
|
||||
getDefaultIndex(): IndexDefinition;
|
||||
|
||||
/**
|
||||
* 包裹数据库表名、字段名等,避免使用关键字为字段名或表名时报错
|
||||
* 引用标识符,包裹数据库表名、字段名等,避免使用关键字为字段名或表名时报错
|
||||
* @param name 名称
|
||||
*/
|
||||
wrapName(name: string): string;
|
||||
quoteIdentifier(name: string): string;
|
||||
|
||||
/**
|
||||
* 生成创建表sql
|
||||
|
||||
@@ -112,7 +112,10 @@ class MysqlDialect implements DbDialect {
|
||||
}
|
||||
|
||||
getDefaultSelectSql(table: string, condition: string, orderBy: string, pageNum: number, limit: number) {
|
||||
return `SELECT * FROM ${this.wrapName(table)} ${condition ? 'WHERE ' + condition : ''} ${orderBy ? orderBy : ''} ${this.getPageSql(pageNum, limit)};`;
|
||||
return `SELECT * FROM ${this.quoteIdentifier(table)} ${condition ? 'WHERE ' + condition : ''} ${orderBy ? orderBy : ''} ${this.getPageSql(
|
||||
pageNum,
|
||||
limit
|
||||
)};`;
|
||||
}
|
||||
|
||||
getPageSql(pageNum: number, limit: number) {
|
||||
@@ -181,7 +184,7 @@ class MysqlDialect implements DbDialect {
|
||||
};
|
||||
}
|
||||
|
||||
wrapName = (name: string) => {
|
||||
quoteIdentifier = (name: string) => {
|
||||
return `\`${name}\``;
|
||||
};
|
||||
|
||||
|
||||
@@ -133,7 +133,10 @@ class PostgresqlDialect implements DbDialect {
|
||||
}
|
||||
|
||||
getDefaultSelectSql(table: string, condition: string, orderBy: string, pageNum: number, limit: number) {
|
||||
return `SELECT * FROM ${this.wrapName(table)} ${condition ? 'WHERE ' + condition : ''} ${orderBy ? orderBy : ''} ${this.getPageSql(pageNum, limit)};`;
|
||||
return `SELECT * FROM ${this.quoteIdentifier(table)} ${condition ? 'WHERE ' + condition : ''} ${orderBy ? orderBy : ''} ${this.getPageSql(
|
||||
pageNum,
|
||||
limit
|
||||
)};`;
|
||||
}
|
||||
|
||||
getPageSql(pageNum: number, limit: number) {
|
||||
@@ -202,7 +205,7 @@ class PostgresqlDialect implements DbDialect {
|
||||
};
|
||||
}
|
||||
|
||||
wrapName = (name: string) => {
|
||||
quoteIdentifier = (name: string) => {
|
||||
// 后端sql解析器暂不支持pgsql
|
||||
return name;
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@ require (
|
||||
gitee.com/chunanyong/dm v1.8.13
|
||||
gitee.com/liuzongyang/libpq v1.0.9
|
||||
github.com/buger/jsonparser v1.1.1
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/glebarez/sqlite v1.10.0
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.1.0
|
||||
@@ -93,5 +94,3 @@ require (
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
vitess.io/vitess v0.17.3 // indirect
|
||||
)
|
||||
|
||||
require github.com/emirpasic/gods v1.18.1
|
||||
|
||||
@@ -277,7 +277,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
|
||||
for _, item := range fieldMap {
|
||||
targetField := item["target"]
|
||||
srcField := item["target"]
|
||||
targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.WrapName(targetField))
|
||||
targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField))
|
||||
srcColumns = append(srcColumns, srcField)
|
||||
}
|
||||
|
||||
|
||||
@@ -25,37 +25,25 @@ func (dbType DbType) Equal(typ string) bool {
|
||||
return ToDbType(typ) == dbType
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
|
||||
// used as part of an SQL statement. For example:
|
||||
//
|
||||
// tblname := "my_table"
|
||||
// data := "my_data"
|
||||
// quoted := quoteIdentifier(tblname, '"')
|
||||
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
|
||||
//
|
||||
// Any double quotes in name will be escaped. The quoted identifier will be
|
||||
// case sensitive when used in a query. If the input string contains a zero
|
||||
// byte, the result will be truncated immediately before it.
|
||||
func (dbType DbType) QuoteIdentifier(name string) string {
|
||||
switch dbType {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
return quoteIdentifier(name, "`")
|
||||
case DbTypePostgres:
|
||||
return pq.QuoteIdentifier(name)
|
||||
return quoteIdentifier(name, `"`)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) MetaDbName() string {
|
||||
switch dbType {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
return ""
|
||||
case DbTypePostgres:
|
||||
return "postgres"
|
||||
case DbTypeDM:
|
||||
return ""
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
// 包装字段名,防止使用了数据库保留关键字
|
||||
func (dbType DbType) WrapName(name string) string {
|
||||
switch dbType {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
return fmt.Sprintf("`%s`", name)
|
||||
default:
|
||||
return fmt.Sprintf(`"%s"`, name)
|
||||
return quoteIdentifier(name, `"`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +56,20 @@ func (dbType DbType) QuoteLiteral(literal string) string {
|
||||
case DbTypePostgres:
|
||||
return pq.QuoteLiteral(literal)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
return pq.QuoteLiteral(literal)
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) MetaDbName() string {
|
||||
switch dbType {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
return ""
|
||||
case DbTypePostgres:
|
||||
return "postgres"
|
||||
case DbTypeDM:
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,24 +79,11 @@ func (dbType DbType) Dialect() sqlparser.Dialect {
|
||||
return sqlparser.MysqlDialect{}
|
||||
case DbTypePostgres:
|
||||
return sqlparser.PostgresDialect{}
|
||||
case DbTypeDM:
|
||||
return sqlparser.PostgresDialect{}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
return sqlparser.PostgresDialect{}
|
||||
}
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
|
||||
// used as part of an SQL statement. For example:
|
||||
//
|
||||
// tblname := "my_table"
|
||||
// data := "my_data"
|
||||
// quoted := pq.QuoteIdentifier(tblname)
|
||||
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
|
||||
//
|
||||
// Any double quotes in name will be escaped. The quoted identifier will be
|
||||
// case sensitive when used in a query. If the input string contains a zero
|
||||
// byte, the result will be truncated immediately before it.
|
||||
func quoteIdentifier(name, quoter string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
@@ -116,7 +104,7 @@ func (dbType DbType) StmtSetForeignKeyChecks(check bool) string {
|
||||
// not currently supported postgres
|
||||
return ""
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +116,6 @@ func (dbType DbType) StmtUseDatabase(dbName string) string {
|
||||
// not currently supported postgres
|
||||
return ""
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,3 +50,34 @@ func Test_QuoteLiteral(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_quoteIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
dbType DbType
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
dbType: DbTypeMysql,
|
||||
sql: "`a`",
|
||||
},
|
||||
{
|
||||
dbType: DbTypeMysql,
|
||||
sql: "select table",
|
||||
},
|
||||
{
|
||||
dbType: DbTypePostgres,
|
||||
sql: "a",
|
||||
},
|
||||
{
|
||||
dbType: DbTypePostgres,
|
||||
sql: "table",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.dbType)+"_"+tt.sql, func(t *testing.T) {
|
||||
got := tt.dbType.QuoteIdentifier(tt.sql)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
|
||||
// 去除最后一个逗号,占位符由括号包裹
|
||||
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
|
||||
|
||||
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder)
|
||||
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
effRows := 0
|
||||
for _, value := range values {
|
||||
// 达梦数据库只能一条条的执行insert
|
||||
|
||||
@@ -236,7 +236,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
|
||||
// 去除最后一个逗号
|
||||
placeholder = strings.TrimSuffix(repeated, ",")
|
||||
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder)
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
// 执行批量insert sql
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
|
||||
@@ -319,7 +319,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
|
||||
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
|
||||
}
|
||||
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
|
||||
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
|
||||
// 执行批量insert sql
|
||||
|
||||
return pd.dc.TxExec(tx, sqlStr, args...)
|
||||
|
||||
Reference in New Issue
Block a user