diff --git a/mayfly_go_web/src/views/ops/db/db.ts b/mayfly_go_web/src/views/ops/db/db.ts index 5a62c436..d8bb71e0 100644 --- a/mayfly_go_web/src/views/ops/db/db.ts +++ b/mayfly_go_web/src/views/ops/db/db.ts @@ -472,6 +472,9 @@ export class DbInst { // 初始化所有列信息,完善需要显示的列类型,包含长度等,如varchar(20) static initColumns(columns: any[]) { + if (!columns) { + return; + } for (let col of columns) { if (col.charMaxLength > 0) { col.showDataType = `${col.dataType}(${col.charMaxLength})`; diff --git a/mayfly_go_web/src/views/ops/db/dialect/dm_dialect.ts b/mayfly_go_web/src/views/ops/db/dialect/dm_dialect.ts index b2f19932..3ac10ef5 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/dm_dialect.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/dm_dialect.ts @@ -7,6 +7,7 @@ import { DuplicateStrategy, EditorCompletion, EditorCompletionItem, + QuoteEscape, IndexDefinition, RowDefinition, sqlColumnType, @@ -523,7 +524,7 @@ class DMDialect implements DbDialect { } // 列注释 if (item.remark) { - columCommentSql += ` comment on column "${data.tableName}"."${item.name}" is '${item.remark}'; `; + columCommentSql += ` comment on column "${data.tableName}"."${item.name}" is '${QuoteEscape(item.remark)}'; `; } }); // 建表 @@ -534,7 +535,7 @@ class DMDialect implements DbDialect { );`; // 表注释 if (data.tableComment) { - tableCommentSql = ` comment on table "${data.tableName}" is '${data.tableComment}'; `; + tableCommentSql = ` comment on table "${data.tableName}" is '${QuoteEscape(data.tableComment)}'; `; } return createSql + tableCommentSql + columCommentSql; @@ -569,7 +570,7 @@ class DMDialect implements DbDialect { changeData.add.forEach((a) => { modifySql += `ALTER TABLE ${dbTable} add COLUMN ${this.genColumnBasicSql(a)};`; if (a.remark) { - commentSql += `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}';`; + commentSql += `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}';`; } if (a.pri) { priArr.add(`"${a.name}"`); @@ -579,7 +580,7 @@ class DMDialect implements DbDialect { if (changeData.upd.length > 0) { changeData.upd.forEach((a) => { - let cmtSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}';`; + let cmtSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}';`; if (a.remark && a.oldName === a.name) { commentSql += cmtSql; } @@ -675,7 +676,7 @@ class DMDialect implements DbDialect { } if (tableData.oldTableComment !== tableData.tableComment) { let baseTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.tableName)}`; - sql += `COMMENT ON TABLE ${baseTable} IS '${tableData.tableComment}';`; + sql += `COMMENT ON TABLE ${baseTable} IS '${QuoteEscape(tableData.tableComment)}';`; } return sql; } diff --git a/mayfly_go_web/src/views/ops/db/dialect/index.ts b/mayfly_go_web/src/views/ops/db/dialect/index.ts index 18688570..adbccc9c 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/index.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/index.ts @@ -262,6 +262,18 @@ export const getDbDialect = (dbType?: string): DbDialect => { return dbType2DialectMap.get(dbType!) || mysqlDialect; }; +/** + * 引号转义,多用于sql注释转义,防止拼接sql报错,如: comment xx is '注''释' 最终注释文本为: 注'释 + * @author liuzongyang + * @since 2024/3/22 08:23 + */ +export const QuoteEscape = (str: string): string => { + if (!str) { + return ''; + } + return str.replace(/'/g, "''"); +}; + (function () { console.log('init register db dialect'); registerDbDialect(DbType.mysql, mysqlDialect); diff --git a/mayfly_go_web/src/views/ops/db/dialect/mssql_dialect.ts b/mayfly_go_web/src/views/ops/db/dialect/mssql_dialect.ts index 99878b38..efa264bf 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/mssql_dialect.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/mssql_dialect.ts @@ -7,6 +7,7 @@ import { DuplicateStrategy, EditorCompletion, EditorCompletionItem, + QuoteEscape, IndexDefinition, RowDefinition, } from './index'; @@ -225,7 +226,7 @@ class MssqlDialect implements DbDialect { item.name && fields.push(this.genColumnBasicSql(item)); item.remark && fieldComments.push( - `EXECUTE sp_addextendedproperty N'MS_Description', N'${item.remark}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'COLUMN', N'${item.name}'` + `EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(item.remark)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'COLUMN', N'${item.name}'` ); if (item.pri) { pks.push(`${this.quoteIdentifier(item.name)}`); @@ -244,7 +245,7 @@ class MssqlDialect implements DbDialect { // 表注释 if (data.tableComment) { - createTable += ` EXECUTE sp_addextendedproperty N'MS_Description', N'${data.tableComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}';`; + createTable += ` EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(data.tableComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}';`; } return createTable + createIndexSql + fieldComments.join(';'); @@ -268,7 +269,7 @@ class MssqlDialect implements DbDialect { sql.push(` CREATE ${a.unique ? 'UNIQUE' : ''} NONCLUSTERED INDEX ${this.quoteIdentifier(a.indexName)} on ${baseTable} (${columnNames.join(',')})`); if (a.indexComment) { indexComment.push( - `EXECUTE sp_addextendedproperty N'MS_Description', N'${a.indexComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'INDEX', N'${a.indexName}'` + `EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.indexComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${data.tableName}', N'INDEX', N'${a.indexName}'` ); } }); @@ -306,7 +307,7 @@ class MssqlDialect implements DbDialect { addArr.push(` ALTER TABLE ${baseTable} ADD ${this.genColumnBasicSql(a)}`); if (a.remark) { addCommentArr.push( - `EXECUTE sp_addextendedproperty N'MS_Description', N'${a.remark}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'COLUMN', N'${a.name}'` + `EXECUTE sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.remark)}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'COLUMN', N'${a.name}'` ); } }); @@ -315,7 +316,7 @@ class MssqlDialect implements DbDialect { if (changeData.upd.length > 0) { changeData.upd.forEach((a) => { if (a.oldName && a.name !== a.oldName) { - renameArr.push(` EXEC sp_rename '${baseTable}.${this.quoteIdentifier(a.oldName)}', '${a.name}', 'COLUMN' `); + renameArr.push(` EXEC sp_rename '${baseTable}.${this.quoteIdentifier(a.oldName)}', '${QuoteEscape(a.name)}', 'COLUMN' `); } else { updArr.push(` ALTER TABLE ${baseTable} ALTER COLUMN ${this.genColumnBasicSql(a)} `); } @@ -325,13 +326,13 @@ class MssqlDialect implements DbDialect { 'TABLE', N'${tableName}', 'COLUMN', N'${a.name}')) > 0) EXEC sp_updateextendedproperty -'MS_Description', N'${a.remark}', +'MS_Description', N'${QuoteEscape(a.remark)}', 'SCHEMA', N'${schema}', 'TABLE', N'${tableName}', 'COLUMN', N'${a.name}' ELSE EXEC sp_addextendedproperty -'MS_Description', N'${a.remark}', +'MS_Description', N'${QuoteEscape(a.remark)}', 'SCHEMA', N'${schema}', 'TABLE', N'${tableName}', 'COLUMN',N'${a.name}'`); @@ -367,7 +368,7 @@ ELSE ); if (a.indexComment) { commentArr.push( - ` EXEC sp_addextendedproperty N'MS_Description', N'${a.indexComment}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'INDEX', N'${a.indexName}' ` + ` EXEC sp_addextendedproperty N'MS_Description', N'${QuoteEscape(a.indexComment)}', N'SCHEMA', N'${schema}', N'TABLE', N'${tableName}', N'INDEX', N'${a.indexName}' ` ); } }; @@ -413,7 +414,7 @@ ELSE if (tableData.oldTableComment !== tableData.tableComment) { // 转义注释中的单引号和换行符 - let tableComment = tableData.tableComment.replaceAll(/'/g, '').replaceAll(/[\r\n]/g, ' '); + let tableComment = tableData.tableComment.replaceAll(/'/g, "'").replaceAll(/[\r\n]/g, ' '); sql += `IF ((SELECT COUNT(*) FROM fn_listextendedproperty('MS_Description', 'SCHEMA', N'${schema}', 'TABLE', N'${tableData.tableName}', NULL, NULL)) > 0) diff --git a/mayfly_go_web/src/views/ops/db/dialect/mysql_dialect.ts b/mayfly_go_web/src/views/ops/db/dialect/mysql_dialect.ts index 43d8475e..70537d31 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/mysql_dialect.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/mysql_dialect.ts @@ -7,6 +7,7 @@ import { DuplicateStrategy, EditorCompletion, EditorCompletionItem, + QuoteEscape, IndexDefinition, RowDefinition, } from './index'; @@ -208,7 +209,7 @@ class MysqlDialect implements DbDialect { let onUpdate = 'update_time' === cl.name ? ' ON UPDATE CURRENT_TIMESTAMP ' : ''; return ` ${this.quoteIdentifier(cl.name)} ${cl.type}${length} ${cl.notNull ? 'NOT NULL' : 'NULL'} ${ cl.auto_increment ? 'AUTO_INCREMENT' : '' - } ${defVal} ${onUpdate} comment '${cl.remark || ''}' `; + } ${defVal} ${onUpdate} comment '${QuoteEscape(cl.remark)}' `; } getCreateTableSql(data: any): string { // 创建表结构 @@ -224,14 +225,14 @@ class MysqlDialect implements DbDialect { return `CREATE TABLE ${data.tableName} ( ${fields.join(',')} ${pks ? `, PRIMARY KEY (${pks.join(',')})` : ''} - ) COMMENT='${data.tableComment}';`; + ) COMMENT='${QuoteEscape(data.tableComment)}';`; } getCreateIndexSql(data: any): string { // 创建索引 let sql = `ALTER TABLE ${data.tableName}`; data.indexs.res.forEach((a: any) => { - sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${a.indexComment}',`; + sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${QuoteEscape(a.indexComment)}',`; }); return sql.substring(0, sql.length - 1) + ';'; } @@ -312,9 +313,9 @@ class MysqlDialect implements DbDialect { sql += ','; } addIndexs.forEach((a) => { - sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${ + sql += ` ADD ${a.unique ? 'UNIQUE' : ''} INDEX ${a.indexName}(${a.columnNames.join(',')}) USING ${a.indexType} COMMENT '${QuoteEscape( a.indexComment - }',`; + )}',`; }); sql = sql.substring(0, sql.length - 1); } @@ -326,7 +327,7 @@ class MysqlDialect implements DbDialect { getModifyTableInfoSql(tableData: any): string { let sql = ''; if (tableData.tableComment !== tableData.oldTableComment) { - sql += `ALTER TABLE ${this.quoteIdentifier(tableData.db)}.${this.quoteIdentifier(tableData.oldTableName)} COMMENT '${tableData.tableComment}';`; + sql += `ALTER TABLE ${this.quoteIdentifier(tableData.db)}.${this.quoteIdentifier(tableData.oldTableName)} COMMENT '${QuoteEscape(tableData.tableComment)}';`; } if (tableData.tableName !== tableData.oldTableName) { diff --git a/mayfly_go_web/src/views/ops/db/dialect/oracle_dialect.ts b/mayfly_go_web/src/views/ops/db/dialect/oracle_dialect.ts index 4f18bc44..79c70e36 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/oracle_dialect.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/oracle_dialect.ts @@ -7,6 +7,7 @@ import { DuplicateStrategy, EditorCompletion, EditorCompletionItem, + QuoteEscape, IndexDefinition, RowDefinition, sqlColumnType, @@ -324,7 +325,7 @@ class OracleDialect implements DbDialect { item.name && fields.push(this.genColumnBasicSql(item, true)); // 列注释 if (item.remark) { - columCommentSql += ` COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(item.name)} is '${item.remark}'; `; + columCommentSql += ` COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(item.name)} is '${QuoteEscape(item.remark)}'; `; } // 主键 if (item.pri) { @@ -340,7 +341,7 @@ class OracleDialect implements DbDialect { createSql = `CREATE TABLE ${dbTable} ( ${fields.join(',')} ${prisql ? ',' + prisql : ''} ) ;`; // 表注释 if (data.tableComment) { - tableCommentSql = ` COMMENT ON TABLE ${dbTable} is '${data.tableComment}'; `; + tableCommentSql = ` COMMENT ON TABLE ${dbTable} is '${QuoteEscape(data.tableComment)}'; `; } return createSql + tableCommentSql + columCommentSql; @@ -379,7 +380,7 @@ class OracleDialect implements DbDialect { if (changeData.upd.length > 0) { changeData.upd.forEach((a) => { - let commentSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${a.remark}'`; + let commentSql = `COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} IS '${QuoteEscape(a.remark)}'`; if (a.remark && a.oldName === a.name) { commentArr.push(commentSql); } @@ -401,7 +402,7 @@ class OracleDialect implements DbDialect { changeData.add.forEach((a) => { modifyArr.push(` ADD (${this.genColumnBasicSql(a, false)})`); if (a.remark) { - commentArr.push(`COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}'`); + commentArr.push(`COMMENT ON COLUMN ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}'`); } if (a.pri) { priArr.add(`"${a.name}"`); @@ -486,7 +487,7 @@ class OracleDialect implements DbDialect { let sql = ''; if (tableData.tableComment != tableData.oldTableComment) { let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`; - sql = `COMMENT ON TABLE ${dbTable} is '${tableData.tableComment}';`; + sql = `COMMENT ON TABLE ${dbTable} is '${QuoteEscape(tableData.tableComment)}';`; } if (tableData.tableName != tableData.oldTableName) { let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`; diff --git a/mayfly_go_web/src/views/ops/db/dialect/postgres_dialect.ts b/mayfly_go_web/src/views/ops/db/dialect/postgres_dialect.ts index e0414d45..528b7709 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/postgres_dialect.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/postgres_dialect.ts @@ -7,6 +7,7 @@ import { DuplicateStrategy, EditorCompletion, EditorCompletionItem, + QuoteEscape, IndexDefinition, RowDefinition, sqlColumnType, @@ -283,7 +284,7 @@ class PostgresqlDialect implements DbDialect { } // 列注释 if (item.remark) { - columCommentSql += ` comment on column ${data.tableName}.${item.name} is '${item.remark}'; `; + columCommentSql += ` comment on column ${data.tableName}.${item.name} is '${QuoteEscape(item.remark)}'; `; } }); // 建表 @@ -294,7 +295,7 @@ class PostgresqlDialect implements DbDialect { );`; // 表注释 if (data.tableComment) { - tableCommentSql = ` comment on table ${data.tableName} is '${data.tableComment}'; `; + tableCommentSql = ` comment on table ${data.tableName} is '${QuoteEscape(data.tableComment)}'; `; } return createSql + tableCommentSql + columCommentSql; @@ -312,7 +313,7 @@ class PostgresqlDialect implements DbDialect { let colArr = a.columnNames.map((a: string) => `${this.quoteIdentifier(a)}`); sql.push(`CREATE ${a.unique ? 'UNIQUE' : ''} INDEX ${this.quoteIdentifier(a.indexName)} on ${dbTable} (${colArr.join(',')})`); if (a.indexComment) { - sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${a.indexComment}'`); + sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${QuoteEscape(a.indexComment)}'`); } }); return sql.join(';'); @@ -334,14 +335,14 @@ class PostgresqlDialect implements DbDialect { changeData.add.forEach((a) => { modifySql += `alter table ${dbTable} add ${this.genColumnBasicSql(a)};`; if (a.remark) { - commentSql += `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}';`; + commentSql += `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}';`; } }); } if (changeData.upd.length > 0) { changeData.upd.forEach((a) => { - let cmtSql = `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${a.remark}';`; + let cmtSql = `comment on column ${dbTable}.${this.quoteIdentifier(a.name)} is '${QuoteEscape(a.remark)}';`; if (a.remark && a.oldName === a.name) { commentSql += cmtSql; } @@ -412,7 +413,7 @@ class PostgresqlDialect implements DbDialect { let colArr = a.columnNames.map((a: string) => `${this.quoteIdentifier(a)}`); sql.push(`CREATE ${a.unique ? 'UNIQUE' : ''} INDEX ${this.quoteIdentifier(a.indexName)} on ${dbTable} (${colArr.join(',')})`); if (a.indexComment) { - sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${a.indexComment}'`); + sql.push(`COMMENT ON INDEX ${schema}.${this.quoteIdentifier(a.indexName)} IS '${QuoteEscape(a.indexComment)}'`); } }); } @@ -428,7 +429,7 @@ class PostgresqlDialect implements DbDialect { let sql = ''; if (tableData.tableComment != tableData.oldTableComment) { let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`; - sql = `COMMENT ON TABLE ${dbTable} is '${tableData.tableComment}';`; + sql = `COMMENT ON TABLE ${dbTable} is '${QuoteEscape(tableData.tableComment)}';`; } if (tableData.tableName != tableData.oldTableName) { let dbTable = `${this.quoteIdentifier(schema)}.${this.quoteIdentifier(tableData.oldTableName)}`; diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index fb57c404..244477d0 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -17,6 +17,7 @@ import ( tagapp "mayfly-go/internal/tag/application" tagentity "mayfly-go/internal/tag/domain/entity" "mayfly-go/pkg/biz" + "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" "mayfly-go/pkg/model" "mayfly-go/pkg/req" @@ -24,6 +25,7 @@ import ( "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/ws" + "sort" "strconv" "strings" "time" @@ -336,47 +338,93 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName } } - for _, table := range tables { - writer.TryFlush() - quotedTable := dbMeta.QuoteIdentifier(table) - if needStruct { - writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) - writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable)) - ddl, err := dbMeta.GetTableDDL(table) - biz.ErrIsNil(err) - writer.WriteString(ddl + "\n") - } - if !needData { - continue - } - writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table)) + // 查询列信息,后面生成建表ddl和insert都需要列信息 + columns, err := dbMeta.GetColumns(tables...) - // 达梦不支持begin语句 - if dbConn.Info.Type != dbi.DbTypeDM { - writer.WriteString("BEGIN;\n") + // 以表名分组,存放每个表的列信息 + columnMap := make(map[string][]dbi.Column) + for _, column := range columns { + columnMap[column.TableName] = append(columnMap[column.TableName], column) + } + + // 按表名排序 + sort.Strings(tables) + + quoteSchema := dbMeta.QuoteIdentifier(dbConn.Info.CurrentSchema()) + + // 遍历获取每个表的信息 + for _, tableName := range tables { + quoteTableName := dbMeta.QuoteIdentifier(tableName) + + writer.TryFlush() + // 查询表信息,主要是为了查询表注释 + tbs, err := dbMeta.GetTables(tableName) + biz.ErrIsNil(err) + if err != nil || tbs == nil || len(tbs) <= 0 { + panic(errorx.NewBiz(fmt.Sprintf("获取表信息失败:%s", tableName))) } - insertSql := "INSERT INTO %s VALUES (%s);\n" - dbConn.WalkTableRows(ctx, table, func(record map[string]any, columns []*dbi.QueryColumn) error { - var values []string - writer.TryFlush() - for _, column := range columns { - value := record[column.Name] - if value == nil { - values = append(values, "NULL") - continue - } - strValue, ok := value.(string) - if ok { - strValue = dbMeta.QuoteLiteral(strValue) - values = append(values, strValue) - } else { - values = append(values, anyx.ToString(value)) - } + tabInfo := dbi.Table{ + TableName: tableName, + TableComment: tbs[0].TableComment, + } + + // 生成表结构信息 + if needStruct { + writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", tableName)) + tbDdlArr := dbMeta.GenerateTableDDL(columnMap[tableName], tabInfo, true) + for _, ddl := range tbDdlArr { + writer.WriteString(ddl + ";\n") } - writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", "))) - return nil - }) - writer.WriteString("COMMIT;\n") + } + + // 生成insert sql,数据在索引前,加速insert + if needData { + writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", tableName)) + + dbMeta.BeforeDumpInsert(writer, quoteTableName) + + // 获取列信息 + quoteColNames := make([]string, 0) + for _, col := range columnMap[tableName] { + quoteColNames = append(quoteColNames, dbMeta.QuoteIdentifier(col.ColumnName)) + } + + converter := dbMeta.GetDataConverter() + _ = dbConn.WalkTableRows(context.Background(), quoteTableName, func(row map[string]any, _ []*dbi.QueryColumn) error { + rowValues := make([]string, len(columnMap[tableName])) + for i, col := range columnMap[tableName] { + rowValues[i] = converter.WrapValue(row[col.ColumnName], converter.GetDataType(string(col.DataType))) + } + + beforeInsert := dbMeta.BeforeDumpInsertSql(quoteSchema, quoteTableName) + insertSQL := fmt.Sprintf("%s INSERT INTO %s (%s) values(%s)", beforeInsert, quoteTableName, strings.Join(quoteColNames, ", "), strings.Join(rowValues, ", ")) + writer.WriteString(insertSQL + ";\n") + return nil + }) + + dbMeta.AfterDumpInsert(writer, tableName, columnMap[tableName]) + } + + indexs, err := dbMeta.GetTableIndex(tableName) + biz.ErrIsNil(err) + + // 过滤主键索引 + idxs := make([]dbi.Index, 0) + for _, idx := range indexs { + if !idx.IsPrimaryKey { + idxs = append(idxs, idx) + } + } + + if len(idxs) > 0 { + // 最后添加索引 + writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表索引: %s \n-- ----------------------------\n", tableName)) + sqlArr := dbMeta.GenerateIndexDDL(idxs, tabInfo) + for _, sqlStr := range sqlArr { + writer.WriteString(sqlStr + ";\n") + } + } + } } @@ -450,7 +498,7 @@ func (d *Db) HintTables(rc *req.Ctx) { func (d *Db) GetTableDDL(rc *req.Ctx) { tn := rc.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn) + res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn, false) biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s") rc.ResData = res } diff --git a/server/internal/db/api/db_transfer.go b/server/internal/db/api/db_transfer.go index 3c6cacce..073443e8 100644 --- a/server/internal/db/api/db_transfer.go +++ b/server/internal/db/api/db_transfer.go @@ -12,6 +12,7 @@ import ( "mayfly-go/pkg/req" "strconv" "strings" + "time" ) type DbTransferTask struct { @@ -53,13 +54,14 @@ func (d *DbTransferTask) DeleteTask(rc *req.Ctx) { } func (d *DbTransferTask) Run(rc *req.Ctx) { + start := time.Now() taskId := d.changeState(rc, entity.DbTransferTaskRunStateRunning) go d.DbTransferTask.Run(taskId, func(msg string, err error) { // 修改状态为停止 if err != nil { logx.Error(msg, err) } else { - logx.Info(fmt.Sprintf("执行迁移完成,%s", msg)) + logx.Info(fmt.Sprintf("执行迁移完成,%s, 耗时:%v", msg, time.Since(start))) } // 修改任务状态 task := new(entity.DbTransferTask) diff --git a/server/internal/db/api/gzip_writer.go b/server/internal/db/api/gzip_writer.go index dde21ad1..0b2e0c31 100644 --- a/server/internal/db/api/gzip_writer.go +++ b/server/internal/db/api/gzip_writer.go @@ -26,6 +26,18 @@ func (g *gzipWriter) WriteString(data string) { } } +func (g *gzipWriter) Write(p []byte) (n int, err error) { + if g.aborted { + return + } + + if _, err := g.writer.Write(p); err != nil { + g.aborted = true + biz.IsTrue(false, "数据库导出失败:%s", err) + } + return +} + func (g *gzipWriter) Close() { g.writer.Close() } diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index 061aea20..1c007660 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -151,20 +151,21 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error { } task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ") - // 把UpdFieldVal尝试转为int,如果可以转为int,则不添加引号,否则添加引号 - if _, err = strconv.Atoi(task.UpdFieldVal); err != nil { - updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal) - } else { - updSql = fmt.Sprintf("and %s > %s", task.UpdField, task.UpdFieldVal) - } - // 如果是oracle且数据类型是时间类型,则需要加上to_date函数 - if srcConn.Info.Type == dbi.DbTypeOracle { - // 用正则判断数据类型是时间 + // 判断UpdFieldVal数据类型 + var updFieldValType dbi.DataType + if _, err = strconv.Atoi(task.UpdFieldVal); err != nil { if dateTimeReg.MatchString(task.UpdFieldVal) { - updSql = fmt.Sprintf("and %s > to_date('%s','yyyy-mm-dd hh24:mi:ss')", task.UpdField, task.UpdFieldVal) + updFieldValType = dbi.DataTypeDateTime + } else { + updFieldValType = dbi.DataTypeString } + } else { + updFieldValType = dbi.DataTypeNumber } + wrapUpdFieldVal := srcConn.GetMetaData().GetDataConverter().WrapValue(task.UpdFieldVal, updFieldValType) + updSql = fmt.Sprintf("and %s > %s", task.UpdField, wrapUpdFieldVal) + orderSql = "order by " + task.UpdField + " asc " } // 正则判断DataSql是否以where .*结尾,如果是则不添加where 1 = 1 diff --git a/server/internal/db/application/db_transfer.go b/server/internal/db/application/db_transfer.go index 91f46f57..40851903 100644 --- a/server/internal/db/application/db_transfer.go +++ b/server/internal/db/application/db_transfer.go @@ -10,6 +10,8 @@ import ( "mayfly-go/pkg/gormx" "mayfly-go/pkg/logx" "mayfly-go/pkg/model" + "mayfly-go/pkg/utils/collx" + "sort" "strings" ) @@ -153,15 +155,21 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon end("获取源表列信息失败", err) return } + // 以表名分组,存放每个表的列信息 columnMap := make(map[string][]dbi.Column) for _, column := range columns { columnMap[column.TableName] = append(columnMap[column.TableName], column) } + // 以表名排序 + sortTableNames := collx.MapKeys(columnMap) + sort.Strings(sortTableNames) + ctx := context.Background() - for tbName, cols := range columnMap { + for _, tbName := range sortTableNames { + cols := columnMap[tbName] targetCols := make([]dbi.Column, 0) for _, col := range cols { colPtr := &col @@ -183,7 +191,7 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon // 迁移数据 logx.Infof("开始迁移数据: 表名:%s", tbName) - total, err := app.transferData(ctx, tbName, srcConn, srcDialect, targetConn, targetDialect) + total, err := app.transferData(ctx, tbName, targetCols, srcConn, srcDialect, targetConn, targetDialect) if err != nil { end(fmt.Sprintf("迁移数据失败: 表名:%s, error: %s", tbName, err.Error()), err) return @@ -216,27 +224,16 @@ func (app *dbTransferAppImpl) transferTables(task *entity.DbTransferTask, srcCon } } -func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string, srcConn *dbi.DbConn, srcDialect dbi.Dialect, targetConn *dbi.DbConn, targetDialect dbi.Dialect) (int, error) { +func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string, targetColumns []dbi.Column, srcConn *dbi.DbConn, srcDialect dbi.Dialect, targetConn *dbi.DbConn, targetDialect dbi.Dialect) (int, error) { result := make([]map[string]any, 0) total := 0 // 总条数 batchSize := 1000 // 每次查询并迁移1000条数据 - var queryColumns []*dbi.QueryColumn var err error srcMeta := srcConn.GetMetaData() srcConverter := srcMeta.GetDataConverter() // 游标查询源表数据,并批量插入目标表 err = srcConn.WalkTableRows(ctx, tableName, func(row map[string]any, columns []*dbi.QueryColumn) error { - if len(queryColumns) == 0 { - - for _, col := range columns { - queryColumns = append(queryColumns, &dbi.QueryColumn{ - Name: targetConn.GetMetaData().QuoteIdentifier(col.Name), - Type: col.Type, - }) - } - - } total++ rawValue := map[string]any{} for _, column := range columns { @@ -246,7 +243,7 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string } result = append(result, rawValue) if total%batchSize == 0 { - err = app.transfer2Target(targetConn, queryColumns, result, targetDialect, tableName) + err = app.transfer2Target(targetConn, targetColumns, result, targetDialect, tableName) if err != nil { logx.Error("批量插入目标表数据失败", err) return err @@ -257,7 +254,7 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string }) // 处理剩余的数据 if len(result) > 0 { - err = app.transfer2Target(targetConn, queryColumns, result, targetDialect, tableName) + err = app.transfer2Target(targetConn, targetColumns, result, targetDialect, tableName) if err != nil { logx.Error(fmt.Sprintf("批量插入目标表数据失败,表名:%s", tableName), err) return 0, err @@ -266,23 +263,36 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, tableName string return total, err } -func (app *dbTransferAppImpl) transfer2Target(targetConn *dbi.DbConn, cols []*dbi.QueryColumn, result []map[string]any, targetDialect dbi.Dialect, tbName string) error { +func (app *dbTransferAppImpl) transfer2Target(targetConn *dbi.DbConn, targetColumns []dbi.Column, result []map[string]any, targetDialect dbi.Dialect, tbName string) error { tx, err := targetConn.Begin() if err != nil { return err } + targetMeta := targetConn.GetMetaData() + // 收集字段名 var columnNames []string - for _, col := range cols { - columnNames = append(columnNames, col.Name) + for _, col := range targetColumns { + columnNames = append(columnNames, targetMeta.QuoteIdentifier(col.ColumnName)) } // 从目标库数据中取出源库字段对应的值 values := make([][]any, 0) for _, record := range result { rawValue := make([]any, 0) - for _, cn := range columnNames { - rawValue = append(rawValue, record[targetConn.GetMetaData().RemoveQuote(cn)]) + for _, tc := range targetColumns { + columnName := tc.ColumnName + val := record[targetMeta.RemoveQuote(columnName)] + if !tc.Nullable { + // 如果val是文本,则设置为空格字符 + switch val.(type) { + case string: + if val == "" { + val = " " + } + } + } + rawValue = append(rawValue, val) } values = append(values, rawValue) } @@ -312,6 +322,18 @@ func (app *dbTransferAppImpl) transferIndex(_ context.Context, tableInfo dbi.Tab return nil } + // 过滤主键索引 + idxs := make([]dbi.Index, 0) + for _, idx := range indexs { + if !idx.IsPrimaryKey { + idxs = append(idxs, idx) + } + } + + if len(idxs) == 0 { + return nil + } + // 通过表名、索引信息生成建索引语句,并执行到目标表 - return targetDialect.CreateIndex(tableInfo, indexs) + return targetDialect.CreateIndex(tableInfo, idxs) } diff --git a/server/internal/db/dbm/dbi/metadata.go b/server/internal/db/dbm/dbi/metadata.go index d0417a0c..76890fcd 100644 --- a/server/internal/db/dbm/dbi/metadata.go +++ b/server/internal/db/dbm/dbi/metadata.go @@ -25,6 +25,9 @@ type MetaData interface { // 获取指定表名的所有列元信息 GetColumns(tableNames ...string) ([]Column, error) + // 根据数据库类型修复字段长度、精度等 + FixColumn(column *Column) + // 获取表主键字段名,没有主键标识则默认第一个字段 GetPrimaryKey(tableName string) (string, error) @@ -32,7 +35,7 @@ type MetaData interface { GetTableIndex(tableName string) ([]Index, error) // 获取建表ddl - GetTableDDL(tableName string) (string, error) + GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string @@ -44,6 +47,9 @@ type MetaData interface { GetDataConverter() DataConverter } +// GenerateSQLStepFunc 生成insert sql的step函数,用于生成insert sql时,每生成100条sql时调用 +type GenerateSQLStepFunc func(sqlArr []string) + // 数据库服务实例信息 type DbServer struct { Version string `json:"version"` // 版本信息 @@ -100,6 +106,7 @@ type Index struct { IndexComment string `json:"indexComment"` // 备注 SeqInIndex int `json:"seqInIndex"` IsUnique bool `json:"isUnique"` + IsPrimaryKey bool `json:"isPrimaryKey"` // 是否是主键索引,某些情况需要判断并过滤掉主键索引 } type ColumnDataType string @@ -152,6 +159,13 @@ type DataConverter interface { // 根据数据类型解析数据为符合要求的指定类型等 ParseData(dbColumnValue any, dataType DataType) any + + // WrapValue 根据数据类型包装value,如: + // 1.数字型:不需要引号, + // 2.文本型:需要用引号包裹,单引号需要转义,换行符转义, + // 3.date型:需要格式化成对应的字符串,如:time:hh:mm:ss.SSS date: yyyy-mm-dd datetime: + // 4.特殊:oracle date型需要用函数包裹:to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss') + WrapValue(dbColumnValue any, dataType DataType) string } // ------------------------- 元数据sql操作 ------------------------- diff --git a/server/internal/db/dbm/dbi/metadata_base.go b/server/internal/db/dbm/dbi/metadata_base.go index 318d4f43..0f65f6fa 100644 --- a/server/internal/db/dbm/dbi/metadata_base.go +++ b/server/internal/db/dbm/dbi/metadata_base.go @@ -3,6 +3,8 @@ package dbi import ( pq "gitee.com/liuzongyang/libpq" "github.com/kanzihuang/vitess/go/vt/sqlparser" + "io" + "strings" ) type BaseMetaData interface { @@ -13,6 +15,9 @@ type BaseMetaData interface { // 用于引用 SQL 标识符(关键字)的字符串 GetIdentifierQuoteString() string + // 引号转义,多用于sql注释转义,防止拼接sql报错,如: comment xx is '注''释' 最终注释文本为: 注'释 + QuoteEscape(str string) string + // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal // to DDL and other statements that do not accept parameters) to be used as part // of an SQL statement. For example: @@ -25,6 +30,12 @@ type BaseMetaData interface { QuoteLiteral(literal string) string SqlParserDialect() sqlparser.Dialect + + BeforeDumpInsert(writer io.Writer, tableName string) + + BeforeDumpInsertSql(quoteSchema string, quoteTableName string) string + + AfterDumpInsert(writer io.Writer, tableName string, columns []Column) } // 默认实现,若需要覆盖,则由各个数据库MetaData实现去覆盖重写 @@ -39,6 +50,10 @@ func (dd *DefaultMetaData) GetIdentifierQuoteString() string { return `"` } +func (dd *DefaultMetaData) QuoteEscape(str string) string { + return strings.Replace(str, `'`, `''`, -1) +} + func (dd *DefaultMetaData) QuoteLiteral(literal string) string { return pq.QuoteLiteral(literal) } @@ -46,3 +61,13 @@ func (dd *DefaultMetaData) QuoteLiteral(literal string) string { func (dd *DefaultMetaData) SqlParserDialect() sqlparser.Dialect { return sqlparser.PostgresDialect{} } + +func (dd *DefaultMetaData) BeforeDumpInsert(writer io.Writer, tableName string) { + writer.Write([]byte("BEGIN;\n")) +} +func (dd *DefaultMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string { + return "" +} +func (dd *DefaultMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []Column) { + writer.Write([]byte("COMMIT;\n")) +} diff --git a/server/internal/db/dbm/dbi/metasql/dm_meta.sql b/server/internal/db/dbm/dbi/metasql/dm_meta.sql index 23e66164..c762a9f9 100644 --- a/server/internal/db/dbm/dbi/metasql/dm_meta.sql +++ b/server/internal/db/dbm/dbi/metasql/dm_meta.sql @@ -52,7 +52,8 @@ WHERE a.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID)) order by a.TABLE_NAME, a.index_name, c.column_position asc --------------------------------------- --DM_COLUMN_MA 表列信息 -select a.table_name as TABLE_NAME, +select a.owner, + a.table_name as TABLE_NAME, a.column_name as COLUMN_NAME, case when a.NULLABLE = 'Y' then 'YES' when a.NULLABLE = 'N' then 'NO' else 'NO' end as NULLABLE, a.data_type as DATA_TYPE, @@ -61,25 +62,21 @@ select a.table_name a.data_scale as NUM_SCALE, b.comments as COLUMN_COMMENT, a.data_default as COLUMN_DEFAULT, - case when t.COL_NAME = a.column_name then 1 else 0 end as IS_IDENTITY, + case when t.INFO2 & 0x01 = 0x01 then 1 else 0 end as IS_IDENTITY, case when t2.constraint_type = 'P' then 1 else 0 end as IS_PRIMARY_KEY from all_tab_columns a left join user_col_comments b on b.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID)) and b.table_name = a.table_name and a.column_name = b.column_name - left join (select b.owner, b.TABLE_NAME, a.NAME as COL_NAME - from SYS.SYSCOLUMNS a, - SYS.all_tables b, - SYS.SYSOBJECTS c - where a.INFO2 & 0x01 = 0x01 - and a.ID = c.ID - and c.NAME = b.TABLE_NAME) t - on t.table_name = a.table_name and t.owner = a.owner + left join (select c1.*, c2.object_name, c2.owner + FROM SYS.SYSCOLUMNS c1 + join SYS.all_objects c2 on c1.id = c2.object_id and c2.object_type = 'TABLE') t + on t.object_name = a.table_name and t.owner = a.owner and t.NAME = a.column_name left join (select uc.OWNER, uic.column_name, uic.table_name, uc.constraint_type from user_ind_columns uic left join user_constraints uc on uic.index_name = uc.index_name) t2 - on t2.table_name = t.table_name and a.column_name = t2.column_name + on t2.table_name = t.object_name and a.column_name = t2.column_name and t2.OWNER = a.owner where a.owner = (SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID)) and a.table_name in (%s) order by a.table_name, diff --git a/server/internal/db/dbm/dbi/metasql/mssql_meta.sql b/server/internal/db/dbm/dbi/metasql/mssql_meta.sql index 8be20c40..ea7633eb 100644 --- a/server/internal/db/dbm/dbi/metasql/mssql_meta.sql +++ b/server/internal/db/dbm/dbi/metasql/mssql_meta.sql @@ -35,7 +35,7 @@ where ss.name = ? {{if .tableNames}} and t.name in ({{.tableNames}}) {{end}} -ORDER BY t.name DESC; +ORDER BY t.name ASC; --------------------------------------- --MSSQL_INDEX_INFO 索引信息 SELECT ind.name AS indexName, @@ -46,7 +46,11 @@ SELECT ind.name AS indexName, END AS indexType, IIF(ind.is_unique = 'true', 1, 0) AS isUnique, ic.key_ordinal AS seqInIndex, - idx.value AS indexComment + idx.value AS indexComment, + CASE + WHEN LEFT(ind.name, 3) = 'PK_' THEN 1 + ELSE 0 + END AS isPrimaryKey FROM sys.indexes ind LEFT JOIN sys.tables t on t.object_id = ind.object_id LEFT JOIN sys.schemas ss on t.schema_id = ss.schema_id diff --git a/server/internal/db/dbm/dbi/metasql/mysql_meta.sql b/server/internal/db/dbm/dbi/metasql/mysql_meta.sql index 8265e734..cb29c292 100644 --- a/server/internal/db/dbm/dbi/metasql/mysql_meta.sql +++ b/server/internal/db/dbm/dbi/metasql/mysql_meta.sql @@ -35,7 +35,8 @@ SELECT index_type indexType, IF(non_unique, 0, 1) isUnique, SEQ_IN_INDEX seqInIndex, - INDEX_COMMENT indexComment + INDEX_COMMENT indexComment, + index_name = 'PRIMARY' as isPrimaryKey FROM information_schema.STATISTICS WHERE diff --git a/server/internal/db/dbm/dbi/metasql/oracle_meta.sql b/server/internal/db/dbm/dbi/metasql/oracle_meta.sql index eafb7ff4..0477121a 100644 --- a/server/internal/db/dbm/dbi/metasql/oracle_meta.sql +++ b/server/internal/db/dbm/dbi/metasql/oracle_meta.sql @@ -38,7 +38,12 @@ SELECT ai.INDEX_NAME AS INDEX_NAME, WHERE aic.INDEX_OWNER = ai.OWNER AND aic.INDEX_NAME = ai.INDEX_NAME AND aic.TABLE_NAME = ai.TABLE_NAME - AND ROWNUM = 1) AS INDEX_COMMENT + AND ROWNUM = 1) AS INDEX_COMMENT, + CASE + WHEN ai.INDEX_NAME like 'PK_%%' THEN 1 + WHEN ai.INDEX_NAME like 'SYS_%%' THEN 1 + ELSE 0 + END AS IS_PRIMARY FROM ALL_INDEXES ai WHERE ai.OWNER = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM DUAL) AND ai.table_name = '%s' @@ -50,14 +55,14 @@ SELECT a.TABLE_NAME as TABLE_NAME, when a.NULLABLE = 'Y' then 'YES' when a.NULLABLE = 'N' then 'NO' else 'NO' end as NULLABLE, - case - when a.DATA_PRECISION > 0 then a.DATA_TYPE - else (a.DATA_TYPE || '(' || a.DATA_LENGTH || ')') end as COLUMN_TYPE, + a.DATA_TYPE as DATA_TYPE, + a.DATA_LENGTH as CHAR_MAX_LENGTH, + a.DATA_PRECISION as NUM_PRECISION, + a.DATA_SCALE as NUM_SCALE, b.COMMENTS as COLUMN_COMMENT, a.DATA_DEFAULT as COLUMN_DEFAULT, - a.DATA_SCALE as NUM_SCALE, - CASE WHEN d.pri IS NOT NULL THEN 1 ELSE 0 END as IS_PRIMARY_KEY, - CASE WHEN a.IDENTITY_COLUMN = 'YES' THEN 1 ELSE 0 END as IS_IDENTITY + CASE WHEN d.pri IS NOT NULL THEN 1 ELSE 0 END as IS_PRIMARY_KEY, + CASE WHEN a.IDENTITY_COLUMN = 'YES' THEN 1 ELSE 0 END as IS_IDENTITY FROM ALL_TAB_COLUMNS a LEFT JOIN ALL_COL_COMMENTS b on a.OWNER = b.OWNER AND a.TABLE_NAME = b.TABLE_NAME AND a.COLUMN_NAME = b.COLUMN_NAME diff --git a/server/internal/db/dbm/dbi/metasql/pgsql_meta.sql b/server/internal/db/dbm/dbi/metasql/pgsql_meta.sql index d9281131..d6ebf186 100644 --- a/server/internal/db/dbm/dbi/metasql/pgsql_meta.sql +++ b/server/internal/db/dbm/dbi/metasql/pgsql_meta.sql @@ -35,29 +35,28 @@ where order by c.relname --------------------------------------- --PGSQL_INDEX_INFO 表索引信息 -SELECT - indexname AS "indexName", - 'BTREE' AS "IndexType", - case when indexdef like 'CREATE UNIQUE INDEX%%' then 1 else 0 end as "isUnique", - obj_description(b.oid, 'pg_class') AS "indexComment", - indexdef AS "indexDef", - c.attname AS "columnName", - c.attnum AS "seqInIndex" +SELECT indexname AS "indexName", + 'BTREE' AS "IndexType", + case when indexdef like 'CREATE UNIQUE INDEX%%' then 1 else 0 end as "isUnique", + obj_description(b.oid, 'pg_class') AS "indexComment", + indexdef AS "indexDef", + c.attname AS "columnName", + c.attnum AS "seqInIndex", + case when indexname like '%_pkey' then 1 else 0 end as "isPrimaryKey" FROM pg_indexes a - join pg_class b on a.indexname = b.relname - join pg_attribute c on b.oid = c.attrelid + join pg_class b on a.indexname = b.relname + join pg_attribute c on b.oid = c.attrelid WHERE a.schemaname = (select current_schema()) AND a.tablename = '%s'; --------------------------------------- --PGSQL_COLUMN_MA 表列信息 -SELECT a.*, - a.table_name AS "tableName", +SELECT a.table_name AS "tableName", a.column_name AS "columnName", a.is_nullable AS "nullable", a.udt_name AS "dataType", a.character_maximum_length AS "charMaxLength", a.numeric_precision AS "numPrecision", - a.column_default AS "columnDefault", + case when a.column_default like 'nextval%%' then null else a.column_default end AS "columnDefault", a.numeric_scale AS "numScale", case when a.column_default like 'nextval%%' then 1 else 0 end AS "isIdentity", case when b.column_name is not null then 1 else 0 end AS "isPrimaryKey", diff --git a/server/internal/db/dbm/dm/dialect.go b/server/internal/db/dbm/dm/dialect.go index 199b5862..8aab3581 100644 --- a/server/internal/db/dbm/dm/dialect.go +++ b/server/internal/db/dbm/dm/dialect.go @@ -127,7 +127,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName metadata := dd.dc.GetMetaData() - ddl, err := metadata.GetTableDDL(tableName) + ddl, err := metadata.GetTableDDL(tableName, false) if err != nil { return err } @@ -177,21 +177,14 @@ func (dd *DMDialect) ToCommonColumn(dialectColumn *dbi.Column) { func (dd *DMDialect) ToColumn(commonColumn *dbi.Column) { ctype := dmColumnTypeMap[commonColumn.DataType] + meta := dd.dc.GetMetaData() if ctype == "" { commonColumn.DataType = "VARCHAR" commonColumn.CharMaxLength = 2000 } else { commonColumn.DataType = dbi.ColumnDataType(ctype) - // 如果是date,不设长度 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(ctype)) { - commonColumn.CharMaxLength = 0 - commonColumn.NumPrecision = 0 - } else - // 如果是char且长度未设置,则默认长度2000 - if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(ctype)) && commonColumn.CharMaxLength == 0 { - commonColumn.CharMaxLength = 2000 - } + meta.FixColumn(commonColumn) } } @@ -212,7 +205,15 @@ func (dd *DMDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, drop } func (dd *DMDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - sqls := dd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) - _, err := dd.dc.Exec(strings.Join(sqls, ";")) - return err + sqlArr := dd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) + // 达梦需要分开执行sql + if len(sqlArr) > 0 { + for _, sqlStr := range sqlArr { + _, err := dd.dc.Exec(sqlStr) + if err != nil { + return err + } + } + } + return nil } diff --git a/server/internal/db/dbm/dm/metadata.go b/server/internal/db/dbm/dm/metadata.go index 90e28e26..ad3b9507 100644 --- a/server/internal/db/dbm/dm/metadata.go +++ b/server/internal/db/dbm/dm/metadata.go @@ -2,6 +2,7 @@ package dm import ( "fmt" + "io" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" @@ -112,11 +113,24 @@ func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { NumPrecision: cast.ToInt(re["NUM_PRECISION"]), NumScale: cast.ToInt(re["NUM_SCALE"]), } + dd.FixColumn(&column) columns = append(columns, column) } return columns, nil } +func (dd *DMMetaData) FixColumn(column *dbi.Column) { + // 如果是date,不设长度 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) { + column.CharMaxLength = 0 + column.NumPrecision = 0 + } else + // 如果是char且长度未设置,则默认长度2000 + if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(string(column.DataType))) && column.CharMaxLength == 0 { + column.CharMaxLength = 2000 + } +} + func (dd *DMMetaData) GetPrimaryKey(tablename string) (string, error) { columns, err := dd.GetColumns(tablename) if err != nil { @@ -150,6 +164,7 @@ func (dd *DMMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { IndexComment: cast.ToString(re["INDEX_COMMENT"]), IsUnique: cast.ToInt(re["IS_UNIQUE"]) == 1, SeqInIndex: cast.ToInt(re["SEQ_IN_INDEX"]), + IsPrimaryKey: false, }) } // 把查询结果以索引名分组,索引字段以逗号连接 @@ -206,7 +221,7 @@ func (dd *DMMetaData) genColumnBasicSql(column dbi.Column) string { } } - columnSql := fmt.Sprintf(" %s %s %s %s %s", colName, column.GetColumnType(), incr, nullAble, defVal) + columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) return columnSql } @@ -215,27 +230,25 @@ func (dd *DMMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) meta := dd.dc.GetMetaData() sqls := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName) - sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, indexName, meta.QuoteIdentifier(tableInfo.TableName), index.ColumnName)) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) + } + + sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, index.IndexName, meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) } return sqls } func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { meta := dd.dc.GetMetaData() - replacer := strings.NewReplacer(";", "", "'", "") tbName := meta.QuoteIdentifier(tableInfo.TableName) sqlArr := make([]string, 0) @@ -253,25 +266,21 @@ func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table pks = append(pks, meta.QuoteIdentifier(column.ColumnName)) } fields = append(fields, dd.genColumnBasicSql(column)) - // 防止注释内含有特殊字符串导致sql出错 if column.ColumnComment != "" { - comment := replacer.Replace(column.ColumnComment) + comment := meta.QuoteEscape(column.ColumnComment) columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, meta.QuoteIdentifier(column.ColumnName), comment)) } } - createSql += strings.Join(fields, ",") + createSql += strings.Join(fields, ",\n") if len(pks) > 0 { - createSql += fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(pks, ",")) + createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ",")) } - createSql += ")" + createSql += "\n)" tableCommentSql := "" if tableInfo.TableComment != "" { - // 防止注释内含有特殊字符串导致sql出错 - comment := replacer.Replace(tableInfo.TableComment) - if comment != "" { - tableCommentSql = fmt.Sprintf(" comment on table %s is '%s'", tbName, comment) - } + comment := meta.QuoteEscape(tableInfo.TableComment) + tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment) } sqlArr = append(sqlArr, createSql) @@ -287,13 +296,12 @@ func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table } // 获取建表ddl -func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) { +func (dd *DMMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := dd.GetTables(tableName) tableInfo := &dbi.Table{} - if err != nil && len(tbs) > 0 { - + if err != nil || tbs == nil || len(tbs) <= 0 { logx.Errorf("获取表信息失败, %s", tableName) return "", err } @@ -306,7 +314,7 @@ func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) { logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := dd.GenerateTableDDL(columns, *tableInfo, false) + tableDDLArr := dd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := dd.GetTableIndex(tableName) if err != nil { @@ -315,7 +323,7 @@ func (dd *DMMetaData) GetTableDDL(tableName string) (string, error) { } // 组装返回 tableDDLArr = append(tableDDLArr, dd.GenerateIndexDDL(indexs, *tableInfo)...) - return strings.Join(tableDDLArr, ";"), nil + return strings.Join(tableDDLArr, ";\n"), nil } // 获取DM当前连接的库可访问的schemaNames @@ -332,6 +340,18 @@ func (dd *DMMetaData) GetSchemas() ([]string, error) { return schemaNames, nil } +func (dd *DMMetaData) BeforeDumpInsert(writer io.Writer, tableName string) { + +} + +func (dd *DMMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string { + return fmt.Sprintf("set identity_insert %s on;", tableName) +} + +func (dd *DMMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) { + writer.Write([]byte("COMMIT;\n")) +} + func (dd *DMMetaData) GetDataConverter() dbi.DataConverter { return converter } @@ -406,7 +426,7 @@ var ( type DataConverter struct { } -func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType { +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { if numberRegexp.MatchString(dbColumnType) { return dbi.DataTypeNumber } @@ -422,23 +442,38 @@ func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType { return dbi.DataTypeString } -func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { str := anyx.ToString(dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.TimeOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } return str } -func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 _, ok := dbColumnValue.(string) if ok { @@ -457,3 +492,24 @@ func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any } return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + val = strings.Replace(val, `\''`, `\'`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType)) + } + return fmt.Sprintf("'%s'", dbColumnValue) +} diff --git a/server/internal/db/dbm/mssql/dialect.go b/server/internal/db/dbm/mssql/dialect.go index dddc3bc2..b8da58ed 100644 --- a/server/internal/db/dbm/mssql/dialect.go +++ b/server/internal/db/dbm/mssql/dialect.go @@ -253,35 +253,14 @@ func (md *MssqlDialect) ToCommonColumn(dialectColumn *dbi.Column) { func (md *MssqlDialect) ToColumn(commonColumn *dbi.Column) { ctype := mssqlColumnTypeMap[commonColumn.DataType] + meta := md.dc.GetMetaData() if ctype == "" { commonColumn.DataType = "varchar" commonColumn.CharMaxLength = 2000 } else { commonColumn.DataType = dbi.ColumnDataType(ctype) - - if strings.Contains(strings.ToLower(ctype), "int") { - // 如果类型是数字,类型后不需要带长度 - commonColumn.CharMaxLength = 0 - commonColumn.NumPrecision = 0 - } else if collx.ArrayAnyMatches([]string{"float", "number", "decimal"}, strings.ToLower(ctype)) { - // 如果是float,最大长度为38 - if commonColumn.CharMaxLength > 38 { - commonColumn.CharMaxLength = 38 - } - if commonColumn.NumPrecision > 38 { - commonColumn.NumPrecision = 38 - } - } else if strings.Contains(strings.ToLower(ctype), "char") { - // 如果是字符串类型,长度最大4000,否则修改字段类型为text - if commonColumn.CharMaxLength > 4000 { - commonColumn.DataType = "text" - commonColumn.CharMaxLength = 0 - } - } else if strings.Contains(strings.ToLower(ctype), "text") { - // 如果是text,取消长度 - commonColumn.CharMaxLength = 0 - } + meta.FixColumn(commonColumn) } } diff --git a/server/internal/db/dbm/mssql/metadata.go b/server/internal/db/dbm/mssql/metadata.go index 8cda438a..9ab9c8d0 100644 --- a/server/internal/db/dbm/mssql/metadata.go +++ b/server/internal/db/dbm/mssql/metadata.go @@ -119,28 +119,32 @@ func (md *MssqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) NumScale: cast.ToInt(re["NUM_SCALE"]), } - dataType := strings.ToLower(string(column.DataType)) - - if collx.ArrayAnyMatches([]string{"date", "time"}, dataType) { - // 如果是datetime,精度取NumScale字段 - column.CharMaxLength = column.NumScale - } else if collx.ArrayAnyMatches([]string{"int", "bit", "real", "text", "xml"}, dataType) { - // 不显示长度的类型 - column.NumPrecision = 0 - column.CharMaxLength = 0 - } else if collx.ArrayAnyMatches([]string{"numeric", "decimal", "float"}, dataType) { - // 如果是num,长度取精度和小数位数 - column.CharMaxLength = 0 - } else if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) { - // 如果是nvarchar,可视长度减半 - column.CharMaxLength = column.CharMaxLength / 2 - } + md.FixColumn(&column) columns = append(columns, column) } return columns, nil } +func (md *MssqlMetaData) FixColumn(column *dbi.Column) { + dataType := strings.ToLower(string(column.DataType)) + + if collx.ArrayAnyMatches([]string{"date", "time"}, dataType) { + // 如果是datetime,精度取NumScale字段 + column.CharMaxLength = column.NumScale + } else if collx.ArrayAnyMatches([]string{"int", "bit", "real", "text", "xml"}, dataType) { + // 不显示长度的类型 + column.NumPrecision = 0 + column.CharMaxLength = 0 + } else if collx.ArrayAnyMatches([]string{"numeric", "decimal", "float"}, dataType) { + // 如果是num,长度取精度和小数位数 + column.CharMaxLength = 0 + } else if collx.ArrayAnyMatches([]string{"nvarchar", "nchar"}, dataType) { + // 如果是nvarchar,可视长度减半 + column.CharMaxLength = column.CharMaxLength / 2 + } +} + // 获取表主键字段名,不存在主键标识则默认第一个字段 func (md *MssqlMetaData) GetPrimaryKey(tablename string) (string, error) { columns, err := md.GetColumns(tablename) @@ -174,6 +178,7 @@ func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, err IndexComment: cast.ToString(re["indexComment"]), IsUnique: cast.ToInt(re["isUnique"]) == 1, SeqInIndex: cast.ToInt(re["seqInIndex"]), + IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1, }) } // 把查询结果以索引名分组,多个索引字段以逗号连接 @@ -199,10 +204,9 @@ func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, err func (md *MssqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { indexs, _ := md.getTableIndexWithPK(tableName) result := make([]dbi.Index, 0) - // 过滤掉主键索引,主键索引名为PK__开头的 + // 过滤掉主键索引 for _, v := range indexs { - in := v.IndexName - if strings.HasPrefix(in, "PK__") { + if v.IsPrimaryKey { continue } result = append(result, v) @@ -248,25 +252,25 @@ func (md *MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (st func (md *MssqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { tbName := tableInfo.TableName + meta := md.dc.GetMetaData() sqls := make([]string, 0) comments := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tbName, colName) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) + } - sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, indexName, md.dc.Info.CurrentSchema(), tbName, index.ColumnName)) + sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, index.IndexName, md.dc.Info.CurrentSchema(), tbName, strings.Join(colNames, ","))) if index.IndexComment != "" { - comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", index.IndexComment, md.dc.Info.CurrentSchema(), tbName, indexName)) + comment := meta.QuoteEscape(index.IndexComment) + comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName)) } } if len(comments) > 0 { @@ -312,7 +316,7 @@ func (md *MssqlMetaData) genColumnBasicSql(column dbi.Column) string { } } - columnSql := fmt.Sprintf(" %s %s %s %s %s", colName, column.GetColumnType(), incr, nullAble, defVal) + columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) return columnSql } @@ -320,7 +324,6 @@ func (md *MssqlMetaData) genColumnBasicSql(column dbi.Column) string { func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { tbName := tableInfo.TableName meta := md.dc.GetMetaData() - replacer := strings.NewReplacer(";", "", "'", "") sqlArr := make([]string, 0) @@ -344,7 +347,7 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta // 防止注释内含有特殊字符串导致sql出错 if column.ColumnComment != "" { - comment := replacer.Replace(column.ColumnComment) + comment := meta.QuoteEscape(column.ColumnComment) columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName)) } } @@ -360,7 +363,8 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta tableCommentSql := "" if tableInfo.TableComment != "" { commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'" - tableCommentSql = fmt.Sprintf(commentTmp, replacer.Replace(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName) + + tableCommentSql = fmt.Sprintf(commentTmp, meta.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName) } sqlArr = append(sqlArr, createSql) @@ -376,13 +380,12 @@ func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta } // 获取建表ddl -func (md *MssqlMetaData) GetTableDDL(tableName string) (string, error) { +func (md *MssqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := md.GetTables(tableName) tableInfo := &dbi.Table{} - if err != nil && len(tbs) > 0 { - + if err != nil || tbs == nil || len(tbs) <= 0 { logx.Errorf("获取表信息失败, %s", tableName) return "", err } @@ -395,7 +398,7 @@ func (md *MssqlMetaData) GetTableDDL(tableName string) (string, error) { logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, false) + tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := md.GetTableIndex(tableName) if err != nil { @@ -424,6 +427,10 @@ func (md *MssqlMetaData) GetIdentifierQuoteString() string { return "[" } +func (md *MssqlMetaData) BeforeDumpInsertSql(quoteSchema string, tableName string) string { + return fmt.Sprintf("set identity_insert %s.%s on ", quoteSchema, tableName) +} + func (md *MssqlMetaData) GetDataConverter() dbi.DataConverter { return converter } @@ -528,6 +535,26 @@ func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { } func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + // 如果dataType是datetime而dbColumnValue是string类型,则需要根据类型格式化 + str, ok := dbColumnValue.(string) + if dataType == dbi.DataTypeDateTime && ok { + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + } + if dataType == dbi.DataTypeDate && ok { + // 尝试用时间格式解析 + res, _ := time.Parse(time.DateOnly, str) + return res.Format(time.DateOnly) + } + if dataType == dbi.DataTypeTime && ok { + res, _ := time.Parse(time.TimeOnly, str) + return res.Format(time.TimeOnly) + } return anyx.ToString(dbColumnValue) } @@ -548,3 +575,24 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any } return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + val = strings.Replace(val, `\''`, `\'`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType)) + } + return fmt.Sprintf("'%s'", dbColumnValue) +} diff --git a/server/internal/db/dbm/mysql/dialect.go b/server/internal/db/dbm/mysql/dialect.go index adc89f86..865ba8b3 100644 --- a/server/internal/db/dbm/mysql/dialect.go +++ b/server/internal/db/dbm/mysql/dialect.go @@ -97,16 +97,7 @@ func (md *MysqlDialect) ToColumn(column *dbi.Column) { column.CharMaxLength = 1000 } else { column.DataType = dbi.ColumnDataType(ctype) - // 如果是int整型,删除精度 - if strings.Contains(strings.ToLower(ctype), "int") { - column.NumScale = 0 - column.CharMaxLength = 0 - } else - // 如果是text,删除长度 - if strings.Contains(strings.ToLower(ctype), "text") { - column.CharMaxLength = 0 - column.NumPrecision = 0 - } + md.dc.GetMetaData().FixColumn(column) } } diff --git a/server/internal/db/dbm/mysql/metadata.go b/server/internal/db/dbm/mysql/metadata.go index f315e465..0339a4ad 100644 --- a/server/internal/db/dbm/mysql/metadata.go +++ b/server/internal/db/dbm/mysql/metadata.go @@ -117,11 +117,25 @@ func (md *MysqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) NumScale: cast.ToInt(re["numScale"]), } + md.FixColumn(&column) columns = append(columns, column) } return columns, nil } +func (md *MysqlMetaData) FixColumn(column *dbi.Column) { + // 如果是int整型,删除精度 + if strings.Contains(strings.ToLower(string(column.DataType)), "int") { + column.NumScale = 0 + column.CharMaxLength = 0 + } else + // 如果是text,删除长度 + if strings.Contains(strings.ToLower(string(column.DataType)), "text") { + column.CharMaxLength = 0 + column.NumPrecision = 0 + } +} + // 获取表主键字段名,不存在主键标识则默认第一个字段 func (md *MysqlMetaData) GetPrimaryKey(tablename string) (string, error) { columns, err := md.GetColumns(tablename) @@ -157,6 +171,7 @@ func (md *MysqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { IndexComment: cast.ToString(re["indexComment"]), IsUnique: cast.ToInt(re["isUnique"]) == 1, SeqInIndex: cast.ToInt(re["seqInIndex"]), + IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1, }) } // 把查询结果以索引名分组,索引字段以逗号连接 @@ -183,27 +198,29 @@ func (md *MysqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tabl meta := md.dc.GetMetaData() sqlArr := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName) - sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE COMMENT '%s'" - replacer := strings.NewReplacer(";", "", "'", "") - sqlArr = append(sqlArr, fmt.Sprintf(sqlTmp, meta.QuoteIdentifier(tableInfo.TableName), unique, indexName, index.ColumnName, replacer.Replace(index.IndexComment))) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) + } + sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE" + sqlStr := fmt.Sprintf(sqlTmp, meta.QuoteIdentifier(tableInfo.TableName), unique, index.IndexName, strings.Join(colNames, ",")) + comment := meta.QuoteEscape(index.IndexComment) + if comment != "" { + sqlStr += fmt.Sprintf(" COMMENT '%s'", comment) + } + sqlArr = append(sqlArr, sqlStr) } return sqlArr } func (md *MysqlMetaData) genColumnBasicSql(column dbi.Column) string { - replacer := strings.NewReplacer(";", "", "'", "") + meta := md.dc.GetMetaData() dataType := string(column.DataType) incr := "" @@ -238,11 +255,11 @@ func (md *MysqlMetaData) genColumnBasicSql(column dbi.Column) string { comment := "" if column.ColumnComment != "" { // 防止注释内含有特殊字符串导致sql出错 - commentStr := replacer.Replace(column.ColumnComment) + commentStr := meta.QuoteEscape(column.ColumnComment) comment = fmt.Sprintf(" COMMENT '%s'", commentStr) } - columnSql := fmt.Sprintf(" %s %s %s %s %s %s", md.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, incr, defVal, comment) + columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, incr, defVal, comment) return columnSql } @@ -272,12 +289,11 @@ func (md *MysqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta if len(pks) > 0 { createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) } - createSql += fmt.Sprintf(") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ") + createSql += "\n)" // 表注释 if tableInfo.TableComment != "" { - replacer := strings.NewReplacer(";", "", "'", "") - createSql += fmt.Sprintf(" COMMENT '%s'", replacer.Replace(tableInfo.TableComment)) + createSql += fmt.Sprintf(" COMMENT '%s'", meta.QuoteEscape(tableInfo.TableComment)) } sqlArr = append(sqlArr, createSql) @@ -286,11 +302,11 @@ func (md *MysqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta } // 获取建表ddl -func (md *MysqlMetaData) GetTableDDL(tableName string) (string, error) { +func (md *MysqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := md.GetTables(tableName) tableInfo := &dbi.Table{} - if err != nil && len(tbs) > 0 { + if err != nil || tbs == nil || len(tbs) <= 0 { logx.Errorf("获取表信息失败, %s", tableName) return "", err } @@ -303,7 +319,7 @@ func (md *MysqlMetaData) GetTableDDL(tableName string) (string, error) { logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, false) + tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := md.GetTableIndex(tableName) if err != nil { @@ -426,6 +442,25 @@ func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { } func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + // 如果dataType是datetime而dbColumnValue是string类型,则需要根据类型格式化 + str, ok := dbColumnValue.(string) + if dataType == dbi.DataTypeDateTime && ok { + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + } + if dataType == dbi.DataTypeDate && ok { + res, _ := time.Parse(time.DateOnly, str) + return res.Format(time.DateOnly) + } + if dataType == dbi.DataTypeTime && ok { + res, _ := time.Parse(time.TimeOnly, str) + return res.Format(time.TimeOnly) + } return anyx.ToString(dbColumnValue) } @@ -448,3 +483,26 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any } return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + val = strings.Replace(val, `\''`, `\'`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + // mysql时间类型无需格式化 + return fmt.Sprintf("'%s'", dbColumnValue) + } + return fmt.Sprintf("'%s'", dbColumnValue) +} diff --git a/server/internal/db/dbm/oracle/dialect.go b/server/internal/db/dbm/oracle/dialect.go index e3685055..af8dd2ac 100644 --- a/server/internal/db/dbm/oracle/dialect.go +++ b/server/internal/db/dbm/oracle/dialect.go @@ -178,17 +178,7 @@ func (od *OracleDialect) ToColumn(commonColumn *dbi.Column) { commonColumn.CharMaxLength = 2000 } else { commonColumn.DataType = dbi.ColumnDataType(ctype) - // 如果类型是数字,类型后不需要带长度 - if strings.Contains(strings.ToLower(ctype), "int") { - commonColumn.CharMaxLength = 0 - commonColumn.NumPrecision = 0 - } else if strings.Contains(strings.ToLower(ctype), "char") { - // 如果是字符串类型,长度最大4000,否则修改字段类型为clob - if commonColumn.CharMaxLength > 4000 { - commonColumn.DataType = "CLOB" - commonColumn.CharMaxLength = 0 - } - } + od.dc.GetMetaData().FixColumn(commonColumn) } } diff --git a/server/internal/db/dbm/oracle/metadata.go b/server/internal/db/dbm/oracle/metadata.go index d1ad0cd1..c0887abe 100644 --- a/server/internal/db/dbm/oracle/metadata.go +++ b/server/internal/db/dbm/oracle/metadata.go @@ -120,28 +120,47 @@ func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) columns := make([]dbi.Column, 0) for _, re := range res { - defaultVal := cast.ToString(re["COLUMN_DEFAULT"]) - // 如果默认值包含.nextval,说明是序列,默认值为null - if strings.Contains(defaultVal, ".nextval") { - defaultVal = "" - } column := dbi.Column{ TableName: cast.ToString(re["TABLE_NAME"]), ColumnName: cast.ToString(re["COLUMN_NAME"]), DataType: dbi.ColumnDataType(cast.ToString(re["DATA_TYPE"])), + CharMaxLength: cast.ToInt(re["CHAR_MAX_LENGTH"]), ColumnComment: cast.ToString(re["COLUMN_COMMENT"]), Nullable: cast.ToString(re["NULLABLE"]) == "YES", IsPrimaryKey: cast.ToInt(re["IS_PRIMARY_KEY"]) == 1, IsIdentity: cast.ToInt(re["IS_IDENTITY"]) == 1, - ColumnDefault: defaultVal, + ColumnDefault: cast.ToString(re["COLUMN_DEFAULT"]), + NumPrecision: cast.ToInt(re["NUM_PRECISION"]), NumScale: cast.ToInt(re["NUM_SCALE"]), } + od.FixColumn(&column) columns = append(columns, column) } return columns, nil } +func (od *OracleMetaData) FixColumn(column *dbi.Column) { + // 如果默认值包含.nextval,说明是序列,默认值为null + if strings.Contains(column.ColumnDefault, ".nextval") { + column.ColumnDefault = "" + } + + // 统一处理一下数据类型的长度 + if collx.ArrayAnyMatches([]string{"date", "time", "lob", "int"}, strings.ToLower(string(column.DataType))) { + // 如果是不需要设置长度的类型 + column.CharMaxLength = 0 + column.NumPrecision = 0 + } else if strings.Contains(strings.ToLower(string(column.DataType)), "char") { + // 如果是字符串类型,长度最大4000,否则修改字段类型为clob + if column.CharMaxLength > 4000 { + column.DataType = "NCLOB" + column.CharMaxLength = 0 + column.NumPrecision = 0 + } + } +} + func (od *OracleMetaData) GetPrimaryKey(tablename string) (string, error) { columns, err := od.GetColumns(tablename) if err != nil { @@ -175,6 +194,7 @@ func (od *OracleMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { IndexComment: cast.ToString(re["INDEX_COMMENT"]), IsUnique: cast.ToInt(re["IS_UNIQUE"]) == 1, SeqInIndex: cast.ToInt(re["SEQ_IN_INDEX"]), + IsPrimaryKey: cast.ToInt(re["IS_PRIMARY"]) == 1, }) } // 把查询结果以索引名分组,索引字段以逗号连接 @@ -204,23 +224,19 @@ func (od *OracleMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tab comments := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName) - sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, indexName, meta.QuoteIdentifier(tableInfo.TableName), index.ColumnName)) - if index.IndexComment != "" { - comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s IS '%s'", indexName, index.IndexComment)) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) } + + sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, index.IndexName, meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) } sqlArr := make([]string, 0) @@ -237,7 +253,6 @@ func (od *OracleMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Tab func (od *OracleMetaData) genColumnBasicSql(column dbi.Column) string { meta := od.dc.GetMetaData() colName := meta.QuoteIdentifier(column.ColumnName) - dataType := string(column.DataType) if column.IsIdentity { // 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number @@ -249,47 +264,18 @@ func (od *OracleMetaData) genColumnBasicSql(column dbi.Column) string { nullAble = " NOT NULL" } - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 + defVal := "" if column.ColumnDefault != "" { - mark := false - // 哪些字段类型默认值需要加引号 - if collx.ArrayAnyMatches([]string{"CHAR", "LONG", "DATE", "TIME", "CLOB", "BLOB", "BFILE"}, dataType) { - // 默认值是时间日期函数的必须要加引号 - val := strings.ToUpper(column.ColumnDefault) - if collx.ArrayAnyMatches([]string{"DATE", "TIMESTAMP"}, dataType) && val == "CURRENT_DATE" || val == "CURRENT_TIMESTAMP" { - mark = false - } else { - mark = true - } - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } else { - // 如果是数字,默认值提取数字 - if collx.ArrayAnyMatches([]string{"NUM", "INT"}, dataType) { - match := bracketsRegexp.FindStringSubmatch(dataType) - if len(match) > 1 { - length := cast.ToInt(match[1]) - defVal = fmt.Sprintf(" DEFAULT %d", length) - } else { - defVal = fmt.Sprintf(" DEFAULT 0") - } - } - - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } + defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault) } - columnSql := fmt.Sprintf(" %s %s %s %s", colName, column.GetColumnType(), defVal, nullAble) + columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble) return columnSql } // 获取建表ddl func (od *OracleMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { meta := od.dc.GetMetaData() - replacer := strings.NewReplacer(";", "", "'", "") quoteTableName := meta.QuoteIdentifier(tableInfo.TableName) sqlArr := make([]string, 0) @@ -302,8 +288,7 @@ begin if num > 0 then execute immediate 'drop table "%s"' ; end if; -end; -` +end` sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, tableInfo.TableName, tableInfo.TableName)) } @@ -320,7 +305,7 @@ end; fields = append(fields, od.genColumnBasicSql(column)) // 防止注释内含有特殊字符串导致sql出错 if column.ColumnComment != "" { - comment := replacer.Replace(column.ColumnComment) + comment := meta.QuoteEscape(column.ColumnComment) columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment)) } } @@ -336,7 +321,7 @@ end; // 表注释 tableCommentSql := "" if tableInfo.TableComment != "" { - tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", meta.QuoteIdentifier(tableInfo.TableName), replacer.Replace(tableInfo.TableComment)) + tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", meta.QuoteIdentifier(tableInfo.TableName), meta.QuoteEscape(tableInfo.TableComment)) sqlArr = append(sqlArr, tableCommentSql) } @@ -349,13 +334,12 @@ end; } // 获取建表ddl -func (od *OracleMetaData) GetTableDDL(tableName string) (string, error) { +func (od *OracleMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := od.GetTables(tableName) tableInfo := &dbi.Table{} - if err != nil && len(tbs) > 0 { - + if err != nil || tbs == nil || len(tbs) <= 0 { logx.Errorf("获取表信息失败, %s", tableName) return "", err } @@ -368,7 +352,7 @@ func (od *OracleMetaData) GetTableDDL(tableName string) (string, error) { logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := od.GenerateTableDDL(columns, *tableInfo, false) + tableDDLArr := od.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := od.GetTableIndex(tableName) if err != nil { @@ -476,7 +460,12 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st switch dataType { // oracle把日期类型数据格式化输出 case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateTime) } return str @@ -490,3 +479,24 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any } return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + val = strings.Replace(val, `\''`, `\'`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + return fmt.Sprintf("to_timestamp('%s', 'yyyy-mm-dd hh24:mi:ss')", dc.FormatData(dbColumnValue, dataType)) + } + return fmt.Sprintf("'%s'", dbColumnValue) +} diff --git a/server/internal/db/dbm/postgres/dialect.go b/server/internal/db/dbm/postgres/dialect.go index 2f727def..af498f8b 100644 --- a/server/internal/db/dbm/postgres/dialect.go +++ b/server/internal/db/dbm/postgres/dialect.go @@ -197,14 +197,7 @@ func (pd *PgsqlDialect) ToColumn(commonColumn *dbi.Column) { commonColumn.CharMaxLength = 2000 } else { commonColumn.DataType = dbi.ColumnDataType(ctype) - // 哪些字段可以指定长度 - if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, ctype) { - commonColumn.CharMaxLength = 0 - commonColumn.NumPrecision = 0 - } else if strings.Contains(strings.ToLower(ctype), "char") { - // 如果类型是文本,长度翻倍 - commonColumn.CharMaxLength = commonColumn.CharMaxLength * 2 - } + } } diff --git a/server/internal/db/dbm/postgres/metadata.go b/server/internal/db/dbm/postgres/metadata.go index 7e0bc489..f4f2555e 100644 --- a/server/internal/db/dbm/postgres/metadata.go +++ b/server/internal/db/dbm/postgres/metadata.go @@ -2,6 +2,7 @@ package postgres import ( "fmt" + "io" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" @@ -114,12 +115,31 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) NumPrecision: cast.ToInt(re["numPrecision"]), NumScale: cast.ToInt(re["numScale"]), } - + pd.FixColumn(&column) columns = append(columns, column) } return columns, nil } +func (pd *PgsqlMetaData) FixColumn(column *dbi.Column) { + dataType := strings.ToLower(string(column.DataType)) + // 哪些字段可以指定长度 + if !collx.ArrayAnyMatches([]string{"char", "time", "bit", "num", "decimal"}, dataType) { + column.CharMaxLength = 0 + column.NumPrecision = 0 + } else if strings.Contains(dataType, "char") { + // 如果类型是文本,长度翻倍 + column.CharMaxLength = column.CharMaxLength * 2 + } + // 如果默认值带冒号,如:'id'::varchar + if column.ColumnDefault != "" && strings.Contains(column.ColumnDefault, "::") && !strings.HasPrefix(column.ColumnDefault, "nextval") { + match := defaultValueRegexp.FindStringSubmatch(column.ColumnDefault) + if len(match) > 1 { + column.ColumnDefault = match[1] + } + } +} + func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) { columns, err := pd.GetColumns(tablename) if err != nil { @@ -153,6 +173,7 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { IndexComment: cast.ToString(re["indexComment"]), IsUnique: cast.ToInt(re["isUnique"]) == 1, SeqInIndex: cast.ToInt(re["seqInIndex"]), + IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1, }) } // 把查询结果以索引名分组,索引字段以逗号连接 @@ -175,30 +196,30 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { } func (pd *PgsqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + meta := pd.dc.GetMetaData() creates := make([]string, 0) drops := make([]string, 0) comments := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName) // 如果索引名存在,先删除索引 - drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), indexName)) + drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName)) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) + } // 创建索引 - creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, indexName, pd.dc.Info.CurrentSchema(), tableInfo.TableName, index.ColumnName)) + creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, index.IndexName, pd.dc.Info.CurrentSchema(), tableInfo.TableName, strings.Join(colNames, ","))) if index.IndexComment != "" { - comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), indexName, index.IndexComment)) + comment := meta.QuoteEscape(index.IndexComment) + comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment)) } } @@ -222,6 +243,12 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string { colName := meta.QuoteIdentifier(column.ColumnName) dataType := string(column.DataType) + // 如果数据类型是数字,则去掉长度 + if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) { + column.NumPrecision = 0 + column.CharMaxLength = 0 + } + // 如果是自增类型,需要转换为serial if column.IsIdentity { if dataType == "int4" { @@ -240,31 +267,16 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string { nullAble := "" if !column.Nullable { nullAble = " NOT NULL" - // 如果字段不能为空,则设置默认值 - if column.ColumnDefault == "" { - if collx.ArrayAnyMatches([]string{"char", "text", "lob"}, strings.ToLower(dataType)) { - // 文本默认值为空字符串 - column.ColumnDefault = " " - } else if collx.ArrayAnyMatches([]string{"int", "num"}, strings.ToLower(dataType)) { - // 数字默认值为0 - column.ColumnDefault = "0" - } - } } defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { - // 哪些字段类型默认值需要加引号 mark := false + // 哪些字段类型默认值需要加引号 if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) { - // 如果是文本类型,则默认值不能带括号 - if collx.ArrayAnyMatches([]string{"char", "text", "lob"}, dataType) { - column.ColumnDefault = "" - } - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && - collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { + collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) { mark = false } else { mark = true @@ -275,30 +287,31 @@ func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string { column.ColumnDefault = "CURRENT_TIMESTAMP" } - if column.ColumnDefault != "" { - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) } } - columnSql := fmt.Sprintf(" %s %s %s %s ", colName, column.GetColumnType(), nullAble, defVal) + // 如果是varchar,长度翻倍,防止报错 + if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) { + column.CharMaxLength = column.CharMaxLength * 2 + } + columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal) return columnSql } func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - meta := pd.dc.GetMetaData() - replacer := strings.NewReplacer(";", "", "'", "") + quoteTableName := meta.QuoteIdentifier(tableInfo.TableName) sqlArr := make([]string, 0) if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", meta.QuoteIdentifier(tableInfo.TableName))) + sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName)) } // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s (\n", meta.QuoteIdentifier(tableInfo.TableName)) + createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName) fields := make([]string, 0) pks := make([]string, 0) columnComments := make([]string, 0) @@ -313,8 +326,8 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta // 防止注释内含有特殊字符串导致sql出错 if column.ColumnComment != "" { - comment := replacer.Replace(column.ColumnComment) - columnComments = append(columnComments, fmt.Sprintf(commentTmp, column.TableName, column.ColumnName, comment)) + comment := meta.QuoteEscape(column.ColumnComment) + columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment)) } } @@ -322,12 +335,12 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta if len(pks) > 0 { createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) } - createSql += ")" + createSql += "\n)" tableCommentSql := "" if tableInfo.TableComment != "" { commentTmp := "comment on table %s is '%s'" - tableCommentSql = fmt.Sprintf(commentTmp, tableInfo.TableName, replacer.Replace(tableInfo.TableComment)) + tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteEscape(tableInfo.TableComment)) } // create @@ -346,13 +359,12 @@ func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Ta } // 获取建表ddl -func (pd *PgsqlMetaData) GetTableDDL(tableName string) (string, error) { +func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := pd.GetTables(tableName) tableInfo := &dbi.Table{} - if err != nil && len(tbs) > 0 { - + if err != nil || tbs == nil || len(tbs) <= 0 { logx.Errorf("获取表信息失败, %s", tableName) return "", err } @@ -365,7 +377,7 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string) (string, error) { logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, false) + tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := pd.GetTableIndex(tableName) if err != nil { @@ -404,6 +416,19 @@ func (pd *PgsqlMetaData) DefaultDb() string { } } +func (pd *PgsqlMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) { + + // 设置自增序列当前值 + for _, column := range columns { + if column.IsIdentity { + seq := fmt.Sprintf("SELECT setval('%s_%s_seq', (SELECT max(%s) FROM %s));\n", tableName, column.ColumnName, column.ColumnName, tableName) + writer.Write([]byte(seq)) + } + } + + writer.Write([]byte("COMMIT;\n")) +} + func (pd *PgsqlMetaData) GetDataConverter() dbi.DataConverter { return converter } @@ -417,8 +442,9 @@ var ( dateRegexp = regexp.MustCompile(`(?i)date`) // 时间类型 timeRegexp = regexp.MustCompile(`(?i)time`) - // 定义正则表达式,匹配括号内的数字 - bracketsRegexp = regexp.MustCompile(`\((\d+)\)`) + + // 提取pg默认值, 如:'id'::varchar 提取id ; '-1'::integer 提取-1 + defaultValueRegexp = regexp.MustCompile(`'([^']*)'`) converter = new(DataConverter) @@ -497,13 +523,28 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st str := fmt.Sprintf("%v", dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, err = time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.TimeOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } return cast.ToString(dbColumnValue) @@ -526,3 +567,23 @@ func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any } return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType)) + } + return fmt.Sprintf("'%s'", dbColumnValue) +} diff --git a/server/internal/db/dbm/sqlite/dialect.go b/server/internal/db/dbm/sqlite/dialect.go index ba217920..00c9b2fe 100644 --- a/server/internal/db/dbm/sqlite/dialect.go +++ b/server/internal/db/dbm/sqlite/dialect.go @@ -59,7 +59,7 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error { // 生成新表名,为老表明+_copy_时间戳 newTableName := tableName + "_copy_" + time.Now().Format("20060102150405") - ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName) + ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName, false) if err != nil { return err } @@ -103,6 +103,8 @@ func (sd *SqliteDialect) ToColumn(commonColumn *dbi.Column) { if ctype == "" { commonColumn.DataType = "nvarchar" commonColumn.CharMaxLength = 2000 + } else { + sd.dc.GetMetaData().FixColumn(commonColumn) } } diff --git a/server/internal/db/dbm/sqlite/metadata.go b/server/internal/db/dbm/sqlite/metadata.go index d334b69e..2884930c 100644 --- a/server/internal/db/dbm/sqlite/metadata.go +++ b/server/internal/db/dbm/sqlite/metadata.go @@ -3,6 +3,7 @@ package sqlite import ( "errors" "fmt" + "io" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/logx" "mayfly-go/pkg/utils/anyx" @@ -138,12 +139,17 @@ func (sd *SqliteMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) } column.DataType = dbi.ColumnDataType(dataType) + sd.FixColumn(&column) + columns = append(columns, column) } } return columns, nil } +func (sd *SqliteMetaData) FixColumn(column *dbi.Column) { +} + func (sd *SqliteMetaData) GetPrimaryKey(tableName string) (string, error) { _, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName)) if err != nil { @@ -193,6 +199,7 @@ func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { IndexComment: cast.ToString(re["indexComment"]), IsUnique: isUnique, SeqInIndex: 1, + IsPrimaryKey: false, }) } // 把查询结果以索引名分组,索引字段以逗号连接 @@ -201,22 +208,24 @@ func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { // 获取建索引ddl func (sd *SqliteMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + meta := sd.dc.GetMetaData() sqls := make([]string, 0) for _, index := range indexs { - // 通过字段、表名拼接索引名 - columnName := strings.ReplaceAll(index.ColumnName, "-", "") - columnName = strings.ReplaceAll(columnName, "_", "") - colName := strings.ReplaceAll(columnName, ",", "_") - - keyType := "normal" unique := "" if index.IsUnique { - keyType = "unique" unique = "unique" } - indexName := fmt.Sprintf("%s_key_%s_%s", keyType, tableInfo.TableName, colName) + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = meta.QuoteIdentifier(name) + } + // 创建前尝试删除 + sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName)) + sqlTmp := "CREATE %s INDEX %s ON \"%s\" (%s) " - sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, indexName, tableInfo.TableName, index.ColumnName)) + sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, index.IndexName, tableInfo.TableName, strings.Join(colNames, ","))) } return sqls } @@ -232,9 +241,11 @@ func (sd *SqliteMetaData) genColumnBasicSql(column dbi.Column) string { nullAble = " NOT NULL" } + quoteColumnName := sd.dc.GetMetaData().QuoteIdentifier(column.ColumnName) + // 如果是主键,则直接返回,不判断默认值 if column.IsPrimaryKey { - return fmt.Sprintf(" %s integer PRIMARY KEY %s %s", column.ColumnName, incr, nullAble) + return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble) } defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 @@ -257,7 +268,7 @@ func (sd *SqliteMetaData) genColumnBasicSql(column dbi.Column) string { } } - return fmt.Sprintf(" %s %s %s %s", sd.dc.GetMetaData().QuoteIdentifier(column.ColumnName), column.GetColumnType(), nullAble, defVal) + return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal) } // 获取建表ddl @@ -275,8 +286,8 @@ func (sd *SqliteMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.T for _, column := range columns { fields = append(fields, sd.genColumnBasicSql(column)) } - createSql += strings.Join(fields, ",") - createSql += ") " + createSql += strings.Join(fields, ",\n") + createSql += "\n)" sqlArr = append(sqlArr, createSql) @@ -284,12 +295,18 @@ func (sd *SqliteMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.T } // 获取建表ddl -func (sd *SqliteMetaData) GetTableDDL(tableName string) (string, error) { +func (sd *SqliteMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { + var builder strings.Builder + + if dropBeforeCreate { + builder.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s; \n\n", tableName)) + } + _, res, err := sd.dc.Query("select sql from sqlite_master WHERE tbl_name=? order by type desc", tableName) if err != nil { return "", err } - var builder strings.Builder + for _, re := range res { builder.WriteString(cast.ToString(re["sql"]) + "; \n\n") } @@ -305,6 +322,11 @@ func (sd *SqliteMetaData) GetDataConverter() dbi.DataConverter { return converter } +func (sd *SqliteMetaData) BeforeDumpInsert(writer io.Writer, tableName string) { +} +func (sd *SqliteMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) { +} + var ( // 数字类型 numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`) @@ -388,13 +410,28 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st str := anyx.ToString(dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateTime, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.DateOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, str) + // 尝试用时间格式解析 + res, err := time.Parse(time.TimeOnly, str) + if err == nil { + return str + } + res, _ = time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } return str @@ -403,3 +440,24 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { return dbColumnValue } + +func (dc *DataConverter) WrapValue(dbColumnValue any, dataType dbi.DataType) string { + if dbColumnValue == nil { + return "NULL" + } + switch dataType { + case dbi.DataTypeNumber: + return fmt.Sprintf("%v", dbColumnValue) + case dbi.DataTypeString: + val := fmt.Sprintf("%v", dbColumnValue) + // 转义单引号 + val = strings.Replace(val, `'`, `''`, -1) + val = strings.Replace(val, `\''`, `\'`, -1) + // 转义换行符 + val = strings.Replace(val, "\n", "\\n", -1) + return fmt.Sprintf("'%s'", val) + case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime: + return fmt.Sprintf("'%s'", dc.FormatData(dbColumnValue, dataType)) + } + return fmt.Sprintf("'%s'", dbColumnValue) +}