diff --git a/frontend/src/views/ops/db/component/table/DbTableData.vue b/frontend/src/views/ops/db/component/table/DbTableData.vue index 740d8ac3..4f4140b2 100644 --- a/frontend/src/views/ops/db/component/table/DbTableData.vue +++ b/frontend/src/views/ops/db/component/table/DbTableData.vue @@ -890,7 +890,7 @@ defineExpose({ font-weight: bold; position: absolute; top: -7px; - padding: 1px; + padding: 3px; } .column-right { diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 9a38cd12..2418e062 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -262,7 +262,7 @@ func (d *Db) DumpSql(rc *req.Ctx) { } func (d *Db) TableInfos(rc *req.Ctx) { - res, err := d.getDbConn(rc).GetMetaData().GetTables() + res, err := d.getDbConn(rc).GetMetadata().GetTables() biz.ErrIsNilAppendErr(err, "获取表信息失败: %s") rc.ResData = res } @@ -270,7 +270,7 @@ func (d *Db) TableInfos(rc *req.Ctx) { func (d *Db) TableIndex(rc *req.Ctx) { tn := rc.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConn(rc).GetMetaData().GetTableIndex(tn) + res, err := d.getDbConn(rc).GetMetadata().GetTableIndex(tn) biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s") rc.ResData = res } @@ -281,7 +281,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) { biz.NotEmpty(tn, "tableName不能为空") dbi := d.getDbConn(rc) - res, err := dbi.GetMetaData().GetColumns(tn) + res, err := dbi.GetMetadata().GetColumns(tn) biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s") rc.ResData = res } @@ -290,7 +290,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) { func (d *Db) HintTables(rc *req.Ctx) { dbi := d.getDbConn(rc) - metadata := dbi.GetMetaData() + metadata := dbi.GetMetadata() // 获取所有表 tables, err := metadata.GetTables() biz.ErrIsNil(err) @@ -331,18 +331,18 @@ func (d *Db) HintTables(rc *req.Ctx) { func (d *Db) GetTableDDL(rc *req.Ctx) { tn := rc.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn, false) + res, err := d.getDbConn(rc).GetMetadata().GetTableDDL(tn, false) biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s") rc.ResData = res } func (d *Db) GetVersion(rc *req.Ctx) { - version := d.getDbConn(rc).GetMetaData().GetCompatibleDbVersion() + version := d.getDbConn(rc).GetMetadata().GetCompatibleDbVersion() rc.ResData = version } func (d *Db) GetSchemas(rc *req.Ctx) { - res, err := d.getDbConn(rc).GetMetaData().GetSchemas() + res, err := d.getDbConn(rc).GetMetadata().GetSchemas() biz.ErrIsNilAppendErr(err, "获取schemas失败: %s") rc.ResData = res } diff --git a/server/internal/db/api/db_instance.go b/server/internal/db/api/db_instance.go index 8493cd0a..c219ed80 100644 --- a/server/internal/db/api/db_instance.go +++ b/server/internal/db/api/db_instance.go @@ -125,7 +125,7 @@ func (d *Instance) GetDbServer(rc *req.Ctx) { instanceId := getInstanceId(rc) conn, err := d.DbApp.GetDbConnByInstanceId(instanceId) biz.ErrIsNil(err) - res, err := conn.GetMetaData().GetDbServer() + res, err := conn.GetMetadata().GetDbServer() biz.ErrIsNil(err) rc.ResData = res } diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index f6207977..485bcc3f 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -222,7 +222,6 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error } func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { - log := dto.DefaultDumpLog if reqParam.Log != nil { log = reqParam.Log @@ -247,7 +246,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { writer.WriteString("\n-- ----------------------------\n\n") // 获取目标元数据,仅生成sql,用于生成建表语句和插入数据,不能用于查询 - targetMeta := dbConn.GetMetaData() + targetDialect := dbConn.GetDialect() if reqParam.TargetDbType != "" && dbConn.Info.Type != reqParam.TargetDbType { // 创建一个假连接,仅用于调用方言生成sql,不做数据库连接操作 meta := dbi.GetMeta(reqParam.TargetDbType) @@ -256,10 +255,11 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { Meta: meta, }} - targetMeta = meta.GetMetaData(dbConn) + targetDialect = meta.GetDialect(dbConn) } - srcMeta := dbConn.GetMetaData() + srcMeta := dbConn.GetMetadata() + srcDialect := dbConn.GetDialect() if len(tables) == 0 { log("获取可导出的表信息...") ti, err := srcMeta.GetTables() @@ -295,14 +295,14 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { // 按表名排序 sort.Strings(tables) - quoteSchema := srcMeta.QuoteIdentifier(dbConn.Info.CurrentSchema()) - dumpHelper := targetMeta.GetDumpHelper() - dataHelper := targetMeta.GetDataHelper() + quoteSchema := srcDialect.QuoteIdentifier(dbConn.Info.CurrentSchema()) + dumpHelper := targetDialect.GetDumpHelper() + dataHelper := targetDialect.GetDataHelper() // 遍历获取每个表的信息 for _, tableName := range tables { log(fmt.Sprintf("获取表[%s]信息...", tableName)) - quoteTableName := targetMeta.QuoteIdentifier(tableName) + quoteTableName := targetDialect.QuoteIdentifier(tableName) // 查询表信息,主要是为了查询表注释 tbs, err := srcMeta.GetTables(tableName) @@ -323,7 +323,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { if reqParam.DumpDDL { log(fmt.Sprintf("生成表[%s]DDL...", tableName)) writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", tableName)) - tbDdlArr := targetMeta.GenerateTableDDL(columnMap[tableName], tabInfo, true) + tbDdlArr := targetDialect.GenerateTableDDL(columnMap[tableName], tabInfo, true) for _, ddl := range tbDdlArr { writer.WriteString(ddl + ";\n") } @@ -338,7 +338,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { // 获取列信息 quoteColNames := make([]string, 0) for _, col := range columnMap[tableName] { - quoteColNames = append(quoteColNames, targetMeta.QuoteIdentifier(col.ColumnName)) + quoteColNames = append(quoteColNames, targetDialect.QuoteIdentifier(col.ColumnName)) } _, _ = dbConn.WalkTableRows(ctx, tableName, func(row map[string]any, _ []*dbi.QueryColumn) error { @@ -367,7 +367,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { // 最后添加索引 log(fmt.Sprintf("生成表[%s]索引...", tableName)) writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表索引: %s \n-- ----------------------------\n", tableName)) - sqlArr := targetMeta.GenerateIndexDDL(indexs, tabInfo) + sqlArr := targetDialect.GenerateIndexDDL(indexs, tabInfo) for _, sqlStr := range sqlArr { writer.WriteString(sqlStr + ";\n") } diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index 8c434949..447ecc16 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -164,7 +164,7 @@ func (app *dataSyncAppImpl) RunCronJob(ctx context.Context, id uint64) error { } else { updFieldValType = dbi.DataTypeNumber } - wrapUpdFieldVal := srcConn.GetMetaData().GetDataHelper().WrapValue(task.UpdFieldVal, updFieldValType) + wrapUpdFieldVal := srcConn.GetDialect().GetDataHelper().WrapValue(task.UpdFieldVal, updFieldValType) updSql = fmt.Sprintf("and %s > %s", task.UpdField, wrapUpdFieldVal) orderSql = "order by " + task.UpdField + " asc " @@ -221,7 +221,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en } }() - srcMetaData := srcConn.GetMetaData() + srcDialect := srcConn.GetDialect() // task.FieldMap为json数组字符串 [{"src":"id","target":"id"}],转为map var fieldMap []map[string]string @@ -251,7 +251,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en updFieldType = dbi.DataTypeString for _, column := range columns { if strings.EqualFold(column.Name, updFieldName) { - updFieldType = srcMetaData.GetDataHelper().GetDataType(column.Type) + updFieldType = srcDialect.GetDataHelper().GetDataType(column.Type) break } } @@ -260,7 +260,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en total++ result = append(result, row) if total%batchSize == 0 { - if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcMetaData, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil { return err } @@ -283,7 +283,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en // 处理剩余的数据 if len(result) > 0 { - if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcMetaData, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil { targetDbTx.Rollback() return syncLog, err } @@ -307,7 +307,7 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en return syncLog, nil } -func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcMetaData *dbi.MetaDataX, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { +func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { // 遍历src字段列表,取出字段对应的类型 var srcColumnTypes = make(map[string]string) @@ -336,7 +336,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ if updFieldVal == "" || updFieldVal == nil { updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(field)] } - task.UpdFieldVal = srcMetaData.GetDataHelper().FormatData(updFieldVal, updFieldType) + task.UpdFieldVal = srcDialect.GetDataHelper().FormatData(updFieldVal, updFieldType) } // 如果指定了更新字段,则以更新字段取值 @@ -351,12 +351,12 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ // 获取源库字段数组 srcColumns := make([]string, 0) srcFieldTypes := make(map[string]dbi.DataType) - targetMetaData := targetDbConn.GetMetaData() + targetDialect := targetDbConn.GetDialect() for _, item := range fieldMap { targetField := item["target"] srcField := item["target"] - srcFieldTypes[srcField] = srcMetaData.GetDataHelper().GetDataType(srcColumnTypes[item["src"]]) - targetWrapColumns = append(targetWrapColumns, targetMetaData.QuoteIdentifier(targetField)) + srcFieldTypes[srcField] = srcDialect.GetDataHelper().GetDataType(srcColumnTypes[item["src"]]) + targetWrapColumns = append(targetWrapColumns, targetDialect.QuoteIdentifier(targetField)) srcColumns = append(srcColumns, srcField) } @@ -366,14 +366,14 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ rawValue := make([]any, 0) for _, column := range srcColumns { // 某些情况,如oracle,需要转换时间类型的字符串为time类型 - res := srcMetaData.GetDataHelper().ParseData(record[column], srcFieldTypes[column]) + res := srcDialect.GetDataHelper().ParseData(record[column], srcFieldTypes[column]) rawValue = append(rawValue, res) } values = append(values, rawValue) } // 目标数据库执行sql批量插入 - _, err := targetDbConn.GetDialect().BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values, task.DuplicateStrategy) + _, err := targetDialect.BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values, task.DuplicateStrategy) if err != nil { return err } diff --git a/server/internal/db/application/db_instance.go b/server/internal/db/application/db_instance.go index d71dc407..3017c2a6 100644 --- a/server/internal/db/application/db_instance.go +++ b/server/internal/db/application/db_instance.go @@ -258,7 +258,7 @@ func (app *instanceAppImpl) getDatabases(instance *entity.DbInstance, ac *tagent } defer dbConn.Close() - return dbConn.GetMetaData().GetDbNames() + return dbConn.GetMetadata().GetDbNames() } func (app *instanceAppImpl) toDbInfoByAc(instance *entity.DbInstance, ac *tagentity.ResourceAuthCert, database string) *dbi.DbInfo { diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index ffa7fb27..f2c02871 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -390,7 +390,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa } // 获取表主键列名,排除使用别名 - primaryKey, err := dbConn.GetMetaData().GetPrimaryKey(tableName) + primaryKey, err := dbConn.GetMetadata().GetPrimaryKey(tableName) if err != nil { logx.ErrorfContext(ctx, "Update SQL - 获取主键列失败: %s", err.Error()) return d.doExec(ctx, dbConn, sqlExecParam.Sql) diff --git a/server/internal/db/application/db_transfer.go b/server/internal/db/application/db_transfer.go index b214701b..a77f24b7 100644 --- a/server/internal/db/application/db_transfer.go +++ b/server/internal/db/application/db_transfer.go @@ -191,7 +191,7 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint } if app.IsRunning(taskId) { - logx.Panicf("[%d]该任务正在运行中...", taskId) + logx.Error("[%d]该任务正在运行中...", taskId) return } @@ -218,14 +218,14 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint // 获取迁移表信息 var tables []dbi.Table if task.CheckedKeys == "all" { - tables, err = srcConn.GetMetaData().GetTables() + tables, err = srcConn.GetMetadata().GetTables() if err != nil { app.EndTransfer(ctx, logId, taskId, "获取源表信息失败", err, nil) return } } else { tableNames := strings.Split(task.CheckedKeys, ",") - tables, err = srcConn.GetMetaData().GetTables(tableNames...) + tables, err = srcConn.GetMetadata().GetTables(tableNames...) if err != nil { app.EndTransfer(ctx, logId, taskId, "获取源表信息失败", err, nil) return @@ -353,9 +353,10 @@ func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64 if len(tableNames) == 0 { return errorx.NewBiz("没有需要迁移的表") } - srcMeta := srcConn.GetMetaData() + srcDialect := srcConn.GetDialect() + srcMetadata := srcConn.GetMetadata() // 查询源表列信息 - columns, err := srcMeta.GetColumns(tableNames...) + columns, err := srcMetadata.GetColumns(tableNames...) if err != nil { return errorx.NewBiz("获取源表列信息失败") } @@ -371,8 +372,8 @@ func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64 sort.Strings(sortTableNames) targetDialect := targetConn.GetDialect() - srcColumnHelper := srcMeta.GetColumnHelper() - targetColumnHelper := targetConn.GetMetaData().GetColumnHelper() + srcColumnHelper := srcDialect.GetColumnHelper() + targetColumnHelper := targetConn.GetDialect().GetColumnHelper() // 分组迁移 tableGroups := collx.ArraySplit[string](sortTableNames, 2) @@ -394,9 +395,13 @@ func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64 // 通过公共列信息生成目标库的建表语句,并执行目标库建表 app.Log(ctx, logId, fmt.Sprintf("开始创建目标表: 表名:%s", tbName)) - _, err := targetDialect.CreateTable(targetCols, tableMap[tbName], true) - if err != nil { - return errorx.NewBiz(fmt.Sprintf("创建目标表失败: 表名:%s, error: %s", tbName, err.Error())) + + sqlArr := targetDialect.GenerateTableDDL(targetCols, tableMap[tbName], true) + for _, sqlStr := range sqlArr { + _, err := targetConn.Exec(sqlStr) + if err != nil { + return errorx.NewBiz(fmt.Sprintf("创建目标表失败: 表名:%s, error: %s", tbName, err.Error())) + } } app.Log(ctx, logId, fmt.Sprintf("创建目标表成功: 表名:%s", tbName)) @@ -413,7 +418,7 @@ func (app *dbTransferAppImpl) transferDbTables(ctx context.Context, logId uint64 // 迁移索引信息 app.Log(ctx, logId, fmt.Sprintf("开始迁移索引: 表名:%s", tbName)) - err = app.transferIndex(ctx, tableMap[tbName], srcConn, targetDialect) + err = app.transferIndex(ctx, tableMap[tbName], srcConn, targetConn) if err != nil { return errorx.NewBiz(fmt.Sprintf("迁移索引失败: 表名:%s, error: %s", tbName, err.Error())) } @@ -432,8 +437,8 @@ func (app *dbTransferAppImpl) transferData(ctx context.Context, logId uint64, ta total := 0 // 总条数 batchSize := 1000 // 每次查询并迁移1000条数据 var err error - srcMeta := srcConn.GetMetaData() - srcConverter := srcMeta.GetDataHelper() + srcDialect := srcConn.GetDialect() + srcConverter := srcDialect.GetDataHelper() targetDialect := targetConn.GetDialect() logExtraKey := fmt.Sprintf("`%s` 当前已迁移数据量: ", tableName) @@ -485,15 +490,14 @@ func (app *dbTransferAppImpl) transfer2Target(taskId uint64, targetConn *dbi.DbC if err != nil { return err } - targetMeta := targetConn.GetMetaData() // 收集字段名 var columnNames []string for _, col := range targetColumns { - columnNames = append(columnNames, targetMeta.QuoteIdentifier(col.ColumnName)) + columnNames = append(columnNames, targetDialect.QuoteIdentifier(col.ColumnName)) } - dataHelper := targetMeta.GetDataHelper() + dataHelper := targetDialect.GetDataHelper() // 从目标库数据中取出源库字段对应的值 values := make([][]any, 0) @@ -501,7 +505,7 @@ func (app *dbTransferAppImpl) transfer2Target(taskId uint64, targetConn *dbi.DbC rawValue := make([]any, 0) for _, tc := range targetColumns { columnName := tc.ColumnName - val := record[targetMeta.RemoveQuote(columnName)] + val := record[targetDialect.RemoveQuote(columnName)] if !tc.Nullable { // 如果val是文本,则设置为空格字符 switch val.(type) { @@ -537,9 +541,9 @@ func (app *dbTransferAppImpl) transfer2Target(taskId uint64, targetConn *dbi.DbC return err } -func (app *dbTransferAppImpl) transferIndex(_ context.Context, tableInfo dbi.Table, srcConn *dbi.DbConn, targetDialect dbi.Dialect) error { +func (app *dbTransferAppImpl) transferIndex(ctx context.Context, tableInfo dbi.Table, srcConn *dbi.DbConn, targetConn *dbi.DbConn) error { // 查询源表索引信息 - indexs, err := srcConn.GetMetaData().GetTableIndex(tableInfo.TableName) + indexs, err := srcConn.GetMetadata().GetTableIndex(tableInfo.TableName) if err != nil { logx.Error("获取索引信息失败", err) return err @@ -549,7 +553,14 @@ func (app *dbTransferAppImpl) transferIndex(_ context.Context, tableInfo dbi.Tab } // 通过表名、索引信息生成建索引语句,并执行到目标表 - return targetDialect.CreateIndex(tableInfo, indexs) + sqlArr := targetConn.GetDialect().GenerateIndexDDL(indexs, tableInfo) + for _, sqlStr := range sqlArr { + _, err := targetConn.Exec(sqlStr) + if err != nil { + return err + } + } + return nil } func (d *dbTransferAppImpl) TimerDeleteTransferFile() { @@ -600,6 +611,8 @@ func (app *dbTransferAppImpl) Log(ctx context.Context, logId uint64, msg string, } func (app *dbTransferAppImpl) EndTransfer(ctx context.Context, logId uint64, taskId uint64, msg string, err error, extra map[string]any) { + app.MarkStop(taskId) + logType := sysentity.SyslogTypeSuccess transferState := entity.DbTransferTaskRunStateSuccess if err != nil { diff --git a/server/internal/db/dbm/dbi/conn.go b/server/internal/db/dbm/dbi/conn.go index 77617d8e..34a53b06 100644 --- a/server/internal/db/dbm/dbi/conn.go +++ b/server/internal/db/dbm/dbi/conn.go @@ -126,9 +126,9 @@ func (d *DbConn) GetDialect() Dialect { return d.Info.Meta.GetDialect(d) } -// GetMetaData 获取数据库MetaData -func (d *DbConn) GetMetaData() *MetaDataX { - return d.Info.Meta.GetMetaData(d) +// GetMetadata 获取数据库MetaData +func (d *DbConn) GetMetadata() Metadata { + return d.Info.Meta.GetMetadata(d) } // Stats 返回数据库连接状态 diff --git a/server/internal/db/dbm/dbi/db_info.go b/server/internal/db/dbm/dbi/db_info.go index 16db09d6..d2b80f13 100644 --- a/server/internal/db/dbm/dbi/db_info.go +++ b/server/internal/db/dbm/dbi/db_info.go @@ -72,7 +72,7 @@ func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { database := dbInfo.Database // 如果数据库为空,则使用默认数据库进行连接 if database == "" { - database = meta.GetMetaData(&DbConn{Info: dbInfo}).DefaultDb() + database = meta.GetMetadata(&DbConn{Info: dbInfo}).GetDefaultDb() dbInfo.Database = database } diff --git a/server/internal/db/dbm/dbi/dialect.go b/server/internal/db/dbm/dbi/dialect.go index 78ac566e..1bd4f5ee 100644 --- a/server/internal/db/dbm/dbi/dialect.go +++ b/server/internal/db/dbm/dbi/dialect.go @@ -3,8 +3,13 @@ package dbi import ( "database/sql" "errors" + "fmt" + "io" "mayfly-go/internal/db/dbm/sqlparser" - "mayfly-go/internal/db/dbm/sqlparser/mysql" + "mayfly-go/internal/db/dbm/sqlparser/pgsql" + "strings" + + pq "gitee.com/liuzongyang/libpq" ) const ( @@ -23,41 +28,185 @@ type DbCopyTable struct { CopyData bool `json:"copyData"` // 是否复制数据 } -// -----------------------------------元数据接口定义------------------------------------------ -// 数据库方言 用于获取元信息接口、批量插入等各个数据库方言不一致的实现方式 -type Dialect interface { +// BaseDialect 基础dialect,在DefaultDialect 都有默认的实现方法 +type BaseDialect interface { + + // GetIdentifierQuoteString 用于引用 SQL 标识符(关键字)的字符串 + GetIdentifierQuoteString() string + + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be + // used as part of an SQL statement. For example: + // + // tblname := "my_table" + // data := "my_data" + // quoted := quoteIdentifier(tblname, '"') + // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) + // + // Any double quotes in name will be escaped. The quoted identifier will be + // case sensitive when used in a query. If the input string contains a zero + // byte, the result will be truncated immediately before it. + QuoteIdentifier(name string) string + + RemoveQuote(name string) string + + // QuoteEscape 引号转义,多用于sql注释转义,防止拼接sql报错,如: comment xx is '注''释' 最终注释文本为: 注'释 + QuoteEscape(str string) string + + // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal + // to DDL and other statements that do not accept parameters) to be used as part + // of an SQL statement. For example: + // + // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") + // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) + // + // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be + // replaced by two backslashes (i.e. "\\") and the C-style escape identifier + QuoteLiteral(literal string) string // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 GetDbProgram() (DbProgram, error) + // GetColumnHelper + GetColumnHelper() ColumnHelper + + // GetDumpHeler + GetDumpHelper() DumpHelper + + // GetDataHelper 获取数据处理助手 用于解析格式化列数据等 + GetDataHelper() DataHelper + + // GetSQLParser 获取sql解析器 + GetSQLParser() sqlparser.SqlParser +} + +// -----------------------------------元数据接口定义------------------------------------------ +// Dialect 数据库方言 用于生成sql、批量插入等各个数据库方言不一致的实现方式 +type Dialect interface { + BaseDialect + // BatchInsert 批量insert数据 BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) // CopyTable 拷贝表 CopyTable(copy *DbCopyTable) error - // CreateTable 创建表 - CreateTable(columns []Column, tableInfo Table, dropOldTable bool) (int, error) + // GenerateTableDDL 生成建表ddl + GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string - // CreateIndex 创建索引 - CreateIndex(tableInfo Table, indexs []Index) error + // GenerateIndexDDL 生成索引ddl + GenerateIndexDDL(indexs []Index, tableInfo Table) []string // UpdateSequence 有些数据库迁移完数据之后,需要更新表自增序列为当前表最大值 UpdateSequence(tableName string, columns []Column) - - GetSQLParser() sqlparser.SqlParser } +// DefaultDialect 默认实现,若需要覆盖,则由各个数据库dialect实现去覆盖重写 type DefaultDialect struct { } +var _ (BaseDialect) = (*DefaultDialect)(nil) + +func (dd *DefaultDialect) GetIdentifierQuoteString() string { + return `"` +} + +func (dx *DefaultDialect) QuoteIdentifier(name string) string { + quoter := dx.GetIdentifierQuoteString() + // 兼容mssql + if quoter == "[" { + return fmt.Sprintf("[%s]", name) + } + + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter +} + +func (dx *DefaultDialect) RemoveQuote(name string) string { + quoter := dx.GetIdentifierQuoteString() + + // 兼容mssql + if quoter == "[" { + return strings.Trim(name, "[]") + } + + return strings.ReplaceAll(name, quoter, "") +} + +func (dd *DefaultDialect) QuoteEscape(str string) string { + return strings.Replace(str, `'`, `''`, -1) +} + +func (dd *DefaultDialect) QuoteLiteral(literal string) string { + return pq.QuoteLiteral(literal) +} + +func (dd *DefaultDialect) UpdateSequence(tableName string, columns []Column) {} + // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (dd *DefaultDialect) GetDbProgram() (DbProgram, error) { return nil, errors.New("not support db program") } -func (dd *DefaultDialect) UpdateSequence(tableName string, columns []Column) {} - -func (dd *DefaultDialect) GetSQLParser() sqlparser.SqlParser { - return new(mysql.MysqlParser) +func (dd *DefaultDialect) GetDumpHelper() DumpHelper { + return new(DefaultDumpHelper) +} + +func (dd *DefaultDialect) GetColumnHelper() ColumnHelper { + return new(DefaultColumnHelper) +} + +func (pd *DefaultDialect) GetSQLParser() sqlparser.SqlParser { + return new(pgsql.PgsqlParser) +} + +func (pd *DefaultDialect) GetDataHelper() DataHelper { + return nil +} + +// ColumnHelper 数据库迁移辅助方法 +type ColumnHelper interface { + // ToCommonColumn 数据库方言自带的列转换为公共列 + ToCommonColumn(dialectColumn *Column) + + // ToColumn 公共列转为各个数据库方言自带的列 + ToColumn(commonColumn *Column) + + // FixColumn 根据数据库类型修复字段长度、精度等 + FixColumn(column *Column) +} + +type DefaultColumnHelper struct { +} + +func (dd *DefaultColumnHelper) ToCommonColumn(dialectColumn *Column) {} + +func (dd *DefaultColumnHelper) ToColumn(commonColumn *Column) {} + +func (dd *DefaultColumnHelper) FixColumn(column *Column) {} + +// DumpHelper 导出辅助方法 +type DumpHelper interface { + BeforeInsert(writer io.Writer, tableName string) + + BeforeInsertSql(quoteSchema string, quoteTableName string) string + + AfterInsert(writer io.Writer, tableName string, columns []Column) +} + +type DefaultDumpHelper struct { +} + +func (dd *DefaultDumpHelper) BeforeInsert(writer io.Writer, tableName string) { + writer.Write([]byte("BEGIN;\n")) +} + +func (dd *DefaultDumpHelper) BeforeInsertSql(quoteSchema string, quoteTableName string) string { + return "" +} + +func (dd *DefaultDumpHelper) AfterInsert(writer io.Writer, tableName string, columns []Column) { + writer.Write([]byte("COMMIT;\n")) } diff --git a/server/internal/db/dbm/dbi/meta.go b/server/internal/db/dbm/dbi/meta.go index 9484300d..c2cc3d39 100644 --- a/server/internal/db/dbm/dbi/meta.go +++ b/server/internal/db/dbm/dbi/meta.go @@ -22,13 +22,13 @@ type DbVersion string // 数据库元信息,如获取sql.DB、Dialect等 type Meta interface { - // 根据数据库信息获取sql.DB + // GetSqlDb 根据数据库信息获取sql.DB GetSqlDb(*DbInfo) (*sql.DB, error) - // 获取数据库方言 + // GetDialect 获取数据库方言, 若一些接口(如 GetIdentifierQuoteString)不需要DbConn,则可以传nil GetDialect(*DbConn) Dialect - // 获取元数据信息接口 - // @param *DbConn 数据库连接, 若一些元数据接口(如 GetIdentifierQuoteString)不需要DbConn,则可以传nil - GetMetaData(*DbConn) *MetaDataX + // GetMetadata 获取元数据信息接口 + // @param *DbConn 数据库连接 + GetMetadata(*DbConn) Metadata } diff --git a/server/internal/db/dbm/dbi/metadata.go b/server/internal/db/dbm/dbi/metadata.go index d9e5ba99..1599276f 100644 --- a/server/internal/db/dbm/dbi/metadata.go +++ b/server/internal/db/dbm/dbi/metadata.go @@ -9,9 +9,8 @@ import ( "strings" ) -// 元数据接口(表、列、等元信息) -type MetaData interface { - BaseMetaData +// Metadata 元数据接口(表、列、等元信息) +type Metadata interface { // GetDbServer 获取数据库服务实例信息 GetDbServer() (*DbServer, error) @@ -19,6 +18,12 @@ type MetaData interface { // GetCompatibleDbVersion 获取兼容版本信息,如果有兼容版本,则需要实现对应版本的特殊方言处理器,以及前端的方言兼容版本 GetCompatibleDbVersion() DbVersion + // GetDefaultDb 获取默认库 + GetDefaultDb() string + + // GetSchemas + GetSchemas() ([]string, error) + // GetDbNames 获取数据库名称列表 GetDbNames() ([]string, error) @@ -36,17 +41,18 @@ type MetaData interface { // GetTableDDL 获取建表ddl GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) +} - // GenerateTableDDL 生成建表ddl - GenerateTableDDL(columns []Column, tableInfo Table, dropBeforeCreate bool) []string +// 默认实现,若需要覆盖,则由各个数据库MetaData实现去覆盖重写 +type DefaultMetadata struct { +} - // GenerateIndexDDL 生成索引ddl - GenerateIndexDDL(indexs []Index, tableInfo Table) []string +func (dd *DefaultMetadata) GetCompatibleDbVersion() DbVersion { + return "" +} - GetSchemas() ([]string, error) - - // GetDataHelper 获取数据处理助手 用于解析格式化列数据等 - GetDataHelper() DataHelper +func (dd *DefaultMetadata) GetDefaultDb() string { + return "" } // GenerateSQLStepFunc 生成insert sql的step函数,用于生成insert sql时,每生成100条sql时调用 diff --git a/server/internal/db/dbm/dbi/metadata_base.go b/server/internal/db/dbm/dbi/metadata_base.go deleted file mode 100644 index 1ce2b6ab..00000000 --- a/server/internal/db/dbm/dbi/metadata_base.go +++ /dev/null @@ -1,126 +0,0 @@ -package dbi - -import ( - "io" - "strings" - - pq "gitee.com/liuzongyang/libpq" - // "github.com/kanzihuang/vitess/go/vt/sqlparser" -) - -type BaseMetaData interface { - - // 默认库 - DefaultDb() string - - DbVersion() string - - // 用于引用 SQL 标识符(关键字)的字符串 - GetIdentifierQuoteString() string - - // 引号转义,多用于sql注释转义,防止拼接sql报错,如: comment xx is '注''释' 最终注释文本为: 注'释 - QuoteEscape(str string) string - - // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal - // to DDL and other statements that do not accept parameters) to be used as part - // of an SQL statement. For example: - // - // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") - // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) - // - // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be - // replaced by two backslashes (i.e. "\\") and the C-style escape identifier - QuoteLiteral(literal string) string - - // GetSqlParserDialect() sqlparser.Dialect - - // GetColumnHelper - GetColumnHelper() ColumnHelper - - // GetDumpHeler - GetDumpHelper() DumpHelper -} - -// 默认实现,若需要覆盖,则由各个数据库MetaData实现去覆盖重写 -type DefaultMetaData struct { -} - -func (dd *DefaultMetaData) GetCompatibleDbVersion() DbVersion { - return "" -} -func (dd *DefaultMetaData) DefaultDb() string { - return "" -} - -func (dd *DefaultMetaData) DbVersion() string { - return "" -} - -func (dd *DefaultMetaData) GetIdentifierQuoteString() string { - return `"` -} - -func (dd *DefaultMetaData) QuoteEscape(str string) string { - return strings.Replace(str, `'`, `''`, -1) -} - -func (dd *DefaultMetaData) QuoteLiteral(literal string) string { - return pq.QuoteLiteral(literal) -} - -// func (dd *DefaultMetaData) GetSqlParserDialect() sqlparser.Dialect { -// return sqlparser.PostgresDialect{} -// } - -func (dd *DefaultMetaData) GetDumpHelper() DumpHelper { - return new(DefaultDumpHelper) -} - -func (dd *DefaultMetaData) GetColumnHelper() ColumnHelper { - return new(DefaultColumnHelper) -} - -// ColumnHelper 数据库迁移辅助方法 -type ColumnHelper interface { - // ToCommonColumn 数据库方言自带的列转换为公共列 - ToCommonColumn(dialectColumn *Column) - - // ToColumn 公共列转为各个数据库方言自带的列 - ToColumn(commonColumn *Column) - - // FixColumn 根据数据库类型修复字段长度、精度等 - FixColumn(column *Column) -} - -type DefaultColumnHelper struct { -} - -func (dd *DefaultColumnHelper) ToCommonColumn(dialectColumn *Column) {} - -func (dd *DefaultColumnHelper) ToColumn(commonColumn *Column) {} - -func (dd *DefaultColumnHelper) FixColumn(column *Column) {} - -// DumpHelper 导出辅助方法 -type DumpHelper interface { - BeforeInsert(writer io.Writer, tableName string) - - BeforeInsertSql(quoteSchema string, quoteTableName string) string - - AfterInsert(writer io.Writer, tableName string, columns []Column) -} - -type DefaultDumpHelper struct { -} - -func (dd *DefaultDumpHelper) BeforeInsert(writer io.Writer, tableName string) { - writer.Write([]byte("BEGIN;\n")) -} - -func (dd *DefaultDumpHelper) BeforeInsertSql(quoteSchema string, quoteTableName string) string { - return "" -} - -func (dd *DefaultDumpHelper) AfterInsert(writer io.Writer, tableName string, columns []Column) { - writer.Write([]byte("COMMIT;\n")) -} diff --git a/server/internal/db/dbm/dbi/metadatax.go b/server/internal/db/dbm/dbi/metadatax.go deleted file mode 100644 index d43c4b09..00000000 --- a/server/internal/db/dbm/dbi/metadatax.go +++ /dev/null @@ -1,59 +0,0 @@ -package dbi - -import ( - "fmt" - "strings" -) - -// 包装扩展MetaData,提供所有实现MetaData结构体的公共方法 -type MetaDataX struct { - MetaData -} - -func NewMetaDataX(metaData MetaData) *MetaDataX { - return &MetaDataX{metaData} -} - -func (md *MetaDataX) QuoteIdentifier(name string) string { - return QuoteIdentifier(md, name) -} - -func (md *MetaDataX) RemoveQuote(name string) string { - return RemoveQuote(md, name) -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// quoted := quoteIdentifier(tblname, '"') -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func QuoteIdentifier(metadata MetaData, name string) string { - quoter := metadata.GetIdentifierQuoteString() - // 兼容mssql - if quoter == "[" { - return fmt.Sprintf("[%s]", name) - } - - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter -} - -func RemoveQuote(metadata MetaData, name string) string { - quoter := metadata.GetIdentifierQuoteString() - - // 兼容mssql - if quoter == "[" { - return strings.Trim(name, "[]") - } - - return strings.ReplaceAll(name, quoter, "") -} diff --git a/server/internal/db/dbm/dm/dialect.go b/server/internal/db/dbm/dm/dialect.go index f6fb96bc..967e5561 100644 --- a/server/internal/db/dbm/dm/dialect.go +++ b/server/internal/db/dbm/dm/dialect.go @@ -40,7 +40,7 @@ func (dd *DMDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []s identityInsert := fmt.Sprintf("set identity_insert \"%s\" on;", tableName) - sqlTemp := fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, dd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlTemp := fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, dd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) effRows := 0 // 设置允许填充自增列之后,显示指定列名可以插入自增列 for _, value := range values { @@ -60,17 +60,17 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st // 查询主键字段 uniqueCols := make([]string, 0) caseSqls := make([]string, 0) - metadata := dd.dc.GetMetaData() + metadata := dd.dc.GetMetadata() tableCols, _ := metadata.GetColumns(tableName) identityCols := make([]string, 0) for _, col := range tableCols { if col.IsPrimaryKey { uniqueCols = append(uniqueCols, col.ColumnName) - caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", metadata.QuoteIdentifier(col.ColumnName), metadata.QuoteIdentifier(col.ColumnName))) + caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", dd.QuoteIdentifier(col.ColumnName), dd.QuoteIdentifier(col.ColumnName))) } if col.IsIdentity { // 自增字段不放入insert内,即使是设置了identity_insert on也不起作用 - identityCols = append(identityCols, metadata.QuoteIdentifier(col.ColumnName)) + identityCols = append(identityCols, dd.QuoteIdentifier(col.ColumnName)) } } // 查询唯一索引涉及到的字段,并组装到match条件内 @@ -81,7 +81,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st tmp := make([]string, 0) for _, col := range cols { uniqueCols = append(uniqueCols, col) - tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", metadata.QuoteIdentifier(col), metadata.QuoteIdentifier(col))) + tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dd.QuoteIdentifier(col), dd.QuoteIdentifier(col))) } caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND "))) } @@ -94,7 +94,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st insertCols := make([]string, 0) for _, column := range columns { phs = append(phs, fmt.Sprintf("? %s", column)) - if !collx.ArrayContains(uniqueCols, metadata.RemoveQuote(column)) { + if !collx.ArrayContains(uniqueCols, dd.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } if !collx.ArrayContains(identityCols, column) { @@ -109,7 +109,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st } t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + metadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ") + sqlTemp := "MERGE INTO " + dd.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ") sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")" sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -124,7 +124,7 @@ func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []st func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName - metadata := dd.dc.GetMetaData() + metadata := dd.dc.GetMetadata() ddl, err := metadata.GetTableDDL(tableName, false) if err != nil { return err @@ -162,32 +162,119 @@ func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (dd *DMDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { +func (dd *DMDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + tbName := dd.QuoteIdentifier(tableInfo.TableName) + sqlArr := make([]string, 0) - sqlArr := dd.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable) - // 达梦需要分开执行sql - if len(sqlArr) > 0 { - for _, sqlStr := range sqlArr { - _, err := dd.dc.Exec(sqlStr) - if err != nil { - return 0, err + if dropBeforeCreate { + sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName)) + } + // 组装建表语句 + createSql := fmt.Sprintf("create table %s (", tbName) + fields := make([]string, 0) + pks := make([]string, 0) + columnComments := make([]string, 0) + + for _, column := range columns { + if column.IsPrimaryKey { + pks = append(pks, dd.QuoteIdentifier(column.ColumnName)) + } + fields = append(fields, dd.genColumnBasicSql(column)) + if column.ColumnComment != "" { + comment := dd.QuoteEscape(column.ColumnComment) + columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, dd.QuoteIdentifier(column.ColumnName), comment)) + } + } + createSql += strings.Join(fields, ",\n") + if len(pks) > 0 { + createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ",")) + } + createSql += "\n)" + + tableCommentSql := "" + if tableInfo.TableComment != "" { + comment := dd.QuoteEscape(tableInfo.TableComment) + tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment) + } + + sqlArr = append(sqlArr, createSql) + if tableCommentSql != "" { + sqlArr = append(sqlArr, tableCommentSql) + } + + if len(columnComments) > 0 { + sqlArr = append(sqlArr, columnComments...) + } + + return sqlArr +} + +func (dd *DMDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + sqls := make([]string, 0) + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = dd.QuoteIdentifier(name) + } + + sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, dd.QuoteIdentifier(index.IndexName), dd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) + } + return sqls +} + +func (dd *DMDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (dd *DMDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (dd *DMDialect) GetDumpHelper() dbi.DumpHelper { + return new(DumpHelper) +} + +func (dd *DMDialect) genColumnBasicSql(column dbi.Column) string { + colName := dd.QuoteIdentifier(column.ColumnName) + dataType := string(column.DataType) + + incr := "" + if column.IsIdentity { + incr = " IDENTITY" + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + + defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 + if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { + // 哪些字段类型默认值需要加引号 + mark := false + if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) { + // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && + collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { + mark = false + } else { + mark = true } } + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) + } } - return len(sqlArr), nil -} - -func (dd *DMDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - sqlArr := dd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) - // 达梦需要分开执行sql - if len(sqlArr) > 0 { - for _, sqlStr := range sqlArr { - _, err := dd.dc.Exec(sqlStr) - if err != nil { - return err - } - } - } - return nil + columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) + return columnSql } diff --git a/server/internal/db/dbm/dm/meta.go b/server/internal/db/dbm/dm/meta.go index 002a5f76..19cf685b 100644 --- a/server/internal/db/dbm/dm/meta.go +++ b/server/internal/db/dbm/dm/meta.go @@ -44,8 +44,8 @@ func (dm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &DMDialect{dc: conn} } -func (dm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { - return dbi.NewMetaDataX(&DMMetaData{ +func (dm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { + return &DMMetadata{ dc: conn, - }) + } } diff --git a/server/internal/db/dbm/dm/metadata.go b/server/internal/db/dbm/dm/metadata.go index 18e2e63b..c7c53fe2 100644 --- a/server/internal/db/dbm/dm/metadata.go +++ b/server/internal/db/dbm/dm/metadata.go @@ -21,13 +21,13 @@ const ( DM_COLUMN_MA_KEY = "DM_COLUMN_MA" ) -type DMMetaData struct { - dbi.DefaultMetaData +type DMMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn } -func (dd *DMMetaData) GetDbServer() (*dbi.DbServer, error) { +func (dd *DMMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := dd.dc.Query("select * from v$instance") if err != nil { return nil, err @@ -38,7 +38,7 @@ func (dd *DMMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (dd *DMMetaData) GetDbNames() ([]string, error) { +func (dd *DMMetadata) GetDbNames() ([]string, error) { _, res, err := dd.dc.Query("SELECT name AS DBNAME FROM v$database") if err != nil { return nil, err @@ -52,9 +52,10 @@ func (dd *DMMetaData) GetDbNames() ([]string, error) { return databases, nil } -func (dd *DMMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { +func (dd *DMMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := dd.dc.GetDialect() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbi.RemoveQuote(dd, val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -85,9 +86,10 @@ func (dd *DMMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { +func (dd *DMMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := dd.dc.GetDialect() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbi.RemoveQuote(dd, val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") _, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName)) @@ -95,7 +97,7 @@ func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { return nil, err } - columnHelper := dd.dc.GetMetaData().GetColumnHelper() + columnHelper := dd.dc.GetDialect().GetColumnHelper() columns := make([]dbi.Column, 0) for _, re := range res { column := dbi.Column{ @@ -117,7 +119,7 @@ func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { return columns, nil } -func (dd *DMMetaData) GetPrimaryKey(tablename string) (string, error) { +func (dd *DMMetadata) GetPrimaryKey(tablename string) (string, error) { columns, err := dd.GetColumns(tablename) if err != nil { return "", err @@ -135,7 +137,7 @@ func (dd *DMMetaData) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (dd *DMMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (dd *DMMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { _, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err @@ -172,117 +174,8 @@ func (dd *DMMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -func (dd *DMMetaData) genColumnBasicSql(column dbi.Column) string { - meta := dd.dc.GetMetaData() - colName := meta.QuoteIdentifier(column.ColumnName) - dataType := string(column.DataType) - - incr := "" - if column.IsIdentity { - incr = " IDENTITY" - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 - if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { - // 哪些字段类型默认值需要加引号 - mark := false - if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) { - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && - collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { - mark = false - } else { - mark = true - } - } - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } - - columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) - return columnSql - -} - -func (dd *DMMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - meta := dd.dc.GetMetaData() - sqls := make([]string, 0) - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - - sqls = append(sqls, fmt.Sprintf("create %s index %s on %s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) - } - return sqls -} - -func (dd *DMMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - meta := dd.dc.GetMetaData() - tbName := meta.QuoteIdentifier(tableInfo.TableName) - sqlArr := make([]string, 0) - - if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("drop table if exists %s", tbName)) - } - // 组装建表语句 - createSql := fmt.Sprintf("create table %s (", tbName) - fields := make([]string, 0) - pks := make([]string, 0) - columnComments := make([]string, 0) - - for _, column := range columns { - if column.IsPrimaryKey { - pks = append(pks, meta.QuoteIdentifier(column.ColumnName)) - } - fields = append(fields, dd.genColumnBasicSql(column)) - if column.ColumnComment != "" { - comment := meta.QuoteEscape(column.ColumnComment) - columnComments = append(columnComments, fmt.Sprintf("comment on column %s.%s is '%s'", tbName, meta.QuoteIdentifier(column.ColumnName), comment)) - } - } - createSql += strings.Join(fields, ",\n") - if len(pks) > 0 { - createSql += fmt.Sprintf(",\n PRIMARY KEY (%s)", strings.Join(pks, ",")) - } - createSql += "\n)" - - tableCommentSql := "" - if tableInfo.TableComment != "" { - comment := meta.QuoteEscape(tableInfo.TableComment) - tableCommentSql = fmt.Sprintf("comment on table %s is '%s'", tbName, comment) - } - - sqlArr = append(sqlArr, createSql) - if tableCommentSql != "" { - sqlArr = append(sqlArr, tableCommentSql) - } - - if len(columnComments) > 0 { - sqlArr = append(sqlArr, columnComments...) - } - - return sqlArr -} - // 获取建表ddl -func (dd *DMMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { +func (dd *DMMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := dd.GetTables(tableName) @@ -300,7 +193,8 @@ func (dd *DMMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (stri logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := dd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) + dialect := dd.dc.GetDialect() + tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := dd.GetTableIndex(tableName) if err != nil { @@ -308,12 +202,12 @@ func (dd *DMMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (stri return "", err } // 组装返回 - tableDDLArr = append(tableDDLArr, dd.GenerateIndexDDL(indexs, *tableInfo)...) + tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...) return strings.Join(tableDDLArr, ";\n"), nil } // 获取DM当前连接的库可访问的schemaNames -func (dd *DMMetaData) GetSchemas() ([]string, error) { +func (dd *DMMetadata) GetSchemas() ([]string, error) { sql := dbi.GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS) _, res, err := dd.dc.Query(sql) if err != nil { @@ -325,15 +219,3 @@ func (dd *DMMetaData) GetSchemas() ([]string, error) { } return schemaNames, nil } - -func (dd *DMMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (dd *DMMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} - -func (dd *DMMetaData) GetDumpHelper() dbi.DumpHelper { - return new(DumpHelper) -} diff --git a/server/internal/db/dbm/mssql/dialect.go b/server/internal/db/dbm/mssql/dialect.go index c3b60f04..4c54eee8 100644 --- a/server/internal/db/dbm/mssql/dialect.go +++ b/server/internal/db/dbm/mssql/dialect.go @@ -63,12 +63,12 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns return count, nil } - msMetadata := md.dc.GetMetaData() + msMetadata := md.dc.GetMetadata() schema := md.dc.Info.CurrentSchema() ignoreDupSql := "" if duplicateStrategy == dbi.DuplicateStrategyIgnore { // ALTER TABLE dbo.TEST ADD CONSTRAINT uniqueRows UNIQUE (ColA, ColB, ColC, ColD) WITH (IGNORE_DUP_KEY = ON) - indexs, _ := msMetadata.MetaData.(*MssqlMetaData).getTableIndexWithPK(tableName) + indexs, _ := msMetadata.(*MssqlMetadata).getTableIndexWithPK(tableName) // 收集唯一索引涉及到的字段 uniqueColumns := make([]string, 0) for _, index := range indexs { @@ -99,7 +99,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns // 去除最后一个逗号 placeholder = strings.TrimSuffix(repeated, ",") - baseTable := fmt.Sprintf("%s.%s", msMetadata.QuoteIdentifier(schema), msMetadata.QuoteIdentifier(tableName)) + baseTable := fmt.Sprintf("%s.%s", md.QuoteIdentifier(schema), md.QuoteIdentifier(tableName)) sqlStr := fmt.Sprintf("insert into %s (%s) values %s", baseTable, strings.Join(columns, ","), placeholder) // 执行批量insert sql @@ -117,7 +117,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns } func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) { - msMetadata := md.dc.GetMetaData() + msMetadata := md.dc.GetMetadata() schema := md.dc.Info.CurrentSchema() // 收集MERGE 语句的 ON 子句条件 @@ -136,7 +136,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ } if col.IsPrimaryKey { pkCols = append(pkCols, col.ColumnName) - name := msMetadata.QuoteIdentifier(col.ColumnName) + name := md.QuoteIdentifier(col.ColumnName) caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name)) } } @@ -150,7 +150,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ // 源数据占位sql phs := make([]string, 0) for _, column := range columns { - if !collx.ArrayContains(identityCols, msMetadata.RemoveQuote(column)) { + if !collx.ArrayContains(identityCols, md.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } insertCols = append(insertCols, column) @@ -168,7 +168,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ } t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + msMetadata.QuoteIdentifier(schema) + "." + msMetadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ") + sqlTemp := "MERGE INTO " + md.QuoteIdentifier(schema) + "." + md.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ") sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") " sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -185,14 +185,14 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ } func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error { - msMetadata := md.dc.GetMetaData().MetaData.(*MssqlMetaData) + msMetadata := md.dc.GetMetadata() schema := md.dc.Info.CurrentSchema() // 生成新表名,为老表明+_copy_时间戳 newTableName := copy.TableName + "_copy_" + time.Now().Format("20060102150405") // 复制建表语句 - ddl, err := msMetadata.CopyTableDDL(copy.TableName, newTableName) + ddl, err := md.CopyTableDDL(copy.TableName, newTableName) if err != nil { return err } @@ -239,14 +239,180 @@ func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (md *MssqlDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { - sqlArr := md.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable) - _, err := md.dc.Exec(strings.Join(sqlArr, ";")) - return len(sqlArr), err +func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (string, error) { + if newTableName == "" { + newTableName = tableName + } + metadata := md.dc.GetMetadata() + // 查询表名和表注释, 设置表注释 + tbs, err := metadata.GetTables(tableName) + if err != nil || len(tbs) < 1 { + logx.Errorf("获取表信息失败, %s", tableName) + return "", err + } + tabInfo := &dbi.Table{ + TableName: newTableName, + TableComment: tbs[0].TableComment, + } + + // 查询列信息 + columns, err := metadata.GetColumns(tableName) + if err != nil { + logx.Errorf("获取列信息失败, %s", tableName) + return "", err + } + sqlArr := md.GenerateTableDDL(columns, *tabInfo, true) + + // 设置索引 + indexs, err := metadata.GetTableIndex(tableName) + if err != nil { + logx.Errorf("获取索引信息失败, %s", tableName) + return strings.Join(sqlArr, ";"), err + } + sqlArr = append(sqlArr, md.GenerateIndexDDL(indexs, *tabInfo)...) + return strings.Join(sqlArr, ";"), nil } -func (md *MssqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - sqlArr := md.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) - _, err := md.dc.Exec(strings.Join(sqlArr, ";")) - return err +// 获取建表ddl +func (md *MssqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + tbName := tableInfo.TableName + schemaName := md.dc.Info.CurrentSchema() + + sqlArr := make([]string, 0) + + // 删除表 + if dropBeforeCreate { + sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName))) + } + + // 组装建表语句 + createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName)) + fields := make([]string, 0) + pks := make([]string, 0) + columnComments := make([]string, 0) + + for _, column := range columns { + if column.IsPrimaryKey { + pks = append(pks, md.QuoteIdentifier(column.ColumnName)) + } + fields = append(fields, md.genColumnBasicSql(column)) + commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'" + + // 防止注释内含有特殊字符串导致sql出错 + if column.ColumnComment != "" { + comment := md.QuoteEscape(column.ColumnComment) + columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName)) + } + } + + // create + createSql += strings.Join(fields, ",\n") + if len(pks) > 0 { + createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ",")) + } + createSql += "\n)" + + // comment + tableCommentSql := "" + if tableInfo.TableComment != "" { + commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'" + + tableCommentSql = fmt.Sprintf(commentTmp, md.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName) + } + + sqlArr = append(sqlArr, createSql) + + if tableCommentSql != "" { + sqlArr = append(sqlArr, tableCommentSql) + } + if len(columnComments) > 0 { + sqlArr = append(sqlArr, columnComments...) + } + + return sqlArr +} + +// 获取建索引ddl +func (md *MssqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + tbName := tableInfo.TableName + sqls := make([]string, 0) + comments := make([]string, 0) + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = md.QuoteIdentifier(name) + } + + sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, md.QuoteIdentifier(index.IndexName), md.QuoteIdentifier(md.dc.Info.CurrentSchema()), md.QuoteIdentifier(tbName), strings.Join(colNames, ","))) + if index.IndexComment != "" { + comment := md.QuoteEscape(index.IndexComment) + comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName)) + } + } + if len(comments) > 0 { + sqls = append(sqls, comments...) + } + + return sqls +} + +func (md *MssqlDialect) GetIdentifierQuoteString() string { + return "[" +} + +func (md *MssqlDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (md *MssqlDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (md *MssqlDialect) GetDumpHelper() dbi.DumpHelper { + return new(DumpHelper) +} + +func (md *MssqlDialect) genColumnBasicSql(column dbi.Column) string { + colName := md.QuoteIdentifier(column.ColumnName) + dataType := string(column.DataType) + + incr := "" + if column.IsIdentity { + incr = " IDENTITY(1,1)" + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + + defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 + if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { + // 哪些字段类型默认值需要加引号 + mark := false + if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) { + // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && + collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { + mark = false + } else { + mark = true + } + } + + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) + } + } + + columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) + return columnSql } diff --git a/server/internal/db/dbm/mssql/meta.go b/server/internal/db/dbm/mssql/meta.go index 406e3389..f84e3d7f 100644 --- a/server/internal/db/dbm/mssql/meta.go +++ b/server/internal/db/dbm/mssql/meta.go @@ -57,6 +57,6 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &MssqlDialect{dc: conn} } -func (mm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { - return dbi.NewMetaDataX(&MssqlMetaData{dc: conn}) +func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { + return &MssqlMetadata{dc: conn} } diff --git a/server/internal/db/dbm/mssql/metadata.go b/server/internal/db/dbm/mssql/metadata.go index af2f5fe1..a6cac5f6 100644 --- a/server/internal/db/dbm/mssql/metadata.go +++ b/server/internal/db/dbm/mssql/metadata.go @@ -22,13 +22,13 @@ const ( MSSQL_COLUMN_MA_KEY = "MSSQL_COLUMN_MA" ) -type MssqlMetaData struct { - dbi.DefaultMetaData +type MssqlMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn } -func (md *MssqlMetaData) GetDbServer() (*dbi.DbServer, error) { +func (md *MssqlMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := md.dc.Query("SELECT @@VERSION as version") if err != nil { return nil, err @@ -39,7 +39,7 @@ func (md *MssqlMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (md *MssqlMetaData) GetDbNames() ([]string, error) { +func (md *MssqlMetadata) GetDbNames() ([]string, error) { _, res, err := md.dc.Query(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_DBS_KEY)) if err != nil { return nil, err @@ -54,11 +54,11 @@ func (md *MssqlMetaData) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (md *MssqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { - meta := md.dc.GetMetaData() +func (md *MssqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := md.dc.GetDialect() schema := md.dc.Info.CurrentSchema() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -89,11 +89,11 @@ func (md *MssqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (md *MssqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - meta := md.dc.GetMetaData() - columnHelper := meta.GetColumnHelper() +func (md *MssqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := md.dc.GetDialect() + columnHelper := dialect.GetColumnHelper() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_COLUMN_MA_KEY), tableName), md.dc.Info.CurrentSchema()) @@ -126,7 +126,7 @@ func (md *MssqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) } // 获取表主键字段名,不存在主键标识则默认第一个字段 -func (md *MssqlMetaData) GetPrimaryKey(tablename string) (string, error) { +func (md *MssqlMetadata) GetPrimaryKey(tablename string) (string, error) { columns, err := md.GetColumns(tablename) if err != nil { return "", err @@ -145,7 +145,7 @@ func (md *MssqlMetaData) GetPrimaryKey(tablename string) (string, error) { } // 需要收集唯一键涉及的字段,所以需要查询出带主键的索引 -func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, error) { +func (md *MssqlMetadata) getTableIndexWithPK(tableName string) ([]dbi.Index, error) { _, res, err := md.dc.Query(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_INDEX_INFO_KEY), md.dc.Info.CurrentSchema(), tableName) if err != nil { return nil, err @@ -182,7 +182,7 @@ func (md *MssqlMetaData) getTableIndexWithPK(tableName string) ([]dbi.Index, err } // 获取表索引信息 -func (md *MssqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (md *MssqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { indexs, _ := md.getTableIndexWithPK(tableName) result := make([]dbi.Index, 0) // 过滤掉主键索引 @@ -195,175 +195,8 @@ func (md *MssqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -func (md *MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (string, error) { - if newTableName == "" { - newTableName = tableName - } - meta := md.dc.GetMetaData() - // 查询表名和表注释, 设置表注释 - tbs, err := md.GetTables(tableName) - if err != nil || len(tbs) < 1 { - logx.Errorf("获取表信息失败, %s", tableName) - return "", err - } - tabInfo := &dbi.Table{ - TableName: newTableName, - TableComment: tbs[0].TableComment, - } - - // 查询列信息 - columns, err := md.GetColumns(tableName) - if err != nil { - logx.Errorf("获取列信息失败, %s", tableName) - return "", err - } - sqlArr := meta.GenerateTableDDL(columns, *tabInfo, true) - - // 设置索引 - indexs, err := md.GetTableIndex(tableName) - if err != nil { - logx.Errorf("获取索引信息失败, %s", tableName) - return strings.Join(sqlArr, ";"), err - } - sqlArr = append(sqlArr, meta.GenerateIndexDDL(indexs, *tabInfo)...) - return strings.Join(sqlArr, ";"), nil -} - -// 获取建索引ddl - -func (md *MssqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - tbName := tableInfo.TableName - meta := md.dc.GetMetaData() - sqls := make([]string, 0) - comments := make([]string, 0) - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - - sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(md.dc.Info.CurrentSchema()), meta.QuoteIdentifier(tbName), strings.Join(colNames, ","))) - if index.IndexComment != "" { - comment := meta.QuoteEscape(index.IndexComment) - comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName)) - } - } - if len(comments) > 0 { - sqls = append(sqls, comments...) - } - - return sqls -} - -func (md *MssqlMetaData) genColumnBasicSql(column dbi.Column) string { - meta := md.dc.GetMetaData() - colName := meta.QuoteIdentifier(column.ColumnName) - dataType := string(column.DataType) - - incr := "" - if column.IsIdentity { - incr = " IDENTITY(1,1)" - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 - if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { - // 哪些字段类型默认值需要加引号 - mark := false - if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) { - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && - collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { - mark = false - } else { - mark = true - } - } - - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } - - columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal) - return columnSql -} - // 获取建表ddl -func (md *MssqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - tbName := tableInfo.TableName - schemaName := md.dc.Info.CurrentSchema() - meta := md.dc.GetMetaData() - - sqlArr := make([]string, 0) - - // 删除表 - if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", meta.QuoteIdentifier(schemaName), meta.QuoteIdentifier(tbName))) - } - - // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", meta.QuoteIdentifier(schemaName), meta.QuoteIdentifier(tbName)) - fields := make([]string, 0) - pks := make([]string, 0) - columnComments := make([]string, 0) - - for _, column := range columns { - if column.IsPrimaryKey { - pks = append(pks, meta.QuoteIdentifier(column.ColumnName)) - } - fields = append(fields, md.genColumnBasicSql(column)) - commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'" - - // 防止注释内含有特殊字符串导致sql出错 - if column.ColumnComment != "" { - comment := meta.QuoteEscape(column.ColumnComment) - columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName)) - } - } - - // create - createSql += strings.Join(fields, ",\n") - if len(pks) > 0 { - createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ",")) - } - createSql += "\n)" - - // comment - tableCommentSql := "" - if tableInfo.TableComment != "" { - commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'" - - tableCommentSql = fmt.Sprintf(commentTmp, meta.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName) - } - - sqlArr = append(sqlArr, createSql) - - if tableCommentSql != "" { - sqlArr = append(sqlArr, tableCommentSql) - } - if len(columnComments) > 0 { - sqlArr = append(sqlArr, columnComments...) - } - - return sqlArr -} - -// 获取建表ddl -func (md *MssqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { - +func (md *MssqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := md.GetTables(tableName) tableInfo := &dbi.Table{} @@ -380,7 +213,8 @@ func (md *MssqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) + dialect := md.dc.GetDialect() + tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := md.GetTableIndex(tableName) if err != nil { @@ -388,11 +222,11 @@ func (md *MssqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s return "", err } // 组装返回 - tableDDLArr = append(tableDDLArr, md.GenerateIndexDDL(indexs, *tableInfo)...) + tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...) return strings.Join(tableDDLArr, ";\n"), nil } -func (md *MssqlMetaData) GetSchemas() ([]string, error) { +func (md *MssqlMetadata) GetSchemas() ([]string, error) { _, res, err := md.dc.Query(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_DB_SCHEMAS_KEY)) if err != nil { return nil, err @@ -404,19 +238,3 @@ func (md *MssqlMetaData) GetSchemas() ([]string, error) { } return schemas, nil } - -func (md *MssqlMetaData) GetIdentifierQuoteString() string { - return "[" -} - -func (md *MssqlMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (md *MssqlMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} - -func (md *MssqlMetaData) GetDumpHelper() dbi.DumpHelper { - return new(DumpHelper) -} diff --git a/server/internal/db/dbm/mysql/dialect.go b/server/internal/db/dbm/mysql/dialect.go index 2117929e..04ed04e4 100644 --- a/server/internal/db/dbm/mysql/dialect.go +++ b/server/internal/db/dbm/mysql/dialect.go @@ -4,6 +4,9 @@ import ( "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" + "mayfly-go/internal/db/dbm/sqlparser" + "mayfly-go/internal/db/dbm/sqlparser/mysql" + "mayfly-go/pkg/utils/collx" "strings" "time" ) @@ -41,7 +44,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri prefix = "replace into" } - sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) // 执行批量insert sql // 把二维数组转为一维数组 var args []any @@ -72,25 +75,137 @@ func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (md *MysqlDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { - sqlArr := md.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable) - for _, sqlStr := range sqlArr { - _, err := md.dc.Exec(sqlStr) - if err != nil { - return 0, err - } +// 获取建表ddl +func (md *MysqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + sqlArr := make([]string, 0) + + if dropBeforeCreate { + sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", md.QuoteIdentifier(tableInfo.TableName))) } - return len(sqlArr), nil + + // 组装建表语句 + createSql := fmt.Sprintf("CREATE TABLE %s (\n", md.QuoteIdentifier(tableInfo.TableName)) + fields := make([]string, 0) + pks := make([]string, 0) + + for _, column := range columns { + if column.IsPrimaryKey { + pks = append(pks, column.ColumnName) + } + fields = append(fields, md.genColumnBasicSql(column)) + } + + // 建表ddl + createSql += strings.Join(fields, ",\n") + if len(pks) > 0 { + createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) + } + createSql += "\n)" + + // 表注释 + if tableInfo.TableComment != "" { + createSql += fmt.Sprintf(" COMMENT '%s'", md.QuoteEscape(tableInfo.TableComment)) + } + + sqlArr = append(sqlArr, createSql) + + return sqlArr } -func (md *MysqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - meta := md.dc.GetMetaData() - sqlArr := meta.GenerateIndexDDL(indexs, tableInfo) - for _, sqlStr := range sqlArr { - _, err := md.dc.Exec(sqlStr) - if err != nil { - return err +// 获取建索引ddl +func (md *MysqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + sqlArr := make([]string, 0) + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = md.QuoteIdentifier(name) + } + sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE" + sqlStr := fmt.Sprintf(sqlTmp, md.QuoteIdentifier(tableInfo.TableName), unique, md.QuoteIdentifier(index.IndexName), strings.Join(colNames, ",")) + comment := md.QuoteEscape(index.IndexComment) + if comment != "" { + sqlStr += fmt.Sprintf(" COMMENT '%s'", comment) + } + sqlArr = append(sqlArr, sqlStr) + } + return sqlArr +} + +func (md *MysqlDialect) genColumnBasicSql(column dbi.Column) string { + dataType := string(column.DataType) + + incr := "" + if column.IsIdentity { + incr = " AUTO_INCREMENT" + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + columnType := column.GetColumnType() + if nullAble == "" && strings.Contains(columnType, "timestamp") { + nullAble = " NULL" + } + + defVal := "" // 默认值需要判断引号,如函数是不需要引号的 + if column.ColumnDefault != "" && + // 当默认值是字符串'NULL'时,不需要设置默认值 + column.ColumnDefault != "NULL" && + // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 + !strings.Contains(column.ColumnDefault, "(") { + // 哪些字段类型默认值需要加引号 + mark := false + if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) { + // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && + collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { + mark = false + } else { + mark = true + } + } + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) } } - return nil + comment := "" + if column.ColumnComment != "" { + // 防止注释内含有特殊字符串导致sql出错 + commentStr := md.QuoteEscape(column.ColumnComment) + comment = fmt.Sprintf(" COMMENT '%s'", commentStr) + } + + columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment) + return columnSql +} + +func (md *MysqlDialect) GetIdentifierQuoteString() string { + return "`" +} + +func (md *MysqlDialect) QuoteLiteral(literal string) string { + literal = strings.ReplaceAll(literal, `\`, `\\`) + literal = strings.ReplaceAll(literal, `'`, `''`) + return "'" + literal + "'" +} + +func (md *MysqlDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (md *MysqlDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (pd *MysqlDialect) GetSQLParser() sqlparser.SqlParser { + return new(mysql.MysqlParser) } diff --git a/server/internal/db/dbm/mysql/helper.go b/server/internal/db/dbm/mysql/helper.go index 48c28bc5..9c739806 100644 --- a/server/internal/db/dbm/mysql/helper.go +++ b/server/internal/db/dbm/mysql/helper.go @@ -147,7 +147,6 @@ func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any { } func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string { - if dbColumnValue == nil { return "NULL" } diff --git a/server/internal/db/dbm/mysql/meta.go b/server/internal/db/dbm/mysql/meta.go index cd5d4799..93e592f1 100644 --- a/server/internal/db/dbm/mysql/meta.go +++ b/server/internal/db/dbm/mysql/meta.go @@ -43,6 +43,6 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &MysqlDialect{dc: conn} } -func (mm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { - return dbi.NewMetaDataX(&MysqlMetaData{dc: conn}) +func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { + return &MysqlMetadata{dc: conn} } diff --git a/server/internal/db/dbm/mysql/metadata.go b/server/internal/db/dbm/mysql/metadata.go index 6930a51b..4e4d3b49 100644 --- a/server/internal/db/dbm/mysql/metadata.go +++ b/server/internal/db/dbm/mysql/metadata.go @@ -10,7 +10,6 @@ import ( "mayfly-go/pkg/utils/stringx" "strings" - // "github.com/kanzihuang/vitess/go/vt/sqlparser" "github.com/may-fly/cast" ) @@ -22,13 +21,13 @@ const ( MYSQL_COLUMN_MA_KEY = "MYSQL_COLUMN_MA" ) -type MysqlMetaData struct { - dbi.DefaultMetaData +type MysqlMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn } -func (md *MysqlMetaData) GetDbServer() (*dbi.DbServer, error) { +func (md *MysqlMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := md.dc.Query("SELECT VERSION() version") if err != nil { return nil, err @@ -39,7 +38,7 @@ func (md *MysqlMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (md *MysqlMetaData) GetDbNames() ([]string, error) { +func (md *MysqlMetadata) GetDbNames() ([]string, error) { _, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_DBS)) if err != nil { return nil, err @@ -52,10 +51,10 @@ func (md *MysqlMetaData) GetDbNames() ([]string, error) { return databases, nil } -func (md *MysqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { - meta := md.dc.GetMetaData() +func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := md.dc.GetDialect() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -86,11 +85,11 @@ func (md *MysqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (md *MysqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - meta := md.dc.GetMetaData() - columnHelper := meta.GetColumnHelper() +func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := md.dc.GetDialect() + columnHelper := dialect.GetColumnHelper() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) @@ -122,7 +121,7 @@ func (md *MysqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) } // 获取表主键字段名,不存在主键标识则默认第一个字段 -func (md *MysqlMetaData) GetPrimaryKey(tablename string) (string, error) { +func (md *MysqlMetadata) GetPrimaryKey(tablename string) (string, error) { columns, err := md.GetColumns(tablename) if err != nil { return "", err @@ -141,7 +140,7 @@ func (md *MysqlMetaData) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (md *MysqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { _, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName) if err != nil { return nil, err @@ -178,124 +177,8 @@ func (md *MysqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -// 获取建索引ddl -func (md *MysqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - meta := md.dc.GetMetaData() - sqlArr := make([]string, 0) - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE" - sqlStr := fmt.Sprintf(sqlTmp, meta.QuoteIdentifier(tableInfo.TableName), unique, meta.QuoteIdentifier(index.IndexName), strings.Join(colNames, ",")) - comment := meta.QuoteEscape(index.IndexComment) - if comment != "" { - sqlStr += fmt.Sprintf(" COMMENT '%s'", comment) - } - sqlArr = append(sqlArr, sqlStr) - } - return sqlArr -} - -func (md *MysqlMetaData) genColumnBasicSql(column dbi.Column) string { - meta := md.dc.GetMetaData() - dataType := string(column.DataType) - - incr := "" - if column.IsIdentity { - incr = " AUTO_INCREMENT" - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - columnType := column.GetColumnType() - if nullAble == "" && strings.Contains(columnType, "timestamp") { - nullAble = " NULL" - } - - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 - if column.ColumnDefault != "" && - // 当默认值是字符串'NULL'时,不需要设置默认值 - column.ColumnDefault != "NULL" && - // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 - !strings.Contains(column.ColumnDefault, "(") { - // 哪些字段类型默认值需要加引号 - mark := false - if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) { - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && - collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { - mark = false - } else { - mark = true - } - } - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } - comment := "" - if column.ColumnComment != "" { - // 防止注释内含有特殊字符串导致sql出错 - commentStr := meta.QuoteEscape(column.ColumnComment) - comment = fmt.Sprintf(" COMMENT '%s'", commentStr) - } - - columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.dc.GetMetaData().QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment) - return columnSql -} - // 获取建表ddl -func (md *MysqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - meta := md.dc.GetMetaData() - sqlArr := make([]string, 0) - - if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", meta.QuoteIdentifier(tableInfo.TableName))) - } - - // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s (\n", meta.QuoteIdentifier(tableInfo.TableName)) - fields := make([]string, 0) - pks := make([]string, 0) - - for _, column := range columns { - if column.IsPrimaryKey { - pks = append(pks, column.ColumnName) - } - fields = append(fields, md.genColumnBasicSql(column)) - } - - // 建表ddl - createSql += strings.Join(fields, ",\n") - if len(pks) > 0 { - createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) - } - createSql += "\n)" - - // 表注释 - if tableInfo.TableComment != "" { - createSql += fmt.Sprintf(" COMMENT '%s'", meta.QuoteEscape(tableInfo.TableComment)) - } - - sqlArr = append(sqlArr, createSql) - - return sqlArr -} - -// 获取建表ddl -func (md *MysqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { +func (md *MysqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := md.GetTables(tableName) tableInfo := &dbi.Table{} @@ -312,7 +195,9 @@ func (md *MysqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := md.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) + + dialect := md.dc.GetDialect() + tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := md.GetTableIndex(tableName) if err != nil { @@ -320,32 +205,10 @@ func (md *MysqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s return "", err } // 组装返回 - tableDDLArr = append(tableDDLArr, md.GenerateIndexDDL(indexs, *tableInfo)...) + tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...) return strings.Join(tableDDLArr, ";\n"), nil } -func (md *MysqlMetaData) GetSchemas() ([]string, error) { +func (md *MysqlMetadata) GetSchemas() ([]string, error) { return nil, errors.New("不支持schema") } - -func (md *MysqlMetaData) GetIdentifierQuoteString() string { - return "`" -} - -func (md *MysqlMetaData) QuoteLiteral(literal string) string { - literal = strings.ReplaceAll(literal, `\`, `\\`) - literal = strings.ReplaceAll(literal, `'`, `''`) - return "'" + literal + "'" -} - -// func (md *MysqlMetaData) GetSqlParserDialect() sqlparser.Dialect { -// return sqlparser.MysqlDialect{} -// } - -func (md *MysqlMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (md *MysqlMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} diff --git a/server/internal/db/dbm/oracle/dialect.go b/server/internal/db/dbm/oracle/dialect.go index c66c4785..0c68a1c3 100644 --- a/server/internal/db/dbm/oracle/dialect.go +++ b/server/internal/db/dbm/oracle/dialect.go @@ -38,7 +38,7 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str // 简单批量插入sql,无需判断键冲突策略 func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, values [][]any, duplicateStrategy int, tx *sql.Tx) (int64, error) { - metadata := od.dc.GetMetaData() + metadata := od.dc.GetMetadata() // 忽略键冲突策略 ignore := "" if duplicateStrategy == dbi.DuplicateStrategyIgnore { @@ -66,7 +66,7 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v for i := 0; i < len(value); i++ { placeholder = append(placeholder, fmt.Sprintf(":%d", i+1)) } - sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, metadata.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")) + sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, od.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")) // oracle数据库为了兼容ignore主键冲突,只能一条条的执行insert res, err := od.dc.TxExec(tx, sqlTemp, value...) @@ -82,7 +82,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string, // 查询主键字段 uniqueCols := make([]string, 0) caseSqls := make([]string, 0) - metadata := od.dc.GetMetaData() + metadata := od.dc.GetMetadata() // 查询唯一索引涉及到的字段,并组装到match条件内 indexs, _ := metadata.GetTableIndex(tableName) if indexs != nil { @@ -94,7 +94,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string, if !collx.ArrayContains(uniqueCols, col) { uniqueCols = append(uniqueCols, col) } - tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", metadata.QuoteIdentifier(col), metadata.QuoteIdentifier(col))) + tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", od.QuoteIdentifier(col), od.QuoteIdentifier(col))) } caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND "))) } @@ -111,7 +111,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string, upds := make([]string, 0) insertCols := make([]string, 0) for _, column := range columns { - if !collx.ArrayContains(uniqueCols, metadata.RemoveQuote(column)) { + if !collx.ArrayContains(uniqueCols, od.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } insertCols = append(insertCols, fmt.Sprintf("T1.%s", column)) @@ -132,7 +132,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string, t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + metadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") " + sqlTemp := "MERGE INTO " + od.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") " sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") " sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -155,22 +155,132 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (od *OracleDialect) CreateTable(commonColumns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { - meta := od.dc.GetMetaData() - sqlArr := meta.GenerateTableDDL(commonColumns, tableInfo, dropOldTable) - // 需要分开执行sql - for _, sqlStr := range sqlArr { - _, err := od.dc.Exec(sqlStr) - if err != nil { - return 0, err +// 获取建表ddl +func (od *OracleDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + quoteTableName := od.QuoteIdentifier(tableInfo.TableName) + sqlArr := make([]string, 0) + + if dropBeforeCreate { + dropSqlTmp := ` +declare + num number; +begin + select count(1) into num from user_tables where table_name = '%s' and owner = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM dual) ; + if num > 0 then + execute immediate 'drop table "%s"' ; + end if; +end` + sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, tableInfo.TableName, tableInfo.TableName)) + } + + // 组装建表语句 + createSql := fmt.Sprintf("CREATE TABLE %s ( \n", quoteTableName) + fields := make([]string, 0) + pks := make([]string, 0) + columnComments := make([]string, 0) + // 把通用类型转换为达梦类型 + for _, column := range columns { + if column.IsPrimaryKey { + pks = append(pks, od.QuoteIdentifier(column.ColumnName)) + } + fields = append(fields, od.genColumnBasicSql(column)) + // 防止注释内含有特殊字符串导致sql出错 + if column.ColumnComment != "" { + comment := od.QuoteEscape(column.ColumnComment) + columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, od.QuoteIdentifier(column.ColumnName), comment)) } } - return len(sqlArr), nil + + // 建表 + createSql += strings.Join(fields, ",\n") + if len(pks) > 0 { + createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) + } + createSql += "\n)" + sqlArr = append(sqlArr, createSql) + + // 表注释 + tableCommentSql := "" + if tableInfo.TableComment != "" { + tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", od.QuoteIdentifier(tableInfo.TableName), od.QuoteEscape(tableInfo.TableComment)) + sqlArr = append(sqlArr, tableCommentSql) + } + + // 列注释 + if len(columnComments) > 0 { + sqlArr = append(sqlArr, columnComments...) + } + otherSql := od.GenerateTableOtherDDL(tableInfo, quoteTableName, columns) + if len(otherSql) > 0 { + sqlArr = append(sqlArr, otherSql...) + } + + return sqlArr } -func (od *OracleDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - meta := od.dc.GetMetaData() - sqlArr := meta.GenerateIndexDDL(indexs, tableInfo) - _, err := od.dc.Exec(strings.Join(sqlArr, ";")) - return err +// 获取建索引ddl +func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + sqls := make([]string, 0) + comments := make([]string, 0) + + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = od.QuoteIdentifier(name) + } + + sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, od.QuoteIdentifier(index.IndexName), od.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) + } + + sqlArr := make([]string, 0) + + sqlArr = append(sqlArr, sqls...) + + if len(comments) > 0 { + sqlArr = append(sqlArr, comments...) + } + + return sqlArr +} + +// 11g及以下版本会设置自增序列 +func (od *OracleDialect) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string { + return nil +} + +func (od *OracleDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (od *OracleDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (od *OracleDialect) genColumnBasicSql(column dbi.Column) string { + colName := od.QuoteIdentifier(column.ColumnName) + + if column.IsIdentity { + // 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number + return fmt.Sprintf(" %s NUMBER generated by default as IDENTITY", colName) + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + + defVal := "" + if column.ColumnDefault != "" { + defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault) + } + + columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble) + return columnSql } diff --git a/server/internal/db/dbm/oracle/meta.go b/server/internal/db/dbm/oracle/meta.go index 6ae8940e..d17739f6 100644 --- a/server/internal/db/dbm/oracle/meta.go +++ b/server/internal/db/dbm/oracle/meta.go @@ -87,7 +87,7 @@ func (om *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &OracleDialect{dc: conn} } -func (om *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { +func (om *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { // 查询数据库版本信息,以做兼容性处理 if conn.Info.Version == "" && !conn.Info.DefaultVersion { @@ -107,10 +107,10 @@ func (om *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { } if conn.Info.Version == DbVersionOracle11 { - md := &OracleMetaData11{} + md := &OracleMetadata11{} md.dc = conn md.version = DbVersionOracle11 - return dbi.NewMetaDataX(md) + return md } - return dbi.NewMetaDataX(&OracleMetaData{dc: conn}) + return &OracleMetadata{dc: conn} } diff --git a/server/internal/db/dbm/oracle/metadata.go b/server/internal/db/dbm/oracle/metadata.go index a35fdfc4..a76bb0b8 100644 --- a/server/internal/db/dbm/oracle/metadata.go +++ b/server/internal/db/dbm/oracle/metadata.go @@ -21,19 +21,19 @@ const ( ORACLE_COLUMN_MA_KEY = "ORACLE_COLUMN_MA" ) -type OracleMetaData struct { - dbi.DefaultMetaData +type OracleMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn version dbi.DbVersion } -func (od *OracleMetaData11) GetCompatibleDbVersion() dbi.DbVersion { +func (od *OracleMetadata11) GetCompatibleDbVersion() dbi.DbVersion { return od.version } -func (od *OracleMetaData) GetDbServer() (*dbi.DbServer, error) { +func (od *OracleMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := od.dc.Query("select * from v$instance") if err != nil { return nil, err @@ -44,7 +44,7 @@ func (od *OracleMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (od *OracleMetaData) GetDbNames() ([]string, error) { +func (od *OracleMetadata) GetDbNames() ([]string, error) { _, res, err := od.dc.Query("SELECT name AS DBNAME FROM v$database") if err != nil { return nil, err @@ -58,10 +58,10 @@ func (od *OracleMetaData) GetDbNames() ([]string, error) { return databases, nil } -func (od *OracleMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { - meta := od.dc.GetMetaData() +func (od *OracleMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := od.dc.GetDialect() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -92,10 +92,10 @@ func (od *OracleMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - meta := od.dc.GetMetaData() +func (od *OracleMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := od.dc.GetDialect() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") // 如果表数量超过了1000,需要分批查询 @@ -121,7 +121,7 @@ func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) return nil, err } - columnHelper := meta.GetColumnHelper() + columnHelper := dialect.GetColumnHelper() columns := make([]dbi.Column, 0) for _, re := range res { column := dbi.Column{ @@ -144,7 +144,7 @@ func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) return columns, nil } -func (od *OracleMetaData) GetPrimaryKey(tablename string) (string, error) { +func (od *OracleMetadata) GetPrimaryKey(tablename string) (string, error) { columns, err := od.GetColumns(tablename) if err != nil { return "", err @@ -162,7 +162,7 @@ func (od *OracleMetaData) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (od *OracleMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (od *OracleMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { _, res, err := od.dc.Query(fmt.Sprintf(dbi.GetLocalSql(ORACLE_META_FILE, ORACLE_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err @@ -199,135 +199,9 @@ func (od *OracleMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -// 获取建索引ddl -func (od *OracleMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - - meta := od.dc.GetMetaData() - sqls := make([]string, 0) - comments := make([]string, 0) - - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - - sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) - } - - sqlArr := make([]string, 0) - - sqlArr = append(sqlArr, sqls...) - - if len(comments) > 0 { - sqlArr = append(sqlArr, comments...) - } - - return sqlArr -} - -func (od *OracleMetaData) genColumnBasicSql(column dbi.Column) string { - meta := od.dc.GetMetaData() - colName := meta.QuoteIdentifier(column.ColumnName) - - if column.IsIdentity { - // 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number - return fmt.Sprintf(" %s NUMBER generated by default as IDENTITY", colName) - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - - defVal := "" - if column.ColumnDefault != "" { - defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault) - } - - columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble) - return columnSql -} - // 获取建表ddl -func (od *OracleMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - meta := od.dc.GetMetaData() - quoteTableName := meta.QuoteIdentifier(tableInfo.TableName) - sqlArr := make([]string, 0) - - if dropBeforeCreate { - dropSqlTmp := ` -declare - num number; -begin - select count(1) into num from user_tables where table_name = '%s' and owner = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM dual) ; - if num > 0 then - execute immediate 'drop table "%s"' ; - end if; -end` - sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, tableInfo.TableName, tableInfo.TableName)) - } - - // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s ( \n", quoteTableName) - fields := make([]string, 0) - pks := make([]string, 0) - columnComments := make([]string, 0) - // 把通用类型转换为达梦类型 - for _, column := range columns { - if column.IsPrimaryKey { - pks = append(pks, meta.QuoteIdentifier(column.ColumnName)) - } - fields = append(fields, od.genColumnBasicSql(column)) - // 防止注释内含有特殊字符串导致sql出错 - if column.ColumnComment != "" { - comment := meta.QuoteEscape(column.ColumnComment) - columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment)) - } - } - - // 建表 - createSql += strings.Join(fields, ",\n") - if len(pks) > 0 { - createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) - } - createSql += "\n)" - sqlArr = append(sqlArr, createSql) - - // 表注释 - tableCommentSql := "" - if tableInfo.TableComment != "" { - tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", meta.QuoteIdentifier(tableInfo.TableName), meta.QuoteEscape(tableInfo.TableComment)) - sqlArr = append(sqlArr, tableCommentSql) - } - - // 列注释 - if len(columnComments) > 0 { - sqlArr = append(sqlArr, columnComments...) - } - otherSql := od.GenerateTableOtherDDL(tableInfo, quoteTableName, columns) - if len(otherSql) > 0 { - sqlArr = append(sqlArr, otherSql...) - } - - return sqlArr -} - -// 11g及以下版本会设置自增序列 -func (od *OracleMetaData) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string { - return nil -} - -// 获取建表ddl -func (od *OracleMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { - +func (od *OracleMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { + dialect := od.dc.GetDialect() // 1.获取表信息 tbs, err := od.GetTables(tableName) tableInfo := &dbi.Table{} @@ -344,7 +218,7 @@ func (od *OracleMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) ( logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := od.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) + tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := od.GetTableIndex(tableName) if err != nil { @@ -352,12 +226,12 @@ func (od *OracleMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) ( return "", err } // 组装返回 - tableDDLArr = append(tableDDLArr, od.GenerateIndexDDL(indexs, *tableInfo)...) + tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...) return strings.Join(tableDDLArr, ";\n"), nil } // 获取DM当前连接的库可访问的schemaNames -func (od *OracleMetaData) GetSchemas() ([]string, error) { +func (od *OracleMetadata) GetSchemas() ([]string, error) { sql := dbi.GetLocalSql(ORACLE_META_FILE, ORACLE_DB_SCHEMAS) _, res, err := od.dc.Query(sql) if err != nil { @@ -369,11 +243,3 @@ func (od *OracleMetaData) GetSchemas() ([]string, error) { } return schemaNames, nil } - -func (od *OracleMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (od *OracleMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} diff --git a/server/internal/db/dbm/oracle/metadata11.go b/server/internal/db/dbm/oracle/metadata11.go index f50508a7..6c0747bf 100644 --- a/server/internal/db/dbm/oracle/metadata11.go +++ b/server/internal/db/dbm/oracle/metadata11.go @@ -13,15 +13,15 @@ const ( ORACLE11_COLUMN_MA_KEY = "ORACLE11_COLUMN_MA" ) -type OracleMetaData11 struct { - OracleMetaData +type OracleMetadata11 struct { + OracleMetadata } // 获取列元信息, 如列名等 -func (od *OracleMetaData11) GetColumns(tableNames ...string) ([]dbi.Column, error) { - meta := od.dc.GetMetaData() +func (od *OracleMetadata11) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := od.dc.GetDialect() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") // 如果表数量超过了1000,需要分批查询 @@ -47,7 +47,7 @@ func (od *OracleMetaData11) GetColumns(tableNames ...string) ([]dbi.Column, erro return nil, err } - columnHelper := meta.GetColumnHelper() + columnHelper := dialect.GetColumnHelper() columns := make([]dbi.Column, 0) for _, re := range res { column := dbi.Column{ @@ -70,9 +70,9 @@ func (od *OracleMetaData11) GetColumns(tableNames ...string) ([]dbi.Column, erro return columns, nil } -func (od *OracleMetaData11) genColumnBasicSql(column dbi.Column) string { - meta := od.dc.GetMetaData() - colName := meta.QuoteIdentifier(column.ColumnName) +func (od *OracleMetadata11) genColumnBasicSql(column dbi.Column) string { + dialect := od.dc.GetDialect() + colName := dialect.QuoteIdentifier(column.ColumnName) if column.IsIdentity { // 11g以前的版本 如果是自增,自增列数据类型必须是number,不需要设置默认值和空值,建表后设置自增序列 @@ -94,7 +94,7 @@ func (od *OracleMetaData11) genColumnBasicSql(column dbi.Column) string { } // 11g及以下版本会设置自增序列和触发器 -func (od *OracleMetaData11) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string { +func (od *OracleMetadata11) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string { result := make([]string, 0) for _, col := range columns { if col.IsIdentity { diff --git a/server/internal/db/dbm/postgres/dialect.go b/server/internal/db/dbm/postgres/dialect.go index 47cccea1..2e506a89 100644 --- a/server/internal/db/dbm/postgres/dialect.go +++ b/server/internal/db/dbm/postgres/dialect.go @@ -4,8 +4,6 @@ import ( "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" - "mayfly-go/internal/db/dbm/sqlparser" - "mayfly-go/internal/db/dbm/sqlparser/pgsql" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" "strings" @@ -52,7 +50,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns) } - sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix) + sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix) // 执行批量insert sql return pd.dc.TxExec(tx, sqlStr, args...) @@ -86,7 +84,7 @@ func (pd *PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, table // 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考:https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138 func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string { suffix := "" - metadata := pd.dc.GetMetaData() + metadata := pd.dc.GetMetadata() if duplicateStrategy == dbi.DuplicateStrategyIgnore { suffix = " \n ON DUPLICATE KEY UPDATE NOTHING" } else if duplicateStrategy == dbi.DuplicateStrategyUpdate { @@ -110,7 +108,7 @@ func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, table suffix = " \n ON DUPLICATE KEY UPDATE " for i, col := range columns { // ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段,所以得去掉 - if !collx.ArrayContains(uniqueColumns, metadata.RemoveQuote(strings.ToLower(col))) { + if !collx.ArrayContains(uniqueColumns, pd.RemoveQuote(strings.ToLower(col))) { suffix += fmt.Sprintf("%s = excluded.%s", col, col) if i < len(columns)-1 { suffix += ", " @@ -178,17 +176,101 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (pd *PgsqlDialect) CreateTable(commonColumns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { - meta := pd.dc.GetMetaData() - sqlArr := meta.GenerateTableDDL(commonColumns, tableInfo, dropOldTable) - _, err := pd.dc.Exec(strings.Join(sqlArr, ";")) - return len(sqlArr), err +func (pd *PgsqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + quoteTableName := pd.QuoteIdentifier(tableInfo.TableName) + + sqlArr := make([]string, 0) + if dropBeforeCreate { + sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName)) + } + // 组装建表语句 + createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName) + fields := make([]string, 0) + pks := make([]string, 0) + columnComments := make([]string, 0) + commentTmp := "comment on column %s.%s is '%s'" + + for _, column := range columns { + if column.IsPrimaryKey { + pks = append(pks, pd.QuoteIdentifier(column.ColumnName)) + } + + fields = append(fields, pd.genColumnBasicSql(column)) + + // 防止注释内含有特殊字符串导致sql出错 + if column.ColumnComment != "" { + comment := pd.QuoteEscape(column.ColumnComment) + columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteIdentifier(column.ColumnName), comment)) + } + } + + createSql += strings.Join(fields, ",\n") + if len(pks) > 0 { + createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) + } + createSql += "\n)" + + tableCommentSql := "" + if tableInfo.TableComment != "" { + commentTmp := "comment on table %s is '%s'" + tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteEscape(tableInfo.TableComment)) + } + + // create + sqlArr = append(sqlArr, createSql) + + // table comment + if tableCommentSql != "" { + sqlArr = append(sqlArr, tableCommentSql) + } + // column comment + if len(columnComments) > 0 { + sqlArr = append(sqlArr, columnComments...) + } + + return sqlArr } -func (pd *PgsqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - sqlArr := pd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) - _, err := pd.dc.Exec(strings.Join(sqlArr, ";")) - return err +func (pd *PgsqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + creates := make([]string, 0) + drops := make([]string, 0) + comments := make([]string, 0) + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + + // 如果索引名存在,先删除索引 + drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName)) + + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = pd.QuoteIdentifier(name) + } + // 创建索引 + creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, pd.QuoteIdentifier(index.IndexName), pd.QuoteIdentifier(pd.dc.Info.CurrentSchema()), pd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) + if index.IndexComment != "" { + comment := pd.QuoteEscape(index.IndexComment) + comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment)) + } + } + + sqlArr := make([]string, 0) + + if len(drops) > 0 { + sqlArr = append(sqlArr, drops...) + } + + if len(creates) > 0 { + sqlArr = append(sqlArr, creates...) + } + if len(comments) > 0 { + sqlArr = append(sqlArr, comments...) + } + return sqlArr } func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) { @@ -199,6 +281,77 @@ func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) { } } -func (pd *PgsqlDialect) GetSQLParser() sqlparser.SqlParser { - return new(pgsql.PgsqlParser) +func (pd *PgsqlDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (pd *PgsqlDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (pd *PgsqlDialect) GetDumpHelper() dbi.DumpHelper { + return new(DumpHelper) +} + +func (pd *PgsqlDialect) genColumnBasicSql(column dbi.Column) string { + colName := pd.QuoteIdentifier(column.ColumnName) + dataType := string(column.DataType) + + // 如果数据类型是数字,则去掉长度 + if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) { + column.NumPrecision = 0 + column.CharMaxLength = 0 + } + + // 如果是自增类型,需要转换为serial + if column.IsIdentity { + if dataType == "int4" { + column.DataType = "serial" + } else if dataType == "int2" { + column.DataType = "smallserial" + } else if dataType == "int8" { + column.DataType = "bigserial" + } else { + column.DataType = "bigserial" + } + + return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType()) + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + + defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 + if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { + mark := false + // 哪些字段类型默认值需要加引号 + if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) { + // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && + collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) { + mark = false + } else { + mark = true + } + } + // 如果数据类型是日期时间,则写死默认值函数 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) { + column.ColumnDefault = "CURRENT_TIMESTAMP" + } + + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) + } + } + + // 如果是varchar,长度翻倍,防止报错 + if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) { + column.CharMaxLength = column.CharMaxLength * 2 + } + columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal) + return columnSql } diff --git a/server/internal/db/dbm/postgres/meta.go b/server/internal/db/dbm/postgres/meta.go index 0462a15f..89d2a6c2 100644 --- a/server/internal/db/dbm/postgres/meta.go +++ b/server/internal/db/dbm/postgres/meta.go @@ -85,8 +85,8 @@ func (pm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &PgsqlDialect{dc: conn} } -func (pm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { - return dbi.NewMetaDataX(&PgsqlMetaData{dc: conn}) +func (pm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { + return &PgsqlMetadata{dc: conn} } // pgsql dialer diff --git a/server/internal/db/dbm/postgres/metadata.go b/server/internal/db/dbm/postgres/metadata.go index 17abd11e..c5f97eff 100644 --- a/server/internal/db/dbm/postgres/metadata.go +++ b/server/internal/db/dbm/postgres/metadata.go @@ -2,7 +2,6 @@ package postgres import ( "fmt" - "io" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" @@ -21,13 +20,13 @@ const ( PGSQL_COLUMN_MA_KEY = "PGSQL_COLUMN_MA" ) -type PgsqlMetaData struct { - dbi.DefaultMetaData +type PgsqlMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn } -func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) { +func (pd *PgsqlMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := pd.dc.Query("SELECT version() as server_version") if err != nil { return nil, err @@ -38,7 +37,7 @@ func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (pd *PgsqlMetaData) GetDbNames() ([]string, error) { +func (pd *PgsqlMetadata) GetDbNames() ([]string, error) { _, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')") if err != nil { return nil, err @@ -52,10 +51,10 @@ func (pd *PgsqlMetaData) GetDbNames() ([]string, error) { return databases, nil } -func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { - meta := pd.dc.GetMetaData() +func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := pd.dc.GetDialect() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -86,10 +85,10 @@ func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - meta := pd.dc.GetMetaData() +func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dialect := pd.dc.GetDialect() tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", meta.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) @@ -97,7 +96,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) return nil, err } - columnHelper := meta.GetColumnHelper() + columnHelper := dialect.GetColumnHelper() columns := make([]dbi.Column, 0) for _, re := range res { column := dbi.Column{ @@ -119,7 +118,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) return columns, nil } -func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) { +func (pd *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) { columns, err := pd.GetColumns(tablename) if err != nil { return "", err @@ -137,7 +136,7 @@ func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err @@ -174,172 +173,8 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -func (pd *PgsqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - meta := pd.dc.GetMetaData() - creates := make([]string, 0) - drops := make([]string, 0) - comments := make([]string, 0) - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - - // 如果索引名存在,先删除索引 - drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName)) - - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - // 创建索引 - creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(pd.dc.Info.CurrentSchema()), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) - if index.IndexComment != "" { - comment := meta.QuoteEscape(index.IndexComment) - comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment)) - } - } - - sqlArr := make([]string, 0) - - if len(drops) > 0 { - sqlArr = append(sqlArr, drops...) - } - - if len(creates) > 0 { - sqlArr = append(sqlArr, creates...) - } - if len(comments) > 0 { - sqlArr = append(sqlArr, comments...) - } - return sqlArr -} - -func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string { - meta := pd.dc.GetMetaData() - colName := meta.QuoteIdentifier(column.ColumnName) - dataType := string(column.DataType) - - // 如果数据类型是数字,则去掉长度 - if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) { - column.NumPrecision = 0 - column.CharMaxLength = 0 - } - - // 如果是自增类型,需要转换为serial - if column.IsIdentity { - if dataType == "int4" { - column.DataType = "serial" - } else if dataType == "int2" { - column.DataType = "smallserial" - } else if dataType == "int8" { - column.DataType = "bigserial" - } else { - column.DataType = "bigserial" - } - - return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType()) - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 - if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { - mark := false - // 哪些字段类型默认值需要加引号 - if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) { - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) && - collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) { - mark = false - } else { - mark = true - } - } - // 如果数据类型是日期时间,则写死默认值函数 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) { - column.ColumnDefault = "CURRENT_TIMESTAMP" - } - - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } - - // 如果是varchar,长度翻倍,防止报错 - if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) { - column.CharMaxLength = column.CharMaxLength * 2 - } - columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal) - return columnSql -} - -func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - meta := pd.dc.GetMetaData() - quoteTableName := meta.QuoteIdentifier(tableInfo.TableName) - - sqlArr := make([]string, 0) - if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName)) - } - // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName) - fields := make([]string, 0) - pks := make([]string, 0) - columnComments := make([]string, 0) - commentTmp := "comment on column %s.%s is '%s'" - - for _, column := range columns { - if column.IsPrimaryKey { - pks = append(pks, meta.QuoteIdentifier(column.ColumnName)) - } - - fields = append(fields, pd.genColumnBasicSql(column)) - - // 防止注释内含有特殊字符串导致sql出错 - if column.ColumnComment != "" { - comment := meta.QuoteEscape(column.ColumnComment) - columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment)) - } - } - - createSql += strings.Join(fields, ",\n") - if len(pks) > 0 { - createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ",")) - } - createSql += "\n)" - - tableCommentSql := "" - if tableInfo.TableComment != "" { - commentTmp := "comment on table %s is '%s'" - tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteEscape(tableInfo.TableComment)) - } - - // create - sqlArr = append(sqlArr, createSql) - - // table comment - if tableCommentSql != "" { - sqlArr = append(sqlArr, tableCommentSql) - } - // column comment - if len(columnComments) > 0 { - sqlArr = append(sqlArr, columnComments...) - } - - return sqlArr -} - // 获取建表ddl -func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { - +func (pd *PgsqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { // 1.获取表信息 tbs, err := pd.GetTables(tableName) tableInfo := &dbi.Table{} @@ -356,7 +191,8 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s logx.Errorf("获取列信息失败, %s", tableName) return "", err } - tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) + dialect := pd.dc.GetDialect() + tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate) // 3.获取索引信息 indexs, err := pd.GetTableIndex(tableName) if err != nil { @@ -364,12 +200,12 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s return "", err } // 组装返回 - tableDDLArr = append(tableDDLArr, pd.GenerateIndexDDL(indexs, *tableInfo)...) + tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...) return strings.Join(tableDDLArr, ";\n"), nil } // 获取pgsql当前连接的库可访问的schemaNames -func (pd *PgsqlMetaData) GetSchemas() ([]string, error) { +func (pd *PgsqlMetadata) GetSchemas() ([]string, error) { sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS) _, res, err := pd.dc.Query(sql) if err != nil { @@ -382,7 +218,7 @@ func (pd *PgsqlMetaData) GetSchemas() ([]string, error) { return schemaNames, nil } -func (pd *PgsqlMetaData) DefaultDb() string { +func (pd *PgsqlMetadata) GetDefaultDb() string { switch pd.dc.Info.Type { case dbi.DbTypePostgres, dbi.DbTypeGauss: return "postgres" @@ -394,28 +230,3 @@ func (pd *PgsqlMetaData) DefaultDb() string { return "" } } - -func (pd *PgsqlMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) { - - // 设置自增序列当前值 - for _, column := range columns { - if column.IsIdentity { - seq := fmt.Sprintf("SELECT setval('%s_%s_seq', (SELECT max(%s) FROM %s));\n", tableName, column.ColumnName, column.ColumnName, tableName) - writer.Write([]byte(seq)) - } - } - - writer.Write([]byte("COMMIT;\n")) -} - -func (pd *PgsqlMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (pd *PgsqlMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} - -func (pd *PgsqlMetaData) GetDumpHelper() dbi.DumpHelper { - return new(DumpHelper) -} diff --git a/server/internal/db/dbm/sqlite/dialect.go b/server/internal/db/dbm/sqlite/dialect.go index 182048c8..7878696b 100644 --- a/server/internal/db/dbm/sqlite/dialect.go +++ b/server/internal/db/dbm/sqlite/dialect.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" + "mayfly-go/pkg/utils/collx" "strings" "time" ) @@ -35,7 +36,7 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str prefix = "insert or replace into" } - sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) // 把二维数组转为一维数组 var args []any @@ -55,7 +56,7 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error { // 生成新表名,为老表明+_copy_时间戳 newTableName := tableName + "_copy_" + time.Now().Format("20060102150405") - ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName, false) + ddl, err := sd.dc.GetMetadata().GetTableDDL(tableName, false) if err != nil { return err } @@ -82,24 +83,101 @@ func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error { return err } -func (sd *SqliteDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) { - sqlArr := sd.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable) - for _, sqlStr := range sqlArr { - _, err := sd.dc.Exec(sqlStr) - if err != nil { - return 0, err - } +// 获取建表ddl +func (sd *SqliteDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { + sqlArr := make([]string, 0) + tbName := sd.QuoteIdentifier(tableInfo.TableName) + if dropBeforeCreate { + sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName)) } - return len(sqlArr), nil + // 组装建表语句 + createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName) + fields := make([]string, 0) + + // 把通用类型转换为达梦类型 + for _, column := range columns { + fields = append(fields, sd.genColumnBasicSql(column)) + } + createSql += strings.Join(fields, ",\n") + createSql += "\n)" + + sqlArr = append(sqlArr, createSql) + + return sqlArr } -func (sd *SqliteDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error { - sqlArr := sd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo) - for _, sqlStr := range sqlArr { - _, err := sd.dc.Exec(sqlStr) - if err != nil { - return err +func (sd *SqliteDialect) genColumnBasicSql(column dbi.Column) string { + incr := "" + if column.IsIdentity { + incr = " AUTOINCREMENT" + } + + nullAble := "" + if !column.Nullable { + nullAble = " NOT NULL" + } + + quoteColumnName := sd.QuoteIdentifier(column.ColumnName) + + // 如果是主键,则直接返回,不判断默认值 + if column.IsPrimaryKey { + return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble) + } + + defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 + if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { + // 哪些字段类型默认值需要加引号 + mark := false + if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) { + // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 + if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) && + collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { + mark = false + } else { + mark = true + } + } + if mark { + defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) + } else { + defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) } } - return nil + + return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal) +} + +// 获取建索引ddl +func (sd *SqliteDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { + sqls := make([]string, 0) + for _, index := range indexs { + unique := "" + if index.IsUnique { + unique = "unique" + } + // 取出列名,添加引号 + cols := strings.Split(index.ColumnName, ",") + colNames := make([]string, len(cols)) + for i, name := range cols { + colNames[i] = sd.QuoteIdentifier(name) + } + // 创建前尝试删除 + sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName)) + + sqlTmp := "CREATE %s INDEX %s ON %s (%s) " + sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, sd.QuoteIdentifier(index.IndexName), sd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) + } + return sqls +} + +func (sd *SqliteDialect) GetDataHelper() dbi.DataHelper { + return new(DataHelper) +} + +func (sd *SqliteDialect) GetColumnHelper() dbi.ColumnHelper { + return new(ColumnHelper) +} + +func (sd *SqliteDialect) GetDumpHelper() dbi.DumpHelper { + return new(DumpHelper) } diff --git a/server/internal/db/dbm/sqlite/meta.go b/server/internal/db/dbm/sqlite/meta.go index 396a5f1a..5b604b04 100644 --- a/server/internal/db/dbm/sqlite/meta.go +++ b/server/internal/db/dbm/sqlite/meta.go @@ -33,6 +33,6 @@ func (sm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect { return &SqliteDialect{dc: conn} } -func (sm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { - return dbi.NewMetaDataX(&SqliteMetaData{dc: conn}) +func (sm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata { + return &SqliteMetadata{dc: conn} } diff --git a/server/internal/db/dbm/sqlite/metadata.go b/server/internal/db/dbm/sqlite/metadata.go index dd9f1587..b7afbb78 100644 --- a/server/internal/db/dbm/sqlite/metadata.go +++ b/server/internal/db/dbm/sqlite/metadata.go @@ -19,13 +19,13 @@ const ( SQLITE_INDEX_INFO_KEY = "SQLITE_INDEX_INFO" ) -type SqliteMetaData struct { - dbi.DefaultMetaData +type SqliteMetadata struct { + dbi.DefaultMetadata dc *dbi.DbConn } -func (sd *SqliteMetaData) GetDbServer() (*dbi.DbServer, error) { +func (sd *SqliteMetadata) GetDbServer() (*dbi.DbServer, error) { _, res, err := sd.dc.Query("SELECT SQLITE_VERSION() as version") if err != nil { return nil, err @@ -36,7 +36,7 @@ func (sd *SqliteMetaData) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (sd *SqliteMetaData) GetDbNames() ([]string, error) { +func (sd *SqliteMetadata) GetDbNames() ([]string, error) { databases := make([]string, 0) _, res, err := sd.dc.Query("PRAGMA database_list") if err != nil { @@ -50,9 +50,10 @@ func (sd *SqliteMetaData) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (sd *SqliteMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { +func (sd *SqliteMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) { + dialect := sd.dc.GetDialect() names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbi.RemoveQuote(sd, val)) + return fmt.Sprintf("'%s'", dialect.RemoveQuote(val)) }), ",") var res []map[string]any @@ -86,7 +87,7 @@ func (sd *SqliteMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) { // 如 decimal(10,2) 提取decimal, 10 ,2 // 如:text 提取text,null,null // 如:varchar(100) 提取varchar, 100 -func (sd *SqliteMetaData) getDataTypes(dataType string) (string, string, string) { +func (sd *SqliteMetadata) getDataTypes(dataType string) (string, string, string) { matches := dataTypeRegexp.FindStringSubmatch(dataType) if len(matches) == 0 { return dataType, "", "" @@ -95,9 +96,9 @@ func (sd *SqliteMetaData) getDataTypes(dataType string) (string, string, string) } // 获取列元信息, 如列名等 -func (sd *SqliteMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { +func (sd *SqliteMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) { columns := make([]dbi.Column, 0) - columnHelper := sd.dc.GetMetaData().GetColumnHelper() + columnHelper := sd.dc.GetDialect().GetColumnHelper() for i := 0; i < len(tableNames); i++ { tableName := tableNames[i] _, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName)) @@ -142,7 +143,7 @@ func (sd *SqliteMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) return columns, nil } -func (sd *SqliteMetaData) GetPrimaryKey(tableName string) (string, error) { +func (sd *SqliteMetadata) GetPrimaryKey(tableName string) (string, error) { _, res, err := sd.dc.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName)) if err != nil { return "", err @@ -173,7 +174,7 @@ func extractIndexFields(indexSQL string) string { } // 获取表索引信息 -func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { +func (sd *SqliteMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) { _, res, err := sd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(SQLITE_META_FILE, SQLITE_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err @@ -198,96 +199,8 @@ func (sd *SqliteMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return indexs, nil } -// 获取建索引ddl -func (sd *SqliteMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string { - meta := sd.dc.GetMetaData() - sqls := make([]string, 0) - for _, index := range indexs { - unique := "" - if index.IsUnique { - unique = "unique" - } - // 取出列名,添加引号 - cols := strings.Split(index.ColumnName, ",") - colNames := make([]string, len(cols)) - for i, name := range cols { - colNames[i] = meta.QuoteIdentifier(name) - } - // 创建前尝试删除 - sqls = append(sqls, fmt.Sprintf("DROP INDEX IF EXISTS \"%s\"", index.IndexName)) - - sqlTmp := "CREATE %s INDEX %s ON %s (%s) " - sqls = append(sqls, fmt.Sprintf(sqlTmp, unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ","))) - } - return sqls -} - -func (sd *SqliteMetaData) genColumnBasicSql(column dbi.Column) string { - incr := "" - if column.IsIdentity { - incr = " AUTOINCREMENT" - } - - nullAble := "" - if !column.Nullable { - nullAble = " NOT NULL" - } - - quoteColumnName := sd.dc.GetMetaData().QuoteIdentifier(column.ColumnName) - - // 如果是主键,则直接返回,不判断默认值 - if column.IsPrimaryKey { - return fmt.Sprintf(" %s integer PRIMARY KEY%s%s", quoteColumnName, incr, nullAble) - } - - defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值 - if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") { - // 哪些字段类型默认值需要加引号 - mark := false - if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(string(column.DataType))) { - // 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号 - if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(string(column.DataType))) && - collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) { - mark = false - } else { - mark = true - } - } - if mark { - defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) - } else { - defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) - } - } - - return fmt.Sprintf(" %s %s%s%s", quoteColumnName, column.GetColumnType(), nullAble, defVal) -} - // 获取建表ddl -func (sd *SqliteMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string { - sqlArr := make([]string, 0) - tbName := sd.dc.GetMetaData().QuoteIdentifier(tableInfo.TableName) - if dropBeforeCreate { - sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", tbName)) - } - // 组装建表语句 - createSql := fmt.Sprintf("CREATE TABLE %s (\n", tbName) - fields := make([]string, 0) - - // 把通用类型转换为达梦类型 - for _, column := range columns { - fields = append(fields, sd.genColumnBasicSql(column)) - } - createSql += strings.Join(fields, ",\n") - createSql += "\n)" - - sqlArr = append(sqlArr, createSql) - - return sqlArr -} - -// 获取建表ddl -func (sd *SqliteMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { +func (sd *SqliteMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) { var builder strings.Builder if dropBeforeCreate { @@ -306,18 +219,6 @@ func (sd *SqliteMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) ( return builder.String(), nil } -func (sd *SqliteMetaData) GetSchemas() ([]string, error) { +func (sd *SqliteMetadata) GetSchemas() ([]string, error) { return nil, nil } - -func (sd *SqliteMetaData) GetDataHelper() dbi.DataHelper { - return new(DataHelper) -} - -func (sd *SqliteMetaData) GetColumnHelper() dbi.ColumnHelper { - return new(ColumnHelper) -} - -func (sd *SqliteMetaData) GetDumpHelper() dbi.DumpHelper { - return new(DumpHelper) -} diff --git a/server/resources/data/mayfly-go.sqlite b/server/resources/data/mayfly-go.sqlite index af6d0b13..616b864a 100644 Binary files a/server/resources/data/mayfly-go.sqlite and b/server/resources/data/mayfly-go.sqlite differ