!110 feat: 支持各源数据库导出sql,数据库迁移部分bug修复

* feat: 各源数据库导出
* fix: 数据库迁移 bug修复
This commit is contained in:
zongyangleo
2024-03-26 09:05:28 +00:00
committed by Coder慌
parent 4b3ed1310d
commit 2acc295259
31 changed files with 821 additions and 424 deletions

View File

@@ -472,6 +472,9 @@ export class DbInst {
// 初始化所有列信息完善需要显示的列类型包含长度等如varchar(20)
static initColumns(columns: any[]) {
if (!columns) {
return;
}
for (let col of columns) {
if (col.charMaxLength > 0) {
col.showDataType = `${col.dataType}(${col.charMaxLength})`;

View File

@@ -7,6 +7,7 @@ import {
DuplicateStrategy,
EditorCompletion,
EditorCompletionItem,
QuoteEscape,
IndexDefinition,
RowDefinition,
sqlColumnType,
@@ -523,7 +524,7 @@ class DMDialect implements DbDialect {
}
// 列注释
if (item.remark) {
columCommentSql += ` comment on column "${data.tableName}"."${item.name}" is '${item.remark}'; `;
columCommentSql += ` comment on column "${data.tableName}"."${item.name}" is '${QuoteEscape(item.remark)}'; `;
}
});
// 建表
@@ -534,7 +535,7 @@ class DMDialect implements DbDialect {
);`;
// 表注释
if (data.tableComment) {
tableCommentSql = ` comment on table "${data.tableName}" is '${data.tableComment}'; `;
tableCommentSql = ` comment on table "${data.tableName}" is '${QuoteEscape(data.tableComment)}'; `;
}
return createSql + tableCommentSql + columCommentSql;
@@ -569,7 +570,7 @@ class DMDialect implements DbDialect {
changeData.add.forEach((a) => {
modifySql += `ALTER TABLE ${dbTable} add COLUMN ${this.genColumnBasicSql(a)};`;
if (a.remark) {
commentSql += `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}';`;
commentSql += `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}';`;
}
if (a.pri) {
priArr.add(`"${a.name}"`);
@@ -579,7 +580,7 @@ class DMDialect implements DbDialect {
if (changeData.upd.length > 0) {
changeData.upd.forEach((a) => {
let cmtSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}';`;
let cmtSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}';`;
if (a.remark && a.oldName === a.name) {
commentSql += cmtSql;
}
@@ -675,7 +676,7 @@ class DMDialect implements DbDialect {
}
if (tableData.oldTableComment !== tableData.tableComment) {
let baseTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.tableName)}`;
sql += `COMMENT ON TABLE ${baseTable} IS '${tableData.tableComment}';`;
sql += `COMMENT ON TABLE ${baseTable} IS '${QuoteEscape(tableData.tableComment)}';`;
}
return sql;
}

View File

@@ -262,6 +262,18 @@ export const getDbDialect = (dbType?: string): DbDialect => {
return dbType2DialectMap.get(dbType!) || mysqlDialect;
};
/**
* 引号转义多用于sql注释转义防止拼接sql报错 comment xx is '注''释' 最终注释文本为: 注'释
* @author liuzongyang
* @since 2024/3/22 08:23
*/
export const QuoteEscape = (str: string): string => {
if (!str) {
return '';
}
return str.replace(/'/g, "''");
};
(function () {
console.log('init register db dialect');
registerDbDialect(DbType.mysql, mysqlDialect);

View File

@@ -7,6 +7,7 @@ import {
DuplicateStrategy,
EditorCompletion,
EditorCompletionItem,
QuoteEscape,
IndexDefinition,
RowDefinition,
} from './index';
@@ -225,7 +226,7 @@ class MssqlDialect implements DbDialect {
item.name && fields.push(this.genColumnBasicSql(item));
item.remark &&
fieldComments.push(
`EXECUTE sp_addextendedproperty N'MS_Description', N'${item.remark}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'COLUMN', N'${item.name}'`
`EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(item.remark)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'COLUMN', N'${item.name}'`
);
if (item.pri) {
pks.push(`${this.quoteIdentifier(item.name)}`);
@@ -244,7 +245,7 @@ class MssqlDialect implements DbDialect {
// 表注释
if (data.tableComment) {
createTable += ` EXECUTE sp_addextendedproperty N'MS_Description', N'${data.tableComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}';`;
createTable += ` EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(data.tableComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}';`;
}
return createTable + createIndexSql + fieldComments.join(';');
@@ -268,7 +269,7 @@ class MssqlDialect implements DbDialect {
sql.push(` CREATE ${a.unique ? 'UNIQUE' : ''} NONCLUSTERED INDEX ${this.quoteIdentifier(a.indexName)} on ${baseTable} (${columnNames.join(',')})`);
if (a.indexComment) {
indexComment.push(
`EXECUTE sp_addextendedproperty N'MS_Description', N'${a.indexComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'INDEX', N'${a.indexName}'`
`EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.indexComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'INDEX', N'${a.indexName}'`
);
}
});
@@ -306,7 +307,7 @@ class MssqlDialect implements DbDialect {
addArr.push(` ALTER TABLE ${baseTable} ADD ${this.genColumnBasicSql(a)}`);
if (a.remark) {
addCommentArr.push(
`EXECUTE sp_addextendedproperty N'MS_Description', N'${a.remark}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'COLUMN', N'${a.name}'`
`EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.remark)}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'COLUMN', N'${a.name}'`
);
}
});
@@ -315,7 +316,7 @@ class MssqlDialect implements DbDialect {
if (changeData.upd.length > 0) {
changeData.upd.forEach((a) => {
if (a.oldName && a.name !== a.oldName) {
renameArr.push(` EXEC sp_rename '${baseTable}.${this.quoteIdentifier(a.oldName)}', '${a.name}', 'COLUMN' `);
renameArr.push(` EXEC sp_rename '${baseTable}.${this.quoteIdentifier(a.oldName)}', '${QuoteEscape(a.name)}', 'COLUMN' `);
} else {
updArr.push(` ALTER TABLE ${baseTable} ALTER COLUMN ${this.genColumnBasicSql(a)} `);
}
@@ -325,13 +326,13 @@ class MssqlDialect implements DbDialect {
'TABLE', N'${tableName}',
'COLUMN', N'${a.name}')) > 0)
EXEC sp_updateextendedproperty
'MS_Description', N'${a.remark}',
'MS_Description', N'${QuoteEscape(a.remark)}',
'SCHEMA', N'${schema}',
'TABLE', N'${tableName}',
'COLUMN', N'${a.name}'
ELSE
EXEC sp_addextendedproperty
'MS_Description', N'${a.remark}',
'MS_Description', N'${QuoteEscape(a.remark)}',
'SCHEMA', N'${schema}',
'TABLE', N'${tableName}',
'COLUMN',N'${a.name}'`);
@@ -367,7 +368,7 @@ ELSE
);
if (a.indexComment) {
commentArr.push(
` EXEC sp_addextendedproperty N'MS_Description', N'${a.indexComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'INDEX', N'${a.indexName}' `
` EXEC sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.indexComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'INDEX', N'${a.indexName}' `
);
}
};
@@ -413,7 +414,7 @@ ELSE
if (tableData.oldTableComment !== tableData.tableComment) {
// 转义注释中的单引号和换行符
let tableComment = tableData.tableComment.replaceAll(/'/g, '').replaceAll(/[\r\n]/g, ' ');
let tableComment = tableData.tableComment.replaceAll(/'/g, "'").replaceAll(/[\r\n]/g, ' ');
sql += `IF ((SELECT COUNT(*) FROM fn_listextendedproperty('MS_Description',
'SCHEMA', N'${schema}',
'TABLE', N'${tableData.tableName}', NULL, NULL)) > 0)

View File

@@ -7,6 +7,7 @@ import {
DuplicateStrategy,
EditorCompletion,
EditorCompletionItem,
QuoteEscape,
IndexDefinition,
RowDefinition,
} from './index';
@@ -208,7 +209,7 @@ class MysqlDialect implements DbDialect {
let onUpdate = 'update_time' === cl.name ? ' ON UPDATE CURRENT_TIMESTAMP ' : '';
return ` ${this.quoteIdentifier(cl.name)} ${cl.type}${length} ${cl.notNull ? 'NOT NULL' : 'NULL'} ${
cl.auto_increment ? 'AUTO_INCREMENT' : ''
} ${defVal} ${onUpdate} comment '${cl.remark || ''}' `;
} ${defVal} ${onUpdate} comment '${QuoteEscape(cl.remark)}' `;
}
getCreateTableSql(data: any): string {
// 创建表结构
@@ -224,14 +225,14 @@ class MysqlDialect implements DbDialect {
return `CREATE TABLE ${data.tableName}
( ${fields.join(',')}
${pks ? `, PRIMARY KEY (${pks.join(',')})` : ''}
) COMMENT='${data.tableComment}';`;
) COMMENT='${QuoteEscape(data.tableComment)}';`;
}
getCreateIndexSql(data: any): string {
// 创建索引
let sql = `ALTER TABLE ${data.tableName}`;
data.indexs.res.forEach((a: any) => {
sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${a.indexComment}',`;
sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${QuoteEscape(a.indexComment)}',`;
});
return sql.substring(0, sql.length - 1) + ';';
}
@@ -312,9 +313,9 @@ class MysqlDialect implements DbDialect {
sql += ',';
}
addIndexs.forEach((a) => {
sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${
sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${QuoteEscape(
a.indexComment
}',`;
)}',`;
});
sql = sql.substring(0, sql.length - 1);
}
@@ -326,7 +327,7 @@ class MysqlDialect implements DbDialect {
getModifyTableInfoSql(tableData: any): string {
let sql = '';
if (tableData.tableComment !== tableData.oldTableComment) {
sql += `ALTER TABLE ${this.quoteIdentifier(tableData.db)}.${this.quoteIdentifier(tableData.oldTableName)} COMMENT '${tableData.tableComment}';`;
sql += `ALTER TABLE ${this.quoteIdentifier(tableData.db)}.${this.quoteIdentifier(tableData.oldTableName)} COMMENT '${QuoteEscape(tableData.tableComment)}';`;
}
if (tableData.tableName !== tableData.oldTableName) {

View File

@@ -7,6 +7,7 @@ import {
DuplicateStrategy,
EditorCompletion,
EditorCompletionItem,
QuoteEscape,
IndexDefinition,
RowDefinition,
sqlColumnType,
@@ -324,7 +325,7 @@ class OracleDialect implements DbDialect {
item.name && fields.push(this.genColumnBasicSql(item, true));
// 列注释
if (item.remark) {
columCommentSql += ` COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(item.name)} is '${item.remark}'; `;
columCommentSql += ` COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(item.name)} is '${QuoteEscape(item.remark)}'; `;
}
// 主键
if (item.pri) {
@@ -340,7 +341,7 @@ class OracleDialect implements DbDialect {
createSql = `CREATE TABLE ${dbTable} ( ${fields.join(',')} ${prisql ? ',' + prisql : ''} ) ;`;
// 表注释
if (data.tableComment) {
tableCommentSql = ` COMMENT ON TABLE ${dbTable} is '${data.tableComment}'; `;
tableCommentSql = ` COMMENT ON TABLE ${dbTable} is '${QuoteEscape(data.tableComment)}'; `;
}
return createSql + tableCommentSql + columCommentSql;
@@ -379,7 +380,7 @@ class OracleDialect implements DbDialect {
if (changeData.upd.length > 0) {
changeData.upd.forEach((a) => {
let commentSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}'`;
let commentSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}'`;
if (a.remark && a.oldName === a.name) {
commentArr.push(commentSql);
}
@@ -401,7 +402,7 @@ class OracleDialect implements DbDialect {
changeData.add.forEach((a) => {
modifyArr.push(` ADD (${this.genColumnBasicSql(a, false)})`);
if (a.remark) {
commentArr.push(`COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}'`);
commentArr.push(`COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}'`);
}
if (a.pri) {
priArr.add(`"${a.name}"`);
@@ -486,7 +487,7 @@ class OracleDialect implements DbDialect {
let sql = '';
if (tableData.tableComment != tableData.oldTableComment) {
let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`;
sql = `COMMENT ON TABLE ${dbTable} is '${tableData.tableComment}';`;
sql = `COMMENT ON TABLE ${dbTable} is '${QuoteEscape(tableData.tableComment)}';`;
}
if (tableData.tableName != tableData.oldTableName) {
let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`;

View File

@@ -7,6 +7,7 @@ import {
DuplicateStrategy,
EditorCompletion,
EditorCompletionItem,
QuoteEscape,
IndexDefinition,
RowDefinition,
sqlColumnType,
@@ -283,7 +284,7 @@ class PostgresqlDialect implements DbDialect {
}
// 列注释
if (item.remark) {
columCommentSql += ` comment on column ${data.tableName}.${item.name} is '${item.remark}'; `;
columCommentSql += ` comment on column ${data.tableName}.${item.name} is '${QuoteEscape(item.remark)}'; `;
}
});
// 建表
@@ -294,7 +295,7 @@ class PostgresqlDialect implements DbDialect {
);`;
// 表注释
if (data.tableComment) {
tableCommentSql = ` comment on table ${data.tableName} is '${data.tableComment}'; `;
tableCommentSql = ` comment on table ${data.tableName} is '${QuoteEscape(data.tableComment)}'; `;
}
return createSql + tableCommentSql + columCommentSql;
@@ -312,7 +313,7 @@ class PostgresqlDialect implements DbDialect {
let colArr = a.columnNames.map((a: string) => `${this.quoteIdentifier(a)}`);
sql.push(`CREATE ${a.unique ? 'UNIQUE' : ''} INDEX ${this.quoteIdentifier(a.indexName)} on ${dbTable} (${colArr.join(',')})`);
if (a.indexComment) {
sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${a.indexComment}'`);
sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${QuoteEscape(a.indexComment)}'`);
}
});
return sql.join(';');
@@ -334,14 +335,14 @@ class PostgresqlDialect implements DbDialect {
changeData.add.forEach((a) => {
modifySql += `alter table ${dbTable} add ${this.genColumnBasicSql(a)};`;
if (a.remark) {
commentSql += `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}';`;
commentSql += `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}';`;
}
});
}
if (changeData.upd.length > 0) {
changeData.upd.forEach((a) => {
let cmtSql = `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}';`;
let cmtSql = `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}';`;
if (a.remark && a.oldName === a.name) {
commentSql += cmtSql;
}
@@ -412,7 +413,7 @@ class PostgresqlDialect implements DbDialect {
let colArr = a.columnNames.map((a: string) => `${this.quoteIdentifier(a)}`);
sql.push(`CREATE ${a.unique ? 'UNIQUE' : ''} INDEX ${this.quoteIdentifier(a.indexName)} on ${dbTable} (${colArr.join(',')})`);
if (a.indexComment) {
sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${a.indexComment}'`);
sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${QuoteEscape(a.indexComment)}'`);
}
});
}
@@ -428,7 +429,7 @@ class PostgresqlDialect implements DbDialect {
let sql = '';
if (tableData.tableComment != tableData.oldTableComment) {
let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`;
sql = `COMMENT ON TABLE ${dbTable} is '${tableData.tableComment}';`;
sql = `COMMENT ON TABLE ${dbTable} is '${QuoteEscape(tableData.tableComment)}';`;
}
if (tableData.tableName != tableData.oldTableName) {
let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`;

View File

@@ -17,6 +17,7 @@ import (
tagapp "mayfly-go/internal/tag/application"
tagentity "mayfly-go/internal/tag/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
@@ -24,6 +25,7 @@ import (
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"mayfly-go/pkg/ws"
"sort"
"strconv"
"strings"
"time"
@@ -336,47 +338,93 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName
}
}
for _, table := range tables {
writer.TryFlush()
quotedTable := dbMeta.QuoteIdentifier(table)
if needStruct {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
ddl, err := dbMeta.GetTableDDL(table)
biz.ErrIsNil(err)
writer.WriteString(ddl + "\n")
}
if !needData {
continue
}
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
// 查询列信息后面生成建表ddl和insert都需要列信息
columns, err := dbMeta.GetColumns(tables...)
// 达梦不支持begin语句
if dbConn.Info.Type != dbi.DbTypeDM {
writer.WriteString("BEGIN;\n")
// 以表名分组,存放每个表的列信息
columnMap := make(map[string][]dbi.Column)
for _, column := range columns {
columnMap[column.TableName] = append(columnMap[column.TableName], column)
}
// 按表名排序
sort.Strings(tables)
quoteSchema := dbMeta.QuoteIdentifier(dbConn.Info.CurrentSchema())
// 遍历获取每个表的信息
for _, tableName := range tables {
quoteTableName := dbMeta.QuoteIdentifier(tableName)
writer.TryFlush()
// 查询表信息,主要是为了查询表注释
tbs, err := dbMeta.GetTables(tableName)
biz.ErrIsNil(err)
if err != nil || tbs == nil || len(tbs) <= 0 {
panic(errorx.NewBiz(fmt.Sprintf("获取表信息失败:%s", tableName)))
}
insertSql := "INSERT INTO %s VALUES (%s);\n"
dbConn.WalkTableRows(ctx, table, func(record map[string]any, columns []*dbi.QueryColumn) error {
var values []string
writer.TryFlush()
for _, column := range columns {
value := record[column.Name]
if value == nil {
values = append(values, "NULL")
continue
}
strValue, ok := value.(string)
if ok {
strValue = dbMeta.QuoteLiteral(strValue)
values = append(values, strValue)
} else {
values = append(values, anyx.ToString(value))
}
tabInfo := dbi.Table{
TableName: tableName,
TableComment: tbs[0].TableComment,
}
// 生成表结构信息
if needStruct {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", tableName))
tbDdlArr := dbMeta.GenerateTableDDL(columnMap[tableName], tabInfo, true)
for _, ddl := range tbDdlArr {
writer.WriteString(ddl + ";\n")
}
writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
return nil
})
writer.WriteString("COMMIT;\n")
}
// 生成insert sql数据在索引前加速insert
if needData {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", tableName))
dbMeta.BeforeDumpInsert(writer, quoteTableName)
// 获取列信息
quoteColNames := make([]string, 0)
for _, col := range columnMap[tableName] {
quoteColNames = append(quoteColNames, dbMeta.QuoteIdentifier(col.ColumnName))
}
converter := dbMeta.GetDataConverter()
_ = dbConn.WalkTableRows(context.Background(), quoteTableName, func(row map[string]any, _ []*dbi.QueryColumn) error {
rowValues := make([]string, len(columnMap[tableName]))
for i, col := range columnMap[tableName] {
rowValues[i] = converter.WrapValue(row[col.ColumnName], converter.GetDataType(string(col.DataType)))
}
beforeInsert := dbMeta.BeforeDumpInsertSql(quoteSchema, quoteTableName)
insertSQL := fmt.Sprintf("%s INSERT INTO %s (%s) values(%s)", beforeInsert, quoteTableName, strings.Join(quoteColNames, ", "), strings.Join(rowValues, ", "))
writer.WriteString(insertSQL + ";\n")
return nil
})
dbMeta.AfterDumpInsert(writer, tableName, columnMap[tableName])
}
indexs, err := dbMeta.GetTableIndex(tableName)
biz.ErrIsNil(err)
// 过滤主键索引
idxs := make([]dbi.Index, 0)
for _, idx := range indexs {
if !idx.IsPrimaryKey {
idxs = append(idxs, idx)
}
}
if len(idxs) > 0 {
// 最后添加索引
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表索引: %s \n-- ----------------------------\n", tableName))
sqlArr := dbMeta.GenerateIndexDDL(idxs, tabInfo)
for _, sqlStr := range sqlArr {
writer.WriteString(sqlStr + ";\n")
}
}
}
}
@@ -450,7 +498,7 @@ func (d *Db) HintTables(rc *req.Ctx) {
func (d *Db) GetTableDDL(rc *req.Ctx) {
tn := rc.Query("tableName")
biz.NotEmpty(tn, "tableName不能为空")
res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn)
res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn, false)
biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s")
rc.ResData = res
}

View File

@@ -12,6 +12,7 @@ import (
"mayfly-go/pkg/req"
"strconv"
"strings"
"time"
)
type DbTransferTask struct {
@@ -53,13 +54,14 @@ func (d *DbTransferTask) DeleteTask(rc *req.Ctx) {
}
func (d *DbTransferTask) Run(rc *req.Ctx) {
start := time.Now()
taskId := d.changeState(rc, entity.DbTransferTaskRunStateRunning)
go d.DbTransferTask.Run(taskId, func(msg string, err error) {
// 修改状态为停止
if err != nil {
logx.Error(msg, err)
} else {
logx.Info(fmt.Sprintf("执行迁移完成,%s", msg))
logx.Info(fmt.Sprintf("执行迁移完成,%s, 耗时:%v", msg, time.Since(start)))
}
// 修改任务状态
task := new(entity.DbTransferTask)

View File

@@ -26,6 +26,18 @@ func (g *gzipWriter) WriteString(data string) {
}
}
func (g *gzipWriter) Write(p []byte) (n int, err error) {
if g.aborted {
return
}
if _, err := g.writer.Write(p); err != nil {
g.aborted = true
biz.IsTrue(false, "数据库导出失败:%s", err)
}
return
}
func (g *gzipWriter) Close() {
g.writer.Close()
}

View File

@@ -151,20 +151,21 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error {
}
task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ")
// 把UpdFieldVal尝试转为int如果可以转为int则不添加引号否则添加引号
if _, err = strconv.Atoi(task.UpdFieldVal); err != nil {
updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal)
} else {
updSql = fmt.Sprintf("and %s > %s", task.UpdField, task.UpdFieldVal)
}
// 如果是oracle且数据类型是时间类型则需要加上to_date函数
if srcConn.Info.Type == dbi.DbTypeOracle {
// 用正则判断数据类型是时间
// 判断UpdFieldVal数据类型
var updFieldValType dbi.DataType
if _, err = strconv.Atoi(task.UpdFieldVal); err != nil {
if dateTimeReg.MatchString(task.UpdFieldVal) {
updSql = fmt.Sprintf("and %s > to_date('%s','yyyy-mm-dd hh24:mi:ss')", task.UpdField, task.UpdFieldVal)
updFieldValType = dbi.DataTypeDateTime
} else {
updFieldValType = dbi.DataTypeString
}
} else {
updFieldValType = dbi.DataTypeNumber
}
wrapUpdFieldVal := srcConn.GetMetaData().GetDataConverter().WrapValue(task.UpdFieldVal, updFieldValType)
updSql = fmt.Sprintf("and %s > %s", task.UpdField, wrapUpdFieldVal)
orderSql = "order by " + task.UpdField + " asc "
}
// 正则判断DataSql是否以where .*结尾如果是则不添加where 1 = 1

View File

@@ -10,6 +10,8 @@ import (
"mayfly-go/pkg/gormx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/collx"
"sort"
"strings"
)
@@ -153,15 +155,21 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon
end("获取源表列信息失败", err)
return
}
// 以表名分组,存放每个表的列信息
columnMap := make(map[string][]dbi.Column)
for _, column := range columns {
columnMap[column.TableName] = append(columnMap[column.TableName], column)
}
// 以表名排序
sortTableNames := collx.MapKeys(columnMap)
sort.Strings(sortTableNames)
ctx := context.Background()
for tbName, cols := range columnMap {
for _, tbName := range sortTableNames {
cols := columnMap[tbName]
targetCols := make([]dbi.Column, 0)
for _, col := range cols {
colPtr := &col
@@ -183,7 +191,7 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon
// 迁移数据
logx.Infof("开始迁移数据: 表名:%s", tbName)
total, err := app.transferData(ctx, tbName, srcConn, srcDialect, targetConn, targetDialect)
total, err := app.transferData(ctx, tbName, targetCols, srcConn, srcDialect, targetConn, targetDialect)
if err != nil {
end(fmt.Sprintf("迁移数据失败: 表名:%s, error: %s", tbName, err.Error()), err)
return
@@ -216,27 +224,16 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon
}
}
func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string, srcConn *dbi.DbConn, srcDialect dbi.Dialect, targetConn *dbi.DbConn, targetDialect dbi.Dialect) (int, error) {
func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string, targetColumns []dbi.Column, srcConn *dbi.DbConn, srcDialect dbi.Dialect, targetConn *dbi.DbConn, targetDialect dbi.Dialect) (int, error) {
result := make([]map[string]any, 0)
total := 0 // 总条数
batchSize := 1000 // 每次查询并迁移1000条数据
var queryColumns []*dbi.QueryColumn
var err error
srcMeta := srcConn.GetMetaData()
srcConverter := srcMeta.GetDataConverter()
// 游标查询源表数据,并批量插入目标表
err = srcConn.WalkTableRows(ctx, tableName, func(row map[string]any, columns []*dbi.QueryColumn) error {
if len(queryColumns) == 0 {
for _, col := range columns {
queryColumns = append(queryColumns, &dbi.QueryColumn{
Name: targetConn.GetMetaData().QuoteIdentifier(col.Name),
Type: col.Type,
})
}
}
total++
rawValue := map[string]any{}
for _, column := range columns {
@@ -246,7 +243,7 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string
}
result = append(result, rawValue)
if total%batchSize == 0 {
err = app.transfer2Target(targetConn, queryColumns, result, targetDialect, tableName)
err = app.transfer2Target(targetConn, targetColumns, result, targetDialect, tableName)
if err != nil {
logx.Error("批量插入目标表数据失败", err)
return err
@@ -257,7 +254,7 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string
})
// 处理剩余的数据
if len(result) > 0 {
err = app.transfer2Target(targetConn, queryColumns, result, targetDialect, tableName)
err = app.transfer2Target(targetConn, targetColumns, result, targetDialect, tableName)
if err != nil {
logx.Error(fmt.Sprintf("批量插入目标表数据失败,表名:%s", tableName), err)
return 0, err
@@ -266,23 +263,36 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string
return total, err
}
func (app *dbTransferAppImpl) transfer2Target(targetConn *dbi.DbConn, cols []*dbi.QueryColumn, result []map[string]any, targetDialect dbi.Dialect, tbName string) error {
func (app *dbTransferAppImpl) transfer2Target(targetConn *dbi.DbConn, targetColumns []dbi.Column, result []map[string]any, targetDialect dbi.Dialect, tbName string) error {
tx, err := targetConn.Begin()
if err != nil {
return err
}
targetMeta := targetConn.GetMetaData()
// 收集字段名
var columnNames []string
for _, col := range cols {
columnNames = append(columnNames, col.Name)
for _, col := range targetColumns {
columnNames = append(columnNames, targetMeta.QuoteIdentifier(col.ColumnName))
}
// 从目标库数据中取出源库字段对应的值
values := make([][]any, 0)
for _, record := range result {
rawValue := make([]any, 0)
for _, cn := range columnNames {
rawValue = append(rawValue, record[targetConn.GetMetaData().RemoveQuote(cn)])
for _, tc := range targetColumns {
columnName := tc.ColumnName
val := record[targetMeta.RemoveQuote(columnName)]
if !tc.Nullable {
// 如果val是文本则设置为空格字符
switch val.(type) {
case string:
if val == "" {
val = " "
}
}
}
rawValue = append(rawValue, val)
}
values = append(values, rawValue)
}
@@ -312,6 +322,18 @@ func (app *dbTransferAppImpl) transferIndex(_ context.Context, tableInfo dbi.Tab
return nil
}
// 过滤主键索引
idxs := make([]dbi.Index, 0)
for _, idx := range indexs {
if !idx.IsPrimaryKey {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
return nil
}
// 通过表名、索引信息生成建索引语句,并执行到目标表
return targetDialect.CreateIndex(tableInfo, indexs)
return targetDialect.CreateIndex(tableInfo, idxs)
}

View File

@@ -25,6 +25,9 @@ type MetaData interface {
// 获取指定表名的所有列元信息
GetColumns(tableNames ...string) ([]Column, error)
// 根据数据库类型修复字段长度、精度等
FixColumn(column *Column)
// 获取表主键字段名,没有主键标识则默认第一个字段
GetPrimaryKey(tableName string) (string, error)
@@ -32,7 +35,7 @@ type MetaData interface {
GetTableIndex(tableName string) ([]Index, error)
// 获取建表ddl
GetTableDDL(tableName string) (string, error)
GetTableDDL(tableName string, dropBeforeCreate bool) (string, error)
GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string
@@ -44,6 +47,9 @@ type MetaData interface {
GetDataConverter() DataConverter
}
// GenerateSQLStepFunc 生成insert sql的step函数用于生成insert sql时每生成100条sql时调用
type GenerateSQLStepFunc func(sqlArr []string)
// 数据库服务实例信息
type DbServer struct {
Version string `json:"version"` // 版本信息
@@ -100,6 +106,7 @@ type Index struct {
IndexComment string `json:"indexComment"` // 备注
SeqInIndex int `json:"seqInIndex"`
IsUnique bool `json:"isUnique"`
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引
}
type ColumnDataType string
@@ -152,6 +159,13 @@ type DataConverter interface {
// 根据数据类型解析数据为符合要求的指定类型等
ParseData(dbColumnValue any, dataType DataType) any
// WrapValue 根据数据类型包装value
// 1.数字型:不需要引号,
// 2.文本型:需要用引号包裹,单引号需要转义,换行符转义,
// 3.date型需要格式化成对应的字符串timehh:mm:ss.SSS date: yyyy-mm-dd datetime:
// 4.特殊oracle date型需要用函数包裹to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss')
WrapValue(dbColumnValue any, dataType DataType) string
}
// ------------------------- 元数据sql操作 -------------------------

View File

@@ -3,6 +3,8 @@ package dbi
import (
pq "gitee.com/liuzongyang/libpq"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"io"
"strings"
)
type BaseMetaData interface {
@@ -13,6 +15,9 @@ type BaseMetaData interface {
// 用于引用 SQL 标识符(关键字)的字符串
GetIdentifierQuoteString() string
// 引号转义多用于sql注释转义防止拼接sql报错 comment xx is '注''释' 最终注释文本为: 注'释
QuoteEscape(str string) string
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
@@ -25,6 +30,12 @@ type BaseMetaData interface {
QuoteLiteral(literal string) string
SqlParserDialect() sqlparser.Dialect
BeforeDumpInsert(writer io.Writer, tableName string)
BeforeDumpInsertSql(quoteSchema string, quoteTableName string) string
AfterDumpInsert(writer io.Writer, tableName string, columns []Column)
}
// 默认实现若需要覆盖则由各个数据库MetaData实现去覆盖重写
@@ -39,6 +50,10 @@ func (dd *DefaultMetaData) GetIdentifierQuoteString() string {
return `"`
}
func (dd *DefaultMetaData) QuoteEscape(str string) string {
return strings.Replace(str, `'`, `''`, -1)
}
func (dd *DefaultMetaData) QuoteLiteral(literal string) string {
return pq.QuoteLiteral(literal)
}
@@ -46,3 +61,13 @@ func (dd *DefaultMetaData) QuoteLiteral(literal string) string {
func (dd *DefaultMetaData) SqlParserDialect() sqlparser.Dialect {
return sqlparser.PostgresDialect{}
}
func (dd *DefaultMetaData) BeforeDumpInsert(writer io.Writer, tableName string) {
writer.Write([]byte("BEGIN;\n"))
}
func (dd *DefaultMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string {
return ""
}
func (dd *DefaultMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []Column) {
writer.Write([]byte("COMMIT;\n"))
}

View File

@@ -52,7 +52,8 @@ WHERE a.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
order by a.TABLE_NAME, a.index_name, c.column_position asc
---------------------------------------
--DM_COLUMN_MA 表列信息
select a.table_name as TABLE_NAME,
select a.owner,
a.table_name as TABLE_NAME,
a.column_name as COLUMN_NAME,
case when a.NULLABLE = 'Y' then 'YES' when a.NULLABLE = 'N' then 'NO' else 'NO' end as NULLABLE,
a.data_type as DATA_TYPE,
@@ -61,25 +62,21 @@ select a.table_name
a.data_scale as NUM_SCALE,
b.comments as COLUMN_COMMENT,
a.data_default as COLUMN_DEFAULT,
case when t.COL_NAME = a.column_name then 1 else 0 end as IS_IDENTITY,
case when t.INFO2 & 0x01 = 0x01 then 1 else 0 end as IS_IDENTITY,
case when t2.constraint_type = 'P' then 1 else 0 end as IS_PRIMARY_KEY
from all_tab_columns a
left join user_col_comments b
on b.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
and b.table_name = a.table_name
and a.column_name = b.column_name
left join (select b.owner, b.TABLE_NAME, a.NAME as COL_NAME
from SYS.SYSCOLUMNS a,
SYS.all_tables b,
SYS.SYSOBJECTS c
where a.INFO2 & 0x01 = 0x01
and a.ID = c.ID
and c.NAME = b.TABLE_NAME) t
on t.table_name = a.table_name and t.owner = a.owner
left join (select c1.*, c2.object_name, c2.owner
FROM SYS.SYSCOLUMNS c1
join SYS.all_objects c2 on c1.id = c2.object_id and c2.object_type = 'TABLE') t
on t.object_name = a.table_name and t.owner = a.owner and t.NAME = a.column_name
left join (select uc.OWNER, uic.column_name, uic.table_name, uc.constraint_type
from user_ind_columns uic
left join user_constraints uc on uic.index_name = uc.index_name) t2
on t2.table_name = t.table_name and a.column_name = t2.column_name
on t2.table_name = t.object_name and a.column_name = t2.column_name and t2.OWNER = a.owner
where a.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))
and a.table_name in (%s)
order by a.table_name,

View File

@@ -35,7 +35,7 @@ where ss.name = ?
{{if .tableNames}}
and t.name in ({{.tableNames}})
{{end}}
ORDER BY t.name DESC;
ORDER BY t.name ASC;
---------------------------------------
--MSSQL_INDEX_INFO 索引信息
SELECT ind.name AS indexName,
@@ -46,7 +46,11 @@ SELECT ind.name AS indexName,
END AS indexType,
IIF(ind.is_unique = 'true', 1, 0) AS isUnique,
ic.key_ordinal AS seqInIndex,
idx.value AS indexComment
idx.value AS indexComment,
CASE
WHEN LEFT(ind.name, 3) = 'PK_' THEN 1
ELSE 0
END AS isPrimaryKey
FROM sys.indexes ind
LEFT JOIN sys.tables t on t.object_id = ind.object_id
LEFT JOIN sys.schemas ss on t.schema_id = ss.schema_id

View File

@@ -35,7 +35,8 @@ SELECT
index_type indexType,
IF(non_unique, 0, 1) isUnique,
SEQ_IN_INDEX seqInIndex,
INDEX_COMMENT indexComment
INDEX_COMMENT indexComment,
index_name = 'PRIMARY' as isPrimaryKey
FROM
information_schema.STATISTICS
WHERE

View File

@@ -38,7 +38,12 @@ SELECT ai.INDEX_NAME AS INDEX_NAME,
WHERE aic.INDEX_OWNER = ai.OWNER
AND aic.INDEX_NAME = ai.INDEX_NAME
AND aic.TABLE_NAME = ai.TABLE_NAME
AND ROWNUM = 1) AS INDEX_COMMENT
AND ROWNUM = 1) AS INDEX_COMMENT,
CASE
WHEN ai.INDEX_NAME like 'PK_%%' THEN 1
WHEN ai.INDEX_NAME like 'SYS_%%' THEN 1
ELSE 0
END AS IS_PRIMARY
FROM ALL_INDEXES ai
WHERE ai.OWNER = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM DUAL)
AND ai.table_name = '%s'
@@ -50,14 +55,14 @@ SELECT a.TABLE_NAME as TABLE_NAME,
when a.NULLABLE = 'Y' then 'YES'
when a.NULLABLE = 'N' then 'NO'
else 'NO' end as NULLABLE,
case
when a.DATA_PRECISION > 0 then a.DATA_TYPE
else (a.DATA_TYPE || '(' || a.DATA_LENGTH || ')') end as COLUMN_TYPE,
a.DATA_TYPE as DATA_TYPE,
a.DATA_LENGTH as CHAR_MAX_LENGTH,
a.DATA_PRECISION as NUM_PRECISION,
a.DATA_SCALE as NUM_SCALE,
b.COMMENTS as COLUMN_COMMENT,
a.DATA_DEFAULT as COLUMN_DEFAULT,
a.DATA_SCALE as NUM_SCALE,
CASE WHEN d.pri IS NOT NULL THEN 1 ELSE 0 END as IS_PRIMARY_KEY,
CASE WHEN a.IDENTITY_COLUMN = 'YES' THEN 1 ELSE 0 END as IS_IDENTITY
CASE WHEN d.pri IS NOT NULL THEN 1 ELSE 0 END as IS_PRIMARY_KEY,
CASE WHEN a.IDENTITY_COLUMN = 'YES' THEN 1 ELSE 0 END as IS_IDENTITY
FROM ALL_TAB_COLUMNS a
LEFT JOIN ALL_COL_COMMENTS b
on a.OWNER = b.OWNER AND a.TABLE_NAME = b.TABLE_NAME AND a.COLUMN_NAME = b.COLUMN_NAME

View File

@@ -35,29 +35,28 @@ where
order by c.relname
---------------------------------------
--PGSQL_INDEX_INFO 表索引信息
SELECT
indexname AS "indexName",
'BTREE' AS "IndexType",
case when indexdef like 'CREATE UNIQUE INDEX%%' then 1 else 0 end as "isUnique",
obj_description(b.oid, 'pg_class') AS "indexComment",
indexdef AS "indexDef",
c.attname AS "columnName",
c.attnum AS "seqInIndex"
SELECT indexname AS "indexName",
'BTREE' AS "IndexType",
case when indexdef like 'CREATE UNIQUE INDEX%%' then 1 else 0 end as "isUnique",
obj_description(b.oid, 'pg_class') AS "indexComment",
indexdef AS "indexDef",
c.attname AS "columnName",
c.attnum AS "seqInIndex",
case when indexname like '%_pkey' then 1 else 0 end as "isPrimaryKey"
FROM pg_indexes a
join pg_class b on a.indexname = b.relname
join pg_attribute c on b.oid = c.attrelid
join pg_class b on a.indexname = b.relname
join pg_attribute c on b.oid = c.attrelid
WHERE a.schemaname = (select current_schema())
AND a.tablename = '%s';
---------------------------------------
--PGSQL_COLUMN_MA 表列信息
SELECT a.*,
a.table_name AS "tableName",
SELECT a.table_name AS "tableName",
a.column_name AS "columnName",
a.is_nullable AS "nullable",
a.udt_name AS "dataType",
a.character_maximum_length AS "charMaxLength",
a.numeric_precision AS "numPrecision",
a.column_default AS "columnDefault",
case when a.column_default like 'nextval%%' then null else a.column_default end AS "columnDefault",
a.numeric_scale AS "numScale",
case when a.column_default like 'nextval%%' then 1 else 0 end AS "isIdentity",
case when b.column_name is not null then 1 else 0 end AS "isPrimaryKey",

View File

@@ -127,7 +127,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st
func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
metadata := dd.dc.GetMetaData()
ddl, err := metadata.GetTableDDL(tableName)
ddl, err := metadata.GetTableDDL(tableName, false)
if err != nil {
return err
}
@@ -177,21 +177,14 @@ func (dd *DMDialect) ToCommonColumn(dialectColumn *dbi.Column) {
func (dd *DMDialect) ToColumn(commonColumn *dbi.Column) {
ctype := dmColumnTypeMap[commonColumn.DataType]
meta := dd.dc.GetMetaData()
if ctype == "" {
commonColumn.DataType = "VARCHAR"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
// 如果是date不设长度
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(ctype)) {
commonColumn.CharMaxLength = 0
commonColumn.NumPrecision = 0
} else
// 如果是char且长度未设置则默认长度2000
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(ctype)) && commonColumn.CharMaxLength == 0 {
commonColumn.CharMaxLength = 2000
}
meta.FixColumn(commonColumn)
}
}
@@ -212,7 +205,15 @@ func (dd *DMDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, drop
}
func (dd *DMDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error {
sqls := dd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo)
_, err := dd.dc.Exec(strings.Join(sqls, ";"))
return err
sqlArr := dd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo)
// 达梦需要分开执行sql
if len(sqlArr) > 0 {
for _, sqlStr := range sqlArr {
_, err := dd.dc.Exec(sqlStr)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -2,6 +2,7 @@ package dm
import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
@@ -112,11 +113,24 @@ func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) {
NumPrecision: cast.ToInt(re["NUM_PRECISION"]),
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
dd.FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
}
func (dd *DMMetaData) FixColumn(column *dbi.Column) {
// 如果是date不设长度
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) {
column.CharMaxLength = 0
column.NumPrecision = 0
} else
// 如果是char且长度未设置则默认长度2000
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(string(column.DataType))) && column.CharMaxLength == 0 {
column.CharMaxLength = 2000
}
}
func (dd *DMMetaData) GetPrimaryKey(tablename string) (string, error) {
columns, err := dd.GetColumns(tablename)
if err != nil {
@@ -150,6 +164,7 @@ func (dd *DMMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
IndexComment: cast.ToString(re["INDEX_COMMENT"]),
IsUnique: cast.ToInt(re["IS_UNIQUE"]) == 1,
SeqInIndex: cast.ToInt(re["SEQ_IN_INDEX"]),
IsPrimaryKey: false,
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -206,7 +221,7 @@ func (dd *DMMetaData) genColumnBasicSql(column dbi.Column) string {
}
}
columnSql := fmt.Sprintf(" %s %s %s %s %s", colName, column.GetColumnType(), incr, nullAble, defVal)
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}
@@ -215,27 +230,25 @@ func (dd *DMMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table)
meta := dd.dc.GetMetaData()
sqls := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName)
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, indexName, meta.QuoteIdentifier(tableInfo.TableName), index.ColumnName))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, index.IndexName, meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
}
return sqls
}
func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
meta := dd.dc.GetMetaData()
replacer := strings.NewReplacer(";", "", "'", "")
tbName := meta.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
@@ -253,25 +266,21 @@ func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table
pks = append(pks, meta.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, dd.genColumnBasicSql(column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := replacer.Replace(column.ColumnComment)
comment := meta.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, meta.QuoteIdentifier(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",")
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(pks, ","))
createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += ")"
createSql += "\n)"
tableCommentSql := ""
if tableInfo.TableComment != "" {
// 防止注释内含有特殊字符串导致sql出错
comment := replacer.Replace(tableInfo.TableComment)
if comment != "" {
tableCommentSql = fmt.Sprintf(" comment on table %s is '%s'", tbName, comment)
}
comment := meta.QuoteEscape(tableInfo.TableComment)
tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment)
}
sqlArr = append(sqlArr, createSql)
@@ -287,13 +296,12 @@ func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table
}
// 获取建表ddl
func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) {
func (dd *DMMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := dd.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil && len(tbs) > 0 {
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
@@ -306,7 +314,7 @@ func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := dd.GenerateTableDDL(columns, *tableInfo, false)
tableDDLArr := dd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := dd.GetTableIndex(tableName)
if err != nil {
@@ -315,7 +323,7 @@ func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) {
}
// 组装返回
tableDDLArr = append(tableDDLArr, dd.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";"), nil
return strings.Join(tableDDLArr, ";\n"), nil
}
// 获取DM当前连接的库可访问的schemaNames
@@ -332,6 +340,18 @@ func (dd *DMMetaData) GetSchemas() ([]string, error) {
return schemaNames, nil
}
func (dd *DMMetaData) BeforeDumpInsert(writer io.Writer, tableName string) {
}
func (dd *DMMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string {
return fmt.Sprintf("set identity_insert %s on;", tableName)
}
func (dd *DMMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) {
writer.Write([]byte("COMMIT;\n"))
}
func (dd *DMMetaData) GetDataConverter() dbi.DataConverter {
return converter
}
@@ -406,7 +426,7 @@ var (
type DataConverter struct {
}
func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
@@ -422,23 +442,38 @@ func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
return dbi.DataTypeString
}
func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return str
}
func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
@@ -457,3 +492,24 @@ func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any
}
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}

View File

@@ -253,35 +253,14 @@ func (md *MssqlDialect) ToCommonColumn(dialectColumn *dbi.Column) {
func (md *MssqlDialect) ToColumn(commonColumn *dbi.Column) {
ctype := mssqlColumnTypeMap[commonColumn.DataType]
meta := md.dc.GetMetaData()
if ctype == "" {
commonColumn.DataType = "varchar"
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
if strings.Contains(strings.ToLower(ctype), "int") {
// 如果类型是数字,类型后不需要带长度
commonColumn.CharMaxLength = 0
commonColumn.NumPrecision = 0
} else if collx.ArrayAnyMatches([]string{"float", "number", "decimal"}, strings.ToLower(ctype)) {
// 如果是float最大长度为38
if commonColumn.CharMaxLength > 38 {
commonColumn.CharMaxLength = 38
}
if commonColumn.NumPrecision > 38 {
commonColumn.NumPrecision = 38
}
} else if strings.Contains(strings.ToLower(ctype), "char") {
// 如果是字符串类型长度最大4000否则修改字段类型为text
if commonColumn.CharMaxLength > 4000 {
commonColumn.DataType = "text"
commonColumn.CharMaxLength = 0
}
} else if strings.Contains(strings.ToLower(ctype), "text") {
// 如果是text取消长度
commonColumn.CharMaxLength = 0
}
meta.FixColumn(commonColumn)
}
}

View File

@@ -119,28 +119,32 @@ func (md *MssqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
dataType := strings.ToLower(string(column.DataType))
if collx.ArrayAnyMatches([]string{"date", "time"}, dataType) {
// 如果是datetime精度取NumScale字段
column.CharMaxLength = column.NumScale
} else if collx.ArrayAnyMatches([]string{"int", "bit", "real", "text", "xml"}, dataType) {
// 不显示长度的类型
column.NumPrecision = 0
column.CharMaxLength = 0
} else if collx.ArrayAnyMatches([]string{"numeric", "decimal", "float"}, dataType) {
// 如果是num长度取精度和小数位数
column.CharMaxLength = 0
} else if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) {
// 如果是nvarchar可视长度减半
column.CharMaxLength = column.CharMaxLength / 2
}
md.FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
}
func (md *MssqlMetaData) FixColumn(column *dbi.Column) {
dataType := strings.ToLower(string(column.DataType))
if collx.ArrayAnyMatches([]string{"date", "time"}, dataType) {
// 如果是datetime精度取NumScale字段
column.CharMaxLength = column.NumScale
} else if collx.ArrayAnyMatches([]string{"int", "bit", "real", "text", "xml"}, dataType) {
// 不显示长度的类型
column.NumPrecision = 0
column.CharMaxLength = 0
} else if collx.ArrayAnyMatches([]string{"numeric", "decimal", "float"}, dataType) {
// 如果是num长度取精度和小数位数
column.CharMaxLength = 0
} else if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) {
// 如果是nvarchar可视长度减半
column.CharMaxLength = column.CharMaxLength / 2
}
}
// 获取表主键字段名,不存在主键标识则默认第一个字段
func (md *MssqlMetaData) GetPrimaryKey(tablename string) (string, error) {
columns, err := md.GetColumns(tablename)
@@ -174,6 +178,7 @@ func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, err
IndexComment: cast.ToString(re["indexComment"]),
IsUnique: cast.ToInt(re["isUnique"]) == 1,
SeqInIndex: cast.ToInt(re["seqInIndex"]),
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
})
}
// 把查询结果以索引名分组,多个索引字段以逗号连接
@@ -199,10 +204,9 @@ func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, err
func (md *MssqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
indexs, _ := md.getTableIndexWithPK(tableName)
result := make([]dbi.Index, 0)
// 过滤掉主键索引,主键索引名为PK__开头的
// 过滤掉主键索引
for _, v := range indexs {
in := v.IndexName
if strings.HasPrefix(in, "PK__") {
if v.IsPrimaryKey {
continue
}
result = append(result, v)
@@ -248,25 +252,25 @@ func (md *MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (st
func (md *MssqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
tbName := tableInfo.TableName
meta := md.dc.GetMetaData()
sqls := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tbName, colName)
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, indexName, md.dc.Info.CurrentSchema(), tbName, index.ColumnName))
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, index.IndexName, md.dc.Info.CurrentSchema(), tbName, strings.Join(colNames, ",")))
if index.IndexComment != "" {
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", index.IndexComment, md.dc.Info.CurrentSchema(), tbName, indexName))
comment := meta.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName))
}
}
if len(comments) > 0 {
@@ -312,7 +316,7 @@ func (md *MssqlMetaData) genColumnBasicSql(column dbi.Column) string {
}
}
columnSql := fmt.Sprintf(" %s %s %s %s %s", colName, column.GetColumnType(), incr, nullAble, defVal)
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}
@@ -320,7 +324,6 @@ func (md *MssqlMetaData) genColumnBasicSql(column dbi.Column) string {
func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
tbName := tableInfo.TableName
meta := md.dc.GetMetaData()
replacer := strings.NewReplacer(";", "", "'", "")
sqlArr := make([]string, 0)
@@ -344,7 +347,7 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := replacer.Replace(column.ColumnComment)
comment := meta.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName))
}
}
@@ -360,7 +363,8 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
tableCommentSql = fmt.Sprintf(commentTmp, replacer.Replace(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName)
tableCommentSql = fmt.Sprintf(commentTmp, meta.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName)
}
sqlArr = append(sqlArr, createSql)
@@ -376,13 +380,12 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
}
// 获取建表ddl
func (md *MssqlMetaData) GetTableDDL(tableName string) (string, error) {
func (md *MssqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil && len(tbs) > 0 {
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
@@ -395,7 +398,7 @@ func (md *MssqlMetaData) GetTableDDL(tableName string) (string, error) {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, false)
tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
@@ -424,6 +427,10 @@ func (md *MssqlMetaData) GetIdentifierQuoteString() string {
return "["
}
func (md *MssqlMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string {
return fmt.Sprintf("set identity_insert %s.%s on ", quoteSchema, tableName)
}
func (md *MssqlMetaData) GetDataConverter() dbi.DataConverter {
return converter
}
@@ -528,6 +535,26 @@ func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
// 如果dataType是datetime而dbColumnValue是string类型则需要根据类型格式化
str, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
if dataType == dbi.DataTypeDate && ok {
// 尝试用时间格式解析
res, _ := time.Parse(time.DateOnly, str)
return res.Format(time.DateOnly)
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, str)
return res.Format(time.TimeOnly)
}
return anyx.ToString(dbColumnValue)
}
@@ -548,3 +575,24 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any
}
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}

View File

@@ -97,16 +97,7 @@ func (md *MysqlDialect) ToColumn(column *dbi.Column) {
column.CharMaxLength = 1000
} else {
column.DataType = dbi.ColumnDataType(ctype)
// 如果是int整型删除精度
if strings.Contains(strings.ToLower(ctype), "int") {
column.NumScale = 0
column.CharMaxLength = 0
} else
// 如果是text删除长度
if strings.Contains(strings.ToLower(ctype), "text") {
column.CharMaxLength = 0
column.NumPrecision = 0
}
md.dc.GetMetaData().FixColumn(column)
}
}

View File

@@ -117,11 +117,25 @@ func (md *MysqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["numScale"]),
}
md.FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
}
func (md *MysqlMetaData) FixColumn(column *dbi.Column) {
// 如果是int整型删除精度
if strings.Contains(strings.ToLower(string(column.DataType)), "int") {
column.NumScale = 0
column.CharMaxLength = 0
} else
// 如果是text删除长度
if strings.Contains(strings.ToLower(string(column.DataType)), "text") {
column.CharMaxLength = 0
column.NumPrecision = 0
}
}
// 获取表主键字段名,不存在主键标识则默认第一个字段
func (md *MysqlMetaData) GetPrimaryKey(tablename string) (string, error) {
columns, err := md.GetColumns(tablename)
@@ -157,6 +171,7 @@ func (md *MysqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
IndexComment: cast.ToString(re["indexComment"]),
IsUnique: cast.ToInt(re["isUnique"]) == 1,
SeqInIndex: cast.ToInt(re["seqInIndex"]),
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -183,27 +198,29 @@ func (md *MysqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tabl
meta := md.dc.GetMetaData()
sqlArr := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName)
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE COMMENT '%s'"
replacer := strings.NewReplacer(";", "", "'", "")
sqlArr = append(sqlArr, fmt.Sprintf(sqlTmp, meta.QuoteIdentifier(tableInfo.TableName), unique, indexName, index.ColumnName, replacer.Replace(index.IndexComment)))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
sqlStr := fmt.Sprintf(sqlTmp, meta.QuoteIdentifier(tableInfo.TableName), unique, index.IndexName, strings.Join(colNames, ","))
comment := meta.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
}
sqlArr = append(sqlArr, sqlStr)
}
return sqlArr
}
func (md *MysqlMetaData) genColumnBasicSql(column dbi.Column) string {
replacer := strings.NewReplacer(";", "", "'", "")
meta := md.dc.GetMetaData()
dataType := string(column.DataType)
incr := ""
@@ -238,11 +255,11 @@ func (md *MysqlMetaData) genColumnBasicSql(column dbi.Column) string {
comment := ""
if column.ColumnComment != "" {
// 防止注释内含有特殊字符串导致sql出错
commentStr := replacer.Replace(column.ColumnComment)
commentStr := meta.QuoteEscape(column.ColumnComment)
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
}
columnSql := fmt.Sprintf(" %s %s %s %s %s %s", md.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, incr, defVal, comment)
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, incr, defVal, comment)
return columnSql
}
@@ -272,12 +289,11 @@ func (md *MysqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += fmt.Sprintf(") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ")
createSql += "\n)"
// 表注释
if tableInfo.TableComment != "" {
replacer := strings.NewReplacer(";", "", "'", "")
createSql += fmt.Sprintf(" COMMENT '%s'", replacer.Replace(tableInfo.TableComment))
createSql += fmt.Sprintf(" COMMENT '%s'", meta.QuoteEscape(tableInfo.TableComment))
}
sqlArr = append(sqlArr, createSql)
@@ -286,11 +302,11 @@ func (md *MysqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
}
// 获取建表ddl
func (md *MysqlMetaData) GetTableDDL(tableName string) (string, error) {
func (md *MysqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil && len(tbs) > 0 {
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
@@ -303,7 +319,7 @@ func (md *MysqlMetaData) GetTableDDL(tableName string) (string, error) {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, false)
tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
@@ -426,6 +442,25 @@ func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
// 如果dataType是datetime而dbColumnValue是string类型则需要根据类型格式化
str, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, str)
return res.Format(time.DateOnly)
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, str)
return res.Format(time.TimeOnly)
}
return anyx.ToString(dbColumnValue)
}
@@ -448,3 +483,26 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any
}
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
// mysql时间类型无需格式化
return fmt.Sprintf("'%s'", dbColumnValue)
}
return fmt.Sprintf("'%s'", dbColumnValue)
}

View File

@@ -178,17 +178,7 @@ func (od *OracleDialect) ToColumn(commonColumn *dbi.Column) {
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
// 如果类型是数字,类型后不需要带长度
if strings.Contains(strings.ToLower(ctype), "int") {
commonColumn.CharMaxLength = 0
commonColumn.NumPrecision = 0
} else if strings.Contains(strings.ToLower(ctype), "char") {
// 如果是字符串类型长度最大4000否则修改字段类型为clob
if commonColumn.CharMaxLength > 4000 {
commonColumn.DataType = "CLOB"
commonColumn.CharMaxLength = 0
}
}
od.dc.GetMetaData().FixColumn(commonColumn)
}
}

View File

@@ -120,28 +120,47 @@ func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
columns := make([]dbi.Column, 0)
for _, re := range res {
defaultVal := cast.ToString(re["COLUMN_DEFAULT"])
// 如果默认值包含.nextval说明是序列默认值为null
if strings.Contains(defaultVal, ".nextval") {
defaultVal = ""
}
column := dbi.Column{
TableName: cast.ToString(re["TABLE_NAME"]),
ColumnName: cast.ToString(re["COLUMN_NAME"]),
DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])),
CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]),
ColumnComment: cast.ToString(re["COLUMN_COMMENT"]),
Nullable: cast.ToString(re["NULLABLE"]) == "YES",
IsPrimaryKey: cast.ToInt(re["IS_PRIMARY_KEY"]) == 1,
IsIdentity: cast.ToInt(re["IS_IDENTITY"]) == 1,
ColumnDefault: defaultVal,
ColumnDefault: cast.ToString(re["COLUMN_DEFAULT"]),
NumPrecision: cast.ToInt(re["NUM_PRECISION"]),
NumScale: cast.ToInt(re["NUM_SCALE"]),
}
od.FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
}
func (od *OracleMetaData) FixColumn(column *dbi.Column) {
// 如果默认值包含.nextval说明是序列默认值为null
if strings.Contains(column.ColumnDefault, ".nextval") {
column.ColumnDefault = ""
}
// 统一处理一下数据类型的长度
if collx.ArrayAnyMatches([]string{"date", "time", "lob", "int"}, strings.ToLower(string(column.DataType))) {
// 如果是不需要设置长度的类型
column.CharMaxLength = 0
column.NumPrecision = 0
} else if strings.Contains(strings.ToLower(string(column.DataType)), "char") {
// 如果是字符串类型长度最大4000否则修改字段类型为clob
if column.CharMaxLength > 4000 {
column.DataType = "NCLOB"
column.CharMaxLength = 0
column.NumPrecision = 0
}
}
}
func (od *OracleMetaData) GetPrimaryKey(tablename string) (string, error) {
columns, err := od.GetColumns(tablename)
if err != nil {
@@ -175,6 +194,7 @@ func (od *OracleMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
IndexComment: cast.ToString(re["INDEX_COMMENT"]),
IsUnique: cast.ToInt(re["IS_UNIQUE"]) == 1,
SeqInIndex: cast.ToInt(re["SEQ_IN_INDEX"]),
IsPrimaryKey: cast.ToInt(re["IS_PRIMARY"]) == 1,
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -204,23 +224,19 @@ func (od *OracleMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tab
comments := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName)
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, indexName, meta.QuoteIdentifier(tableInfo.TableName), index.ColumnName))
if index.IndexComment != "" {
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s IS '%s'", indexName, index.IndexComment))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, index.IndexName, meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
}
sqlArr := make([]string, 0)
@@ -237,7 +253,6 @@ func (od *OracleMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tab
func (od *OracleMetaData) genColumnBasicSql(column dbi.Column) string {
meta := od.dc.GetMetaData()
colName := meta.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
if column.IsIdentity {
// 如果是自增不需要设置默认值和空值自增列数据类型必须是number
@@ -249,47 +264,18 @@ func (od *OracleMetaData) genColumnBasicSql(column dbi.Column) string {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
defVal := ""
if column.ColumnDefault != "" {
mark := false
// 哪些字段类型默认值需要加引号
if collx.ArrayAnyMatches([]string{"CHAR", "LONG", "DATE", "TIME", "CLOB", "BLOB", "BFILE"}, dataType) {
// 默认值是时间日期函数的必须要加引号
val := strings.ToUpper(column.ColumnDefault)
if collx.ArrayAnyMatches([]string{"DATE", "TIMESTAMP"}, dataType) && val == "CURRENT_DATE" || val == "CURRENT_TIMESTAMP" {
mark = false
} else {
mark = true
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
} else {
// 如果是数字,默认值提取数字
if collx.ArrayAnyMatches([]string{"NUM", "INT"}, dataType) {
match := bracketsRegexp.FindStringSubmatch(dataType)
if len(match) > 1 {
length := cast.ToInt(match[1])
defVal = fmt.Sprintf(" DEFAULT %d", length)
} else {
defVal = fmt.Sprintf(" DEFAULT 0")
}
}
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault)
}
columnSql := fmt.Sprintf(" %s %s %s %s", colName, column.GetColumnType(), defVal, nullAble)
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble)
return columnSql
}
// 获取建表ddl
func (od *OracleMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
meta := od.dc.GetMetaData()
replacer := strings.NewReplacer(";", "", "'", "")
quoteTableName := meta.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
@@ -302,8 +288,7 @@ begin
if num > 0 then
execute immediate 'drop table "%s"' ;
end if;
end;
`
end`
sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, tableInfo.TableName, tableInfo.TableName))
}
@@ -320,7 +305,7 @@ end;
fields = append(fields, od.genColumnBasicSql(column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := replacer.Replace(column.ColumnComment)
comment := meta.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment))
}
}
@@ -336,7 +321,7 @@ end;
// 表注释
tableCommentSql := ""
if tableInfo.TableComment != "" {
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", meta.QuoteIdentifier(tableInfo.TableName), replacer.Replace(tableInfo.TableComment))
tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", meta.QuoteIdentifier(tableInfo.TableName), meta.QuoteEscape(tableInfo.TableComment))
sqlArr = append(sqlArr, tableCommentSql)
}
@@ -349,13 +334,12 @@ end;
}
// 获取建表ddl
func (od *OracleMetaData) GetTableDDL(tableName string) (string, error) {
func (od *OracleMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := od.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil && len(tbs) > 0 {
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
@@ -368,7 +352,7 @@ func (od *OracleMetaData) GetTableDDL(tableName string) (string, error) {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := od.GenerateTableDDL(columns, *tableInfo, false)
tableDDLArr := od.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := od.GetTableIndex(tableName)
if err != nil {
@@ -476,7 +460,12 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
switch dataType {
// oracle把日期类型数据格式化输出
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
return str
@@ -490,3 +479,24 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any
}
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss')", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}

View File

@@ -197,14 +197,7 @@ func (pd *PgsqlDialect) ToColumn(commonColumn *dbi.Column) {
commonColumn.CharMaxLength = 2000
} else {
commonColumn.DataType = dbi.ColumnDataType(ctype)
// 哪些字段可以指定长度
if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, ctype) {
commonColumn.CharMaxLength = 0
commonColumn.NumPrecision = 0
} else if strings.Contains(strings.ToLower(ctype), "char") {
// 如果类型是文本,长度翻倍
commonColumn.CharMaxLength = commonColumn.CharMaxLength * 2
}
}
}

View File

@@ -2,6 +2,7 @@ package postgres
import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
@@ -114,12 +115,31 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumPrecision: cast.ToInt(re["numPrecision"]),
NumScale: cast.ToInt(re["numScale"]),
}
pd.FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
}
func (pd *PgsqlMetaData) FixColumn(column *dbi.Column) {
dataType := strings.ToLower(string(column.DataType))
// 哪些字段可以指定长度
if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, dataType) {
column.CharMaxLength = 0
column.NumPrecision = 0
} else if strings.Contains(dataType, "char") {
// 如果类型是文本,长度翻倍
column.CharMaxLength = column.CharMaxLength * 2
}
// 如果默认值带冒号,如:'id'::varchar
if column.ColumnDefault != "" && strings.Contains(column.ColumnDefault, "::") && !strings.HasPrefix(column.ColumnDefault, "nextval") {
match := defaultValueRegexp.FindStringSubmatch(column.ColumnDefault)
if len(match) > 1 {
column.ColumnDefault = match[1]
}
}
}
func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) {
columns, err := pd.GetColumns(tablename)
if err != nil {
@@ -153,6 +173,7 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
IndexComment: cast.ToString(re["indexComment"]),
IsUnique: cast.ToInt(re["isUnique"]) == 1,
SeqInIndex: cast.ToInt(re["seqInIndex"]),
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -175,30 +196,30 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
}
func (pd *PgsqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
meta := pd.dc.GetMetaData()
creates := make([]string, 0)
drops := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName)
// 如果索引名存在,先删除索引
drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), indexName))
drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
// 创建索引
creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, indexName, pd.dc.Info.CurrentSchema(), tableInfo.TableName, index.ColumnName))
creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, index.IndexName, pd.dc.Info.CurrentSchema(), tableInfo.TableName, strings.Join(colNames, ",")))
if index.IndexComment != "" {
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), indexName, index.IndexComment))
comment := meta.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
}
}
@@ -222,6 +243,12 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string {
colName := meta.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
// 如果数据类型是数字,则去掉长度
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
column.NumPrecision = 0
column.CharMaxLength = 0
}
// 如果是自增类型需要转换为serial
if column.IsIdentity {
if dataType == "int4" {
@@ -240,31 +267,16 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string {
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
// 如果字段不能为空,则设置默认值
if column.ColumnDefault == "" {
if collx.ArrayAnyMatches([]string{"char", "text", "lob"}, strings.ToLower(dataType)) {
// 文本默认值为空字符串
column.ColumnDefault = " "
} else if collx.ArrayAnyMatches([]string{"int", "num"}, strings.ToLower(dataType)) {
// 数字默认值为0
column.ColumnDefault = "0"
}
}
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
// 哪些字段类型默认值需要加引号
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 如果是文本类型,则默认值不能带括号
if collx.ArrayAnyMatches([]string{"char", "text", "lob"}, dataType) {
column.ColumnDefault = ""
}
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
mark = false
} else {
mark = true
@@ -275,30 +287,31 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string {
column.ColumnDefault = "CURRENT_TIMESTAMP"
}
if column.ColumnDefault != "" {
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s %s %s ", colName, column.GetColumnType(), nullAble, defVal)
// 如果是varchar长度翻倍防止报错
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
column.CharMaxLength = column.CharMaxLength * 2
}
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
return columnSql
}
func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
meta := pd.dc.GetMetaData()
replacer := strings.NewReplacer(";", "", "'", "")
quoteTableName := meta.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", meta.QuoteIdentifier(tableInfo.TableName)))
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", meta.QuoteIdentifier(tableInfo.TableName))
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
@@ -313,8 +326,8 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := replacer.Replace(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, column.TableName, column.ColumnName, comment))
comment := meta.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment))
}
}
@@ -322,12 +335,12 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += ")"
createSql += "\n)"
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "comment on table %s is '%s'"
tableCommentSql = fmt.Sprintf(commentTmp, tableInfo.TableName, replacer.Replace(tableInfo.TableComment))
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteEscape(tableInfo.TableComment))
}
// create
@@ -346,13 +359,12 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta
}
// 获取建表ddl
func (pd *PgsqlMetaData) GetTableDDL(tableName string) (string, error) {
func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := pd.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil && len(tbs) > 0 {
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
@@ -365,7 +377,7 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string) (string, error) {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, false)
tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := pd.GetTableIndex(tableName)
if err != nil {
@@ -404,6 +416,19 @@ func (pd *PgsqlMetaData) DefaultDb() string {
}
}
func (pd *PgsqlMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) {
// 设置自增序列当前值
for _, column := range columns {
if column.IsIdentity {
seq := fmt.Sprintf("SELECT setval('%s_%s_seq', (SELECT max(%s) FROM %s));\n", tableName, column.ColumnName, column.ColumnName, tableName)
writer.Write([]byte(seq))
}
}
writer.Write([]byte("COMMIT;\n"))
}
func (pd *PgsqlMetaData) GetDataConverter() dbi.DataConverter {
return converter
}
@@ -417,8 +442,9 @@ var (
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
// 定义正则表达式,匹配括号内的数字
bracketsRegexp = regexp.MustCompile(`\((\d+)\)`)
// 提取pg默认值 如:'id'::varchar 提取id '-1'::integer 提取-1
defaultValueRegexp = regexp.MustCompile(`'([^']*)'`)
converter = new(DataConverter)
@@ -497,13 +523,28 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
str := fmt.Sprintf("%v", dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, err = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return cast.ToString(dbColumnValue)
@@ -526,3 +567,23 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any
}
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}

View File

@@ -59,7 +59,7 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 生成新表名,为老表明+_copy_时间戳
newTableName := tableName + "_copy_" + time.Now().Format("20060102150405")
ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName)
ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName, false)
if err != nil {
return err
}
@@ -103,6 +103,8 @@ func (sd *SqliteDialect) ToColumn(commonColumn *dbi.Column) {
if ctype == "" {
commonColumn.DataType = "nvarchar"
commonColumn.CharMaxLength = 2000
} else {
sd.dc.GetMetaData().FixColumn(commonColumn)
}
}

View File

@@ -3,6 +3,7 @@ package sqlite
import (
"errors"
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
@@ -138,12 +139,17 @@ func (sd *SqliteMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
}
column.DataType = dbi.ColumnDataType(dataType)
sd.FixColumn(&column)
columns = append(columns, column)
}
}
return columns, nil
}
func (sd *SqliteMetaData) FixColumn(column *dbi.Column) {
}
func (sd *SqliteMetaData) GetPrimaryKey(tableName string) (string, error) {
_, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName))
if err != nil {
@@ -193,6 +199,7 @@ func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
IndexComment: cast.ToString(re["indexComment"]),
IsUnique: isUnique,
SeqInIndex: 1,
IsPrimaryKey: false,
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -201,22 +208,24 @@ func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建索引ddl
func (sd *SqliteMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
meta := sd.dc.GetMetaData()
sqls := make([]string, 0)
for _, index := range indexs {
// 通过字段、表名拼接索引名
columnName := strings.ReplaceAll(index.ColumnName, "-", "")
columnName = strings.ReplaceAll(columnName, "_", "")
colName := strings.ReplaceAll(columnName, ",", "_")
keyType := "normal"
unique := ""
if index.IsUnique {
keyType = "unique"
unique = "unique"
}
indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName)
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
// 创建前尝试删除
sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName))
sqlTmp := "CREATE %s INDEX %s ON \"%s\" (%s) "
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, indexName, tableInfo.TableName, index.ColumnName))
sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, index.IndexName, tableInfo.TableName, strings.Join(colNames, ",")))
}
return sqls
}
@@ -232,9 +241,11 @@ func (sd *SqliteMetaData) genColumnBasicSql(column dbi.Column) string {
nullAble = " NOT NULL"
}
quoteColumnName := sd.dc.GetMetaData().QuoteIdentifier(column.ColumnName)
// 如果是主键,则直接返回,不判断默认值
if column.IsPrimaryKey {
return fmt.Sprintf(" %s integer PRIMARY KEY %s %s", column.ColumnName, incr, nullAble)
return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble)
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
@@ -257,7 +268,7 @@ func (sd *SqliteMetaData) genColumnBasicSql(column dbi.Column) string {
}
}
return fmt.Sprintf(" %s %s %s %s", sd.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, defVal)
return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal)
}
// 获取建表ddl
@@ -275,8 +286,8 @@ func (sd *SqliteMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.T
for _, column := range columns {
fields = append(fields, sd.genColumnBasicSql(column))
}
createSql += strings.Join(fields, ",")
createSql += ") "
createSql += strings.Join(fields, ",\n")
createSql += "\n)"
sqlArr = append(sqlArr, createSql)
@@ -284,12 +295,18 @@ func (sd *SqliteMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.T
}
// 获取建表ddl
func (sd *SqliteMetaData) GetTableDDL(tableName string) (string, error) {
func (sd *SqliteMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
var builder strings.Builder
if dropBeforeCreate {
builder.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s; \n\n", tableName))
}
_, res, err := sd.dc.Query("select sql from sqlite_master WHERE tbl_name=? order by type desc", tableName)
if err != nil {
return "", err
}
var builder strings.Builder
for _, re := range res {
builder.WriteString(cast.ToString(re["sql"]) + "; \n\n")
}
@@ -305,6 +322,11 @@ func (sd *SqliteMetaData) GetDataConverter() dbi.DataConverter {
return converter
}
func (sd *SqliteMetaData) BeforeDumpInsert(writer io.Writer, tableName string) {
}
func (sd *SqliteMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) {
}
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`)
@@ -388,13 +410,28 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
str := anyx.ToString(dbColumnValue)
switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.DateOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, str)
// 尝试用时间格式解析
res, err := time.Parse(time.TimeOnly, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly)
}
return str
@@ -403,3 +440,24 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue
}
func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType))
}
return fmt.Sprintf("'%s'", dbColumnValue)
}