diff --git a/mayfly_go_web/package.json b/mayfly_go_web/package.json index 8eca716d..046e8354 100644 --- a/mayfly_go_web/package.json +++ b/mayfly_go_web/package.json @@ -56,7 +56,7 @@ "prettier": "^3.2.5", "sass": "^1.69.0", "typescript": "^5.3.2", - "vite": "^5.1.5", + "vite": "^5.1.6", "vue-eslint-parser": "^9.4.2" }, "browserslist": [ diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index b1c06af5..f431c11c 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -113,11 +113,10 @@ func (d *Db) ExecSql(rc *req.Ctx) { DbConn: dbConn, } - // 比前端超时时间稍微快一点,可以提示到前端 ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second) defer cancel() - sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.Info.Type.Dialect())) + sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.GetMetaData().SqlParserDialect())) biz.ErrIsNil(err, "SQL解析错误,请检查您的执行SQL") isMulti := len(sqls) > 1 var execResAll *application.DbSqlExecRes @@ -198,7 +197,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { var sql string tokenizer := sqlparser.NewReaderTokenizer(file, - sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.Info.Type.Dialect())) + sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.GetMetaData().SqlParserDialect())) executedStatements := 0 progressId := stringx.Rand(32) @@ -327,12 +326,9 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName)) writer.WriteString("\n-- ----------------------------\n\n") - writer.WriteString(dbConn.Info.Type.StmtUseDatabase(dbName)) - writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(false)) - - dbMeta := dbConn.GetDialect() + dbMeta := dbConn.GetMetaData() if len(tables) == 0 { - ti, err := dbMeta.GetMetaData().GetTables() + ti, err := dbMeta.GetTables() biz.ErrIsNil(err) tables = make([]string, len(ti)) for i, table := range ti { @@ -342,11 +338,11 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName for _, table := range tables { writer.TryFlush() - quotedTable := dbConn.Info.Type.QuoteIdentifier(table) + quotedTable := dbMeta.QuoteIdentifier(table) if needStruct { writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable)) - ddl, err := dbMeta.GetMetaData().GetTableDDL(table) + ddl, err := dbMeta.GetTableDDL(table) biz.ErrIsNil(err) writer.WriteString(ddl + "\n") } @@ -371,7 +367,7 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName } strValue, ok := value.(string) if ok { - strValue = dbConn.Info.Type.QuoteLiteral(strValue) + strValue = dbMeta.QuoteLiteral(strValue) values = append(values, strValue) } else { values = append(values, anyx.ToString(value)) @@ -382,11 +378,10 @@ func (d *Db) dumpDb(ctx context.Context, writer *gzipWriter, dbId uint64, dbName }) writer.WriteString("COMMIT;\n") } - writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(true)) } func (d *Db) TableInfos(rc *req.Ctx) { - res, err := d.getDbConn(rc).GetDialect().GetMetaData().GetTables() + res, err := d.getDbConn(rc).GetMetaData().GetTables() biz.ErrIsNilAppendErr(err, "获取表信息失败: %s") rc.ResData = res } @@ -394,7 +389,7 @@ func (d *Db) TableInfos(rc *req.Ctx) { func (d *Db) TableIndex(rc *req.Ctx) { tn := rc.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConn(rc).GetDialect().GetMetaData().GetTableIndex(tn) + res, err := d.getDbConn(rc).GetMetaData().GetTableIndex(tn) biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s") rc.ResData = res } @@ -405,7 +400,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) { biz.NotEmpty(tn, "tableName不能为空") dbi := d.getDbConn(rc) - res, err := dbi.GetDialect().GetMetaData().GetColumns(tn) + res, err := dbi.GetMetaData().GetColumns(tn) biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s") rc.ResData = res } @@ -414,9 +409,9 @@ func (d *Db) ColumnMA(rc *req.Ctx) { func (d *Db) HintTables(rc *req.Ctx) { dbi := d.getDbConn(rc) - dm := dbi.GetDialect() + metadata := dbi.GetMetaData() // 获取所有表 - tables, err := dm.GetMetaData().GetTables() + tables, err := metadata.GetTables() biz.ErrIsNil(err) tableNames := make([]string, 0) for _, v := range tables { @@ -432,7 +427,7 @@ func (d *Db) HintTables(rc *req.Ctx) { } // 获取所有表下的所有列信息 - columnMds, err := dm.GetMetaData().GetColumns(tableNames...) + columnMds, err := metadata.GetColumns(tableNames...) biz.ErrIsNil(err) for _, v := range columnMds { tName := v.TableName @@ -455,13 +450,13 @@ func (d *Db) HintTables(rc *req.Ctx) { func (d *Db) GetTableDDL(rc *req.Ctx) { tn := rc.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConn(rc).GetDialect().GetMetaData().GetTableDDL(tn) + res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn) biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s") rc.ResData = res } func (d *Db) GetSchemas(rc *req.Ctx) { - res, err := d.getDbConn(rc).GetDialect().GetMetaData().GetSchemas() + res, err := d.getDbConn(rc).GetMetaData().GetSchemas() biz.ErrIsNilAppendErr(err, "获取schemas失败: %s") rc.ResData = res } diff --git a/server/internal/db/api/db_instance.go b/server/internal/db/api/db_instance.go index 6dea3be4..ade2dcc7 100644 --- a/server/internal/db/api/db_instance.go +++ b/server/internal/db/api/db_instance.go @@ -107,7 +107,7 @@ func (d *Instance) GetDbServer(rc *req.Ctx) { instanceId := getInstanceId(rc) conn, err := d.DbApp.GetDbConnByInstanceId(instanceId) biz.ErrIsNil(err) - res, err := conn.GetDialect().GetMetaData().GetDbServer() + res, err := conn.GetMetaData().GetDbServer() biz.ErrIsNil(err) rc.ResData = res } diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index 02b46677..5a1f9f74 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -152,18 +152,6 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) { return nil, errorx.NewBiz("数据库实例不存在") } - checkDb := dbName - // 兼容pgsql/dm db/schema模式 - if dbi.DbTypePostgres.Equal(instance.Type) || dbi.DbTypeGauss.Equal(instance.Type) || dbi.DbTypeDM.Equal(instance.Type) || dbi.DbTypeOracle.Equal(instance.Type) || dbi.DbTypeMssql.Equal(instance.Type) || dbi.DbTypeKingbaseEs.Equal(instance.Type) || dbi.DbTypeVastbase.Equal(instance.Type) { - ss := strings.Split(dbName, "/") - if len(ss) > 1 { - checkDb = ss[0] - } - } - if !strings.Contains(" "+db.Database+" ", " "+checkDb+" ") { - return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName) - } - // 密码解密 if err := instance.PwdDecrypt(); err != nil { return nil, errorx.NewBiz(err.Error()) @@ -173,6 +161,11 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) { di.FlowProcdefKey = *db.FlowProcdefKey } + checkDb := di.GetDatabase() + if !strings.Contains(" "+db.Database+" ", " "+checkDb+" ") { + return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName) + } + return di, nil }) } @@ -188,7 +181,7 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error return nil, errorx.NewBiz("获取数据库列表失败") } if len(dbs) == 0 { - return nil, errorx.NewBiz("该实例未配置数据库, 请先进行配置") + return nil, errorx.NewBiz("实例[%d]未配置数据库, 请先进行配置", instanceId) } // 使用该实例关联的已配置数据库中的第一个库进行连接并返回 diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index 713fa700..061aea20 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -218,7 +218,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* } }() - srcDialect := srcConn.GetDialect() + srcMetaData := srcConn.GetMetaData() // task.FieldMap为json数组字符串 [{"src":"id","target":"id"}],转为map var fieldMap []map[string]string @@ -248,7 +248,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* updFieldType = dbi.DataTypeString for _, column := range columns { if strings.EqualFold(column.Name, updFieldName) { - updFieldType = srcDialect.GetDataConverter().GetDataType(column.Type) + updFieldType = srcMetaData.GetDataConverter().GetDataType(column.Type) break } } @@ -257,7 +257,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* total++ result = append(result, row) if total%batchSize == 0 { - if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcMetaData, targetConn, targetDbTx); err != nil { return err } @@ -279,7 +279,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* // 处理剩余的数据 if len(result) > 0 { - if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcMetaData, targetConn, targetDbTx); err != nil { targetDbTx.Rollback() return syncLog, err } @@ -303,7 +303,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* return syncLog, nil } -func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { +func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcMetaData *dbi.MetaDataX, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { // 遍历src字段列表,取出字段对应的类型 var srcColumnTypes = make(map[string]string) @@ -331,18 +331,19 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(updFieldName)] } - task.UpdFieldVal = srcDialect.GetDataConverter().FormatData(updFieldVal, updFieldType) + task.UpdFieldVal = srcMetaData.GetDataConverter().FormatData(updFieldVal, updFieldType) // 获取目标库字段数组 targetWrapColumns := make([]string, 0) // 获取源库字段数组 srcColumns := make([]string, 0) srcFieldTypes := make(map[string]dbi.DataType) + targetMetaData := targetDbConn.GetMetaData() for _, item := range fieldMap { targetField := item["target"] srcField := item["target"] - srcFieldTypes[srcField] = srcDialect.GetDataConverter().GetDataType(srcColumnTypes[item["src"]]) - targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField)) + srcFieldTypes[srcField] = srcMetaData.GetDataConverter().GetDataType(srcColumnTypes[item["src"]]) + targetWrapColumns = append(targetWrapColumns, targetMetaData.QuoteIdentifier(targetField)) srcColumns = append(srcColumns, srcField) } @@ -352,7 +353,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ rawValue := make([]any, 0) for _, column := range srcColumns { // 某些情况,如oracle,需要转换时间类型的字符串为time类型 - res := srcDialect.GetDataConverter().ParseData(record[column], srcFieldTypes[column]) + res := srcMetaData.GetDataConverter().ParseData(record[column], srcFieldTypes[column]) rawValue = append(rawValue, res) } values = append(values, rawValue) diff --git a/server/internal/db/application/db_instance.go b/server/internal/db/application/db_instance.go index db55b6f8..a1025295 100644 --- a/server/internal/db/application/db_instance.go +++ b/server/internal/db/application/db_instance.go @@ -152,13 +152,12 @@ func (app *instanceAppImpl) Delete(ctx context.Context, instanceId uint64) error func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) { ed.Network = ed.GetNetwork() - metaDb := dbi.ToDbType(ed.Type).MetaDbName() - dbConn, err := dbm.Conn(toDbInfo(ed, 0, metaDb, "")) + dbConn, err := dbm.Conn(toDbInfo(ed, 0, "", "")) if err != nil { return nil, err } defer dbConn.Close() - return dbConn.GetDialect().GetMetaData().GetDbNames() + return dbConn.GetMetaData().GetDbNames() } diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 6a414881..3c9591ac 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -278,7 +278,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, update *sqlparser.Updat } // 获取表主键列名,排除使用别名 - primaryKey, err := dbConn.GetDialect().GetMetaData().GetPrimaryKey(tableName) + primaryKey, err := dbConn.GetMetaData().GetPrimaryKey(tableName) if err != nil { return nil, errorx.NewBiz("获取表主键信息失败") } diff --git a/server/internal/db/dbm/dbi/conn.go b/server/internal/db/dbm/dbi/conn.go index ee080c52..717cd8c1 100644 --- a/server/internal/db/dbm/dbi/conn.go +++ b/server/internal/db/dbm/dbi/conn.go @@ -127,6 +127,11 @@ func (d *DbConn) GetDialect() Dialect { return d.Info.Meta.GetDialect(d) } +// 获取数据库MetaData +func (d *DbConn) GetMetaData() *MetaDataX { + return d.Info.Meta.GetMetaData(d) +} + // 返回数据库连接状态 func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBStats { return d.db.Stats() diff --git a/server/internal/db/dbm/dbi/info.go b/server/internal/db/dbm/dbi/db_info.go similarity index 72% rename from server/internal/db/dbm/dbi/info.go rename to server/internal/db/dbm/dbi/db_info.go index 517aa518..8d10cbb9 100644 --- a/server/internal/db/dbm/dbi/info.go +++ b/server/internal/db/dbm/dbi/db_info.go @@ -9,6 +9,29 @@ import ( "strings" ) +type DbType string + +const ( + DbTypeMysql DbType = "mysql" + DbTypeMariadb DbType = "mariadb" + DbTypePostgres DbType = "postgres" + DbTypeGauss DbType = "gauss" + DbTypeDM DbType = "dm" + DbTypeOracle DbType = "oracle" + DbTypeSqlite DbType = "sqlite" + DbTypeMssql DbType = "mssql" + DbTypeKingbaseEs DbType = "kingbaseEs" + DbTypeVastbase DbType = "vastbase" +) + +func ToDbType(dbType string) DbType { + return DbType(dbType) +} + +func (dbType DbType) Equal(typ string) bool { + return ToDbType(typ) == dbType +} + type DbInfo struct { InstanceId uint64 // 实例id Id uint64 // dbId @@ -17,12 +40,12 @@ type DbInfo struct { Type DbType // 类型,mysql postgres等 Host string Port int - Extra string // 连接需要的其他额外参数(json字符串),如oracle数据库需要指定sid + Extra string // 连接需要的其他额外参数(json字符串),如oracle数据库需要指定sid等 Network string Username string Password string Params string - Database string + Database string // 若有schema的库则为'database/scheam'格式 FlowProcdefKey string // 流程定义key TagPath []string @@ -45,6 +68,11 @@ func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { // 赋值Meta,方便后续获取dialect等 dbInfo.Meta = meta database := dbInfo.Database + // 如果数据库为空,则使用默认数据库进行连接 + if database == "" { + database = meta.GetMetaData(&DbConn{Info: dbInfo}).DefaultDb() + dbInfo.Database = database + } conn, err := meta.GetSqlDb(dbInfo) if err != nil { @@ -90,7 +118,7 @@ func (di *DbInfo) IfUseSshTunnelChangeIpPort() error { return nil } -// 获取当前库的schema +// 获取当前库的schema(兼容 database/schema模式) func (di *DbInfo) CurrentSchema() string { dbName := di.Database schema := "" @@ -101,6 +129,16 @@ func (di *DbInfo) CurrentSchema() string { return schema } +// 获取当前数据库(兼容 database/schema模式) +func (di *DbInfo) GetDatabase() string { + dbName := di.Database + ss := strings.Split(dbName, "/") + if len(ss) > 1 { + return ss[0] + } + return dbName +} + // 根据ssh tunnel机器id返回ssh tunnel func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) { return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId) diff --git a/server/internal/db/dbm/dbi/db_type.go b/server/internal/db/dbm/dbi/db_type.go deleted file mode 100644 index a6b0f86d..00000000 --- a/server/internal/db/dbm/dbi/db_type.go +++ /dev/null @@ -1,160 +0,0 @@ -package dbi - -import ( - "fmt" - "strings" - - pq "gitee.com/liuzongyang/libpq" - "github.com/kanzihuang/vitess/go/vt/sqlparser" -) - -type DbType string - -const ( - DbTypeMysql DbType = "mysql" - DbTypeMariadb DbType = "mariadb" - DbTypePostgres DbType = "postgres" - DbTypeGauss DbType = "gauss" - DbTypeDM DbType = "dm" - DbTypeOracle DbType = "oracle" - DbTypeSqlite DbType = "sqlite" - DbTypeMssql DbType = "mssql" - DbTypeKingbaseEs DbType = "kingbaseEs" - DbTypeVastbase DbType = "vastbase" -) - -func ToDbType(dbType string) DbType { - return DbType(dbType) -} - -func (dbType DbType) Equal(typ string) bool { - return ToDbType(typ) == dbType -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// quoted := quoteIdentifier(tblname, '"') -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func (dbType DbType) QuoteIdentifier(name string) string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return quoteIdentifier(name, "`") - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - return quoteIdentifier(name, `"`) - case DbTypeMssql: - return fmt.Sprintf("[%s]", name) - default: - return quoteIdentifier(name, `"`) - } -} - -func (dbType DbType) RemoveQuote(name string) string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return removeQuote(name, "`") - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - return removeQuote(name, `"`) - case DbTypeMssql: - return strings.Trim(name, "[]") - default: - return removeQuote(name, `"`) - } -} - -func (dbType DbType) QuoteLiteral(literal string) string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - literal = strings.ReplaceAll(literal, `\`, `\\`) - literal = strings.ReplaceAll(literal, `'`, `''`) - return "'" + literal + "'" - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - return pq.QuoteLiteral(literal) - default: - return pq.QuoteLiteral(literal) - } -} - -func (dbType DbType) MetaDbName() string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return "" - case DbTypePostgres, DbTypeGauss: - return "postgres" - case DbTypeDM: - return "" - case DbTypeKingbaseEs: - return "security" - case DbTypeVastbase: - return "vastbase" - default: - return "" - } -} - -func (dbType DbType) Dialect() sqlparser.Dialect { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return sqlparser.MysqlDialect{} - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - return sqlparser.PostgresDialect{} - default: - return sqlparser.PostgresDialect{} - } -} - -func quoteIdentifier(name, quoter string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter -} - -// 移除相关引号 -func removeQuote(name, quoter string) string { - return strings.ReplaceAll(name, quoter, "") -} - -func (dbType DbType) StmtSetForeignKeyChecks(check bool) string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - if check { - return "SET FOREIGN_KEY_CHECKS = 1;\n" - } else { - return "SET FOREIGN_KEY_CHECKS = 0;\n" - } - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - // not currently supported postgres - return "" - default: - return "" - } -} - -func (dbType DbType) StmtUseDatabase(dbName string) string { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return fmt.Sprintf("USE %s;\n", dbType.QuoteIdentifier(dbName)) - case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase: - // not currently supported postgres - return "" - default: - return "" - } -} - -func (dbType DbType) SupportingBackup() bool { - switch dbType { - case DbTypeMysql, DbTypeMariadb: - return true - default: - return false - } -} diff --git a/server/internal/db/dbm/dbi/db_type_test.go b/server/internal/db/dbm/dbi/db_type_test.go deleted file mode 100644 index 197c092c..00000000 --- a/server/internal/db/dbm/dbi/db_type_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package dbi - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_QuoteLiteral(t *testing.T) { - tests := []struct { - dbType DbType - sql string - want string - }{ - { - dbType: DbTypeMysql, - sql: "\\a\\b", - want: "'\\\\a\\\\b'", - }, - { - dbType: DbTypeMysql, - sql: "'a'", - want: "'''a'''", - }, - { - dbType: DbTypeMysql, - sql: "a\u00A0b", - want: "'a\u00A0b'", - }, - { - dbType: DbTypePostgres, - sql: "\\a\\b", - want: " E'\\\\a\\\\b'", - }, - { - dbType: DbTypePostgres, - sql: "'a'", - want: "'''a'''", - }, - { - dbType: DbTypePostgres, - sql: "a\u00A0b", - want: "'a\u00A0b'", - }, - } - for _, tt := range tests { - t.Run(string(tt.dbType)+"_"+tt.sql, func(t *testing.T) { - got := tt.dbType.QuoteLiteral(tt.sql) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_quoteIdentifier(t *testing.T) { - tests := []struct { - dbType DbType - sql string - want string - }{ - { - dbType: DbTypeMysql, - sql: "`a`", - }, - { - dbType: DbTypeMysql, - sql: "select table", - }, - { - dbType: DbTypePostgres, - sql: "a", - }, - { - dbType: DbTypePostgres, - sql: "table", - }, - } - for _, tt := range tests { - t.Run(string(tt.dbType)+"_"+tt.sql, func(t *testing.T) { - got := tt.dbType.QuoteIdentifier(tt.sql) - require.Equal(t, tt.want, got) - }) - } -} diff --git a/server/internal/db/dbm/dbi/dialect.go b/server/internal/db/dbm/dbi/dialect.go index 89f74538..04471030 100644 --- a/server/internal/db/dbm/dbi/dialect.go +++ b/server/internal/db/dbm/dbi/dialect.go @@ -4,16 +4,6 @@ import ( "database/sql" ) -type DataType string - -const ( - DataTypeString DataType = "string" - DataTypeNumber DataType = "number" - DataTypeDate DataType = "date" - DataTypeTime DataType = "time" - DataTypeDateTime DataType = "datetime" -) - const ( // -1. 无操作 DuplicateStrategyNone = -1 @@ -30,34 +20,16 @@ type DbCopyTable struct { CopyData bool `json:"copyData"` // 是否复制数据 } -// 数据转换器 -type DataConverter interface { - // 获取数据对应的类型 - // @param dbColumnType 数据库原始列类型,如varchar等 - GetDataType(dbColumnType string) DataType - - // 根据数据类型格式化指定数据 - FormatData(dbColumnValue any, dataType DataType) string - - // 根据数据类型解析数据为符合要求的指定类型等 - ParseData(dbColumnValue any, dataType DataType) any -} - // -----------------------------------元数据接口定义------------------------------------------ // 数据库方言 用于获取元信息接口、批量插入等各个数据库方言不一致的实现方式 type Dialect interface { - // 获取元数据信息接口 - GetMetaData() MetaData - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 GetDbProgram() (DbProgram, error) // 批量保存数据 BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) - // 获取数据转换器用于解析格式化列数据等 - GetDataConverter() DataConverter - + // 拷贝表 CopyTable(copy *DbCopyTable) error } diff --git a/server/internal/db/dbm/dbi/meta.go b/server/internal/db/dbm/dbi/meta.go index f3ff44a1..a9e3051e 100644 --- a/server/internal/db/dbm/dbi/meta.go +++ b/server/internal/db/dbm/dbi/meta.go @@ -23,4 +23,8 @@ type Meta interface { // 获取数据库方言 GetDialect(*DbConn) Dialect + + // 获取元数据信息接口 + // @param *DbConn 数据库连接, 若一些元数据接口(如 GetIdentifierQuoteString)不需要DbConn,则可以传nil + GetMetaData(*DbConn) *MetaDataX } diff --git a/server/internal/db/dbm/dbi/metadata.go b/server/internal/db/dbm/dbi/metadata.go index a67b1784..554fea8e 100644 --- a/server/internal/db/dbm/dbi/metadata.go +++ b/server/internal/db/dbm/dbi/metadata.go @@ -10,6 +10,8 @@ import ( // 元数据接口(表、列、等元信息) type MetaData interface { + BaseMetaData + // 获取数据库服务实例信息 GetDbServer() (*DbServer, error) @@ -32,6 +34,9 @@ type MetaData interface { GetTableDDL(tableName string) (string, error) GetSchemas() ([]string, error) + + // 获取数据转换器用于解析格式化列数据等 + GetDataConverter() DataConverter } // 数据库服务实例信息 @@ -74,6 +79,29 @@ type Index struct { IsUnique bool `json:"isUnique"` } +type DataType string + +const ( + DataTypeString DataType = "string" + DataTypeNumber DataType = "number" + DataTypeDate DataType = "date" + DataTypeTime DataType = "time" + DataTypeDateTime DataType = "datetime" +) + +// 数据转换器 +type DataConverter interface { + // 获取数据对应的类型 + // @param dbColumnType 数据库原始列类型,如varchar等 + GetDataType(dbColumnType string) DataType + + // 根据数据类型格式化指定数据 + FormatData(dbColumnValue any, dataType DataType) string + + // 根据数据类型解析数据为符合要求的指定类型等 + ParseData(dbColumnValue any, dataType DataType) any +} + // ------------------------- 元数据sql操作 ------------------------- // //go:embed metasql/* diff --git a/server/internal/db/dbm/dbi/metadata_base.go b/server/internal/db/dbm/dbi/metadata_base.go new file mode 100644 index 00000000..318d4f43 --- /dev/null +++ b/server/internal/db/dbm/dbi/metadata_base.go @@ -0,0 +1,48 @@ +package dbi + +import ( + pq "gitee.com/liuzongyang/libpq" + "github.com/kanzihuang/vitess/go/vt/sqlparser" +) + +type BaseMetaData interface { + + // 默认库 + DefaultDb() string + + // 用于引用 SQL 标识符(关键字)的字符串 + GetIdentifierQuoteString() string + + // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal + // to DDL and other statements that do not accept parameters) to be used as part + // of an SQL statement. For example: + // + // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") + // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) + // + // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be + // replaced by two backslashes (i.e. "\\") and the C-style escape identifier + QuoteLiteral(literal string) string + + SqlParserDialect() sqlparser.Dialect +} + +// 默认实现,若需要覆盖,则由各个数据库MetaData实现去覆盖重写 +type DefaultMetaData struct { +} + +func (dd *DefaultMetaData) DefaultDb() string { + return "" +} + +func (dd *DefaultMetaData) GetIdentifierQuoteString() string { + return `"` +} + +func (dd *DefaultMetaData) QuoteLiteral(literal string) string { + return pq.QuoteLiteral(literal) +} + +func (dd *DefaultMetaData) SqlParserDialect() sqlparser.Dialect { + return sqlparser.PostgresDialect{} +} diff --git a/server/internal/db/dbm/dbi/metadatax.go b/server/internal/db/dbm/dbi/metadatax.go new file mode 100644 index 00000000..d43c4b09 --- /dev/null +++ b/server/internal/db/dbm/dbi/metadatax.go @@ -0,0 +1,59 @@ +package dbi + +import ( + "fmt" + "strings" +) + +// 包装扩展MetaData,提供所有实现MetaData结构体的公共方法 +type MetaDataX struct { + MetaData +} + +func NewMetaDataX(metaData MetaData) *MetaDataX { + return &MetaDataX{metaData} +} + +func (md *MetaDataX) QuoteIdentifier(name string) string { + return QuoteIdentifier(md, name) +} + +func (md *MetaDataX) RemoveQuote(name string) string { + return RemoveQuote(md, name) +} + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := quoteIdentifier(tblname, '"') +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be +// case sensitive when used in a query. If the input string contains a zero +// byte, the result will be truncated immediately before it. +func QuoteIdentifier(metadata MetaData, name string) string { + quoter := metadata.GetIdentifierQuoteString() + // 兼容mssql + if quoter == "[" { + return fmt.Sprintf("[%s]", name) + } + + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter +} + +func RemoveQuote(metadata MetaData, name string) string { + quoter := metadata.GetIdentifierQuoteString() + + // 兼容mssql + if quoter == "[" { + return strings.Trim(name, "[]") + } + + return strings.ReplaceAll(name, quoter, "") +} diff --git a/server/internal/db/dbm/dm/dialect.go b/server/internal/db/dbm/dm/dialect.go index 46d3ae52..30f3b325 100644 --- a/server/internal/db/dbm/dm/dialect.go +++ b/server/internal/db/dbm/dm/dialect.go @@ -5,10 +5,8 @@ import ( "fmt" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/stringx" - "regexp" "strings" "time" @@ -21,52 +19,31 @@ type DMDialect struct { dc *dbi.DbConn } -func (dd *DMDialect) GetMetaData() dbi.MetaData { - return &DMMetaData{ - dc: dd.dc, - } -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (dd *DMDialect) GetDbProgram() (dbi.DbProgram, error) { return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", dd.dc.Info.Type) } -var ( - // 数字类型 - numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) - // 日期时间类型 - datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) - // 日期类型 - dateRegexp = regexp.MustCompile(`(?i)date`) - // 时间类型 - timeRegexp = regexp.MustCompile(`(?i)time`) - - converter = new(DataConverter) -) - func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) { // 执行批量insert sql // insert into "table_name" ("column1", "column2", ...) values (value1, value2, ...) - dbType := dd.dc.Info.Type - // 无需处理重复数据,直接执行批量insert if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 { - return dd.batchInsertSimple(dbType, tx, tableName, columns, values) + return dd.batchInsertSimple(tx, tableName, columns, values) } else { // 执行MERGE INTO语句 - return dd.batchInsertMerge(dbType, tx, tableName, columns, values) + return dd.batchInsertMerge(tx, tableName, columns, values) } } -func (dd *DMDialect) batchInsertSimple(dbType dbi.DbType, tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { +func (dd *DMDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 生成占位符字符串:如:(?,?) // 重复字符串并用逗号连接 repeated := strings.Repeat("?,", len(columns)) // 去除最后一个逗号,占位符由括号包裹 placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ",")) - sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dbType.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) effRows := 0 for _, value := range values { // 达梦数据库只能一条条的执行insert @@ -80,21 +57,21 @@ func (dd *DMDialect) batchInsertSimple(dbType dbi.DbType, tx *sql.Tx, tableName return int64(effRows), nil } -func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { +func (dd *DMDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 查询主键字段 uniqueCols := make([]string, 0) caseSqls := make([]string, 0) - metadata := dd.GetMetaData() + metadata := dd.dc.GetMetaData() tableCols, _ := metadata.GetColumns(tableName) identityCols := make([]string, 0) for _, col := range tableCols { if col.IsPrimaryKey { uniqueCols = append(uniqueCols, col.ColumnName) - caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", dbType.QuoteIdentifier(col.ColumnName), dbType.QuoteIdentifier(col.ColumnName))) + caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", metadata.QuoteIdentifier(col.ColumnName), metadata.QuoteIdentifier(col.ColumnName))) } if col.IsIdentity { // 自增字段不放入insert内,即使是设置了identity_insert on也不起作用 - identityCols = append(identityCols, dbType.QuoteIdentifier(col.ColumnName)) + identityCols = append(identityCols, metadata.QuoteIdentifier(col.ColumnName)) } } // 查询唯一索引涉及到的字段,并组装到match条件内 @@ -106,7 +83,7 @@ func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName s tmp := make([]string, 0) for _, col := range cols { uniqueCols = append(uniqueCols, col) - tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dbType.QuoteIdentifier(col), dbType.QuoteIdentifier(col))) + tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", metadata.QuoteIdentifier(col), metadata.QuoteIdentifier(col))) } caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND "))) } @@ -120,7 +97,7 @@ func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName s insertCols := make([]string, 0) for _, column := range columns { phs = append(phs, fmt.Sprintf("? %s", column)) - if !collx.ArrayContains(uniqueCols, dbType.RemoveQuote(column)) { + if !collx.ArrayContains(uniqueCols, metadata.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } if !collx.ArrayContains(identityCols, column) { @@ -135,7 +112,7 @@ func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName s } t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ") + sqlTemp := "MERGE INTO " + metadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ") sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")" sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -148,68 +125,9 @@ func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName s return dd.dc.TxExec(tx, sqlTemp, args...) } -func (dd *DMDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -type DataConverter struct { -} - -func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - if datetimeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - if dateRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDate - } - if timeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - -func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - str := anyx.ToString(dbColumnValue) - switch dataType { - case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateTime) - case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateOnly) - case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.TimeOnly) - } - return str -} - -func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 - _, ok := dbColumnValue.(string) - if ok { - if dataType == dbi.DataTypeDateTime { - res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeDate { - res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeTime { - res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) - return res - } - } - return dbColumnValue -} - func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName - metadata := dd.GetMetaData() + metadata := dd.dc.GetMetaData() ddl, err := metadata.GetTableDDL(tableName) if err != nil { return err @@ -221,7 +139,7 @@ func (dd *DMDialect) CopyTable(copy *dbi.DbCopyTable) error { ddl = strings.ReplaceAll(ddl, fmt.Sprintf("\"%s\"", strings.ToUpper(tableName)), fmt.Sprintf("\"%s\"", strings.ToUpper(newTableName))) // 去除空格换行 ddl = stringx.TrimSpaceAndBr(ddl) - sqls, err := sqlparser.SplitStatementToPieces(ddl, sqlparser.WithDialect(dd.dc.Info.Type.Dialect())) + sqls, err := sqlparser.SplitStatementToPieces(ddl, sqlparser.WithDialect(dd.dc.GetMetaData().SqlParserDialect())) for _, sql := range sqls { _, _ = dd.dc.Exec(sql) } diff --git a/server/internal/db/dbm/dm/meta.go b/server/internal/db/dbm/dm/meta.go index 7d49efd2..3904fb4f 100644 --- a/server/internal/db/dbm/dm/meta.go +++ b/server/internal/db/dbm/dm/meta.go @@ -41,5 +41,11 @@ func (md *DmMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { } func (md *DmMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &DMDialect{conn} + return &DMDialect{dc: conn} +} + +func (md *DmMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&DMMetaData{ + dc: conn, + }) } diff --git a/server/internal/db/dbm/dm/metadata.go b/server/internal/db/dbm/dm/metadata.go index bca9fe6a..0eb6de55 100644 --- a/server/internal/db/dbm/dm/metadata.go +++ b/server/internal/db/dbm/dm/metadata.go @@ -6,7 +6,9 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" + "regexp" "strings" + "time" ) const ( @@ -18,6 +20,8 @@ const ( ) type DMMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -74,9 +78,8 @@ func (dd *DMMetaData) GetTables() ([]dbi.Table, error) { // 获取列元信息, 如列名等 func (dd *DMMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := dd.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dbi.RemoveQuote(dd, val)) }), ",") _, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName)) @@ -237,3 +240,75 @@ func (dd *DMMetaData) GetSchemas() ([]string, error) { } return schemaNames, nil } + +func (dd *DMMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) + switch dataType { + case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateOnly) + case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.TimeOnly) + } + return str +} + +func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 + _, ok := dbColumnValue.(string) + if ok { + if dataType == dbi.DataTypeDateTime { + res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeDate { + res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeTime { + res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) + return res + } + } + return dbColumnValue +} diff --git a/server/internal/db/dbm/mssql/dialect.go b/server/internal/db/dbm/mssql/dialect.go index 3406e11a..fdec4cb2 100644 --- a/server/internal/db/dbm/mssql/dialect.go +++ b/server/internal/db/dbm/mssql/dialect.go @@ -5,9 +5,7 @@ import ( "fmt" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" - "regexp" "strings" "time" ) @@ -16,10 +14,6 @@ type MssqlDialect struct { dc *dbi.DbConn } -func (md *MssqlDialect) GetMetaData() dbi.MetaData { - return &MssqlMetaData{dc: md.dc} -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (md *MssqlDialect) GetDbProgram() (dbi.DbProgram, error) { return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", md.dc.Info.Type) @@ -35,12 +29,12 @@ func (md *MssqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri } func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) { - msMetadata := md.GetMetaData().(*MssqlMetaData) + msMetadata := md.dc.GetMetaData() schema := md.dc.Info.CurrentSchema() ignoreDupSql := "" if duplicateStrategy == dbi.DuplicateStrategyIgnore { // ALTER TABLE dbo.TEST ADD CONSTRAINT uniqueRows UNIQUE (ColA, ColB, ColC, ColD) WITH (IGNORE_DUP_KEY = ON) - indexs, _ := msMetadata.getTableIndexWithPK(tableName) + indexs, _ := msMetadata.MetaData.(*MssqlMetaData).getTableIndexWithPK(tableName) // 收集唯一索引涉及到的字段 uniqueColumns := make([]string, 0) for _, index := range indexs { @@ -71,7 +65,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns // 去除最后一个逗号 placeholder = strings.TrimSuffix(repeated, ",") - baseTable := fmt.Sprintf("%s.%s", md.dc.Info.Type.QuoteIdentifier(schema), md.dc.Info.Type.QuoteIdentifier(tableName)) + baseTable := fmt.Sprintf("%s.%s", msMetadata.QuoteIdentifier(schema), msMetadata.QuoteIdentifier(tableName)) sqlStr := fmt.Sprintf("insert into %s (%s) values %s", baseTable, strings.Join(columns, ","), placeholder) // 执行批量insert sql @@ -94,9 +88,8 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns } func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) { - msMetadata := md.GetMetaData().(*MssqlMetaData) + msMetadata := md.dc.GetMetaData() schema := md.dc.Info.CurrentSchema() - dbType := md.dc.Info.Type // 收集MERGE 语句的 ON 子句条件 caseSqls := make([]string, 0) @@ -111,7 +104,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ } if col.IsPrimaryKey { pkCols = append(pkCols, col.ColumnName) - name := dbType.QuoteIdentifier(col.ColumnName) + name := msMetadata.QuoteIdentifier(col.ColumnName) caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name)) } } @@ -125,7 +118,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ // 源数据占位sql phs := make([]string, 0) for _, column := range columns { - if !collx.ArrayContains(identityCols, dbType.RemoveQuote(column)) { + if !collx.ArrayContains(identityCols, msMetadata.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } insertCols = append(insertCols, fmt.Sprintf("%s", column)) @@ -143,7 +136,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ } t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(schema) + "." + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ") + sqlTemp := "MERGE INTO " + msMetadata.QuoteIdentifier(schema) + "." + msMetadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ") sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") " sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -159,69 +152,8 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [ return res, err } -func (md *MssqlDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -var ( - // 数字类型 - numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) - // 日期时间类型 - datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) - // 日期类型 - dateRegexp = regexp.MustCompile(`(?i)date`) - // 时间类型 - timeRegexp = regexp.MustCompile(`(?i)time`) - - converter = new(DataConverter) -) - -type DataConverter struct { -} - -func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if datetimeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if dateRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if timeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - -func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - return anyx.ToString(dbColumnValue) -} - -func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 - _, ok := dbColumnValue.(string) - if dataType == dbi.DataTypeDateTime && ok { - res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeDate && ok { - res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeTime && ok { - res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) - return res - } - return dbColumnValue -} - func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error { - msMetadata := md.GetMetaData().(*MssqlMetaData) + msMetadata := md.dc.GetMetaData().MetaData.(*MssqlMetaData) schema := md.dc.Info.CurrentSchema() // 生成新表名,为老表明+_copy_时间戳 diff --git a/server/internal/db/dbm/mssql/meta.go b/server/internal/db/dbm/mssql/meta.go index 6a4dddc8..4d70b575 100644 --- a/server/internal/db/dbm/mssql/meta.go +++ b/server/internal/db/dbm/mssql/meta.go @@ -3,10 +3,11 @@ package mssql import ( "database/sql" "fmt" - _ "github.com/microsoft/go-mssqldb" "mayfly-go/internal/db/dbm/dbi" "net/url" "strings" + + _ "github.com/microsoft/go-mssqldb" ) func init() { @@ -53,5 +54,9 @@ func (md *MssqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { } func (md *MssqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &MssqlDialect{conn} + return &MssqlDialect{dc: conn} +} + +func (md *MssqlMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&MssqlMetaData{dc: conn}) } diff --git a/server/internal/db/dbm/mssql/metadata.go b/server/internal/db/dbm/mssql/metadata.go index 486b1128..22e787cf 100644 --- a/server/internal/db/dbm/mssql/metadata.go +++ b/server/internal/db/dbm/mssql/metadata.go @@ -6,7 +6,9 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" + "regexp" "strings" + "time" ) const ( @@ -21,6 +23,8 @@ const ( ) type MssqlMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -73,9 +77,8 @@ func (md *MssqlMetaData) GetTables() ([]dbi.Table, error) { // 获取列元信息, 如列名等 func (md *MssqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := md.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dbi.RemoveQuote(md, val)) }), ",") _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_COLUMN_MA_KEY), tableName), md.dc.Info.CurrentSchema()) @@ -168,7 +171,7 @@ func (md *MssqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) { return result, nil } -func (md MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (string, error) { +func (md *MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (string, error) { if newTableName == "" { newTableName = tableName } @@ -192,7 +195,7 @@ func (md MssqlMetaData) CopyTableDDL(tableName string, newTableName string) (str } } - baseTable := fmt.Sprintf("%s.%s", md.dc.Info.Type.QuoteIdentifier(md.dc.Info.CurrentSchema()), md.dc.Info.Type.QuoteIdentifier(newTableName)) + baseTable := fmt.Sprintf("%s.%s", dbi.QuoteIdentifier(md, md.dc.Info.CurrentSchema()), dbi.QuoteIdentifier(md, newTableName)) // 查询列信息 columns, err := md.GetColumns(tableName) @@ -271,3 +274,68 @@ func (md *MssqlMetaData) GetSchemas() ([]string, error) { } return schemas, nil } + +func (md *MssqlMetaData) GetIdentifierQuoteString() string { + return "[" +} + +func (md *MssqlMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + // 日期类型 + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + // 时间类型 + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + return anyx.ToString(dbColumnValue) +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 + _, ok := dbColumnValue.(string) + if dataType == dbi.DataTypeDateTime && ok { + res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeDate && ok { + res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeTime && ok { + res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) + return res + } + return dbColumnValue +} diff --git a/server/internal/db/dbm/mysql/dialect.go b/server/internal/db/dbm/mysql/dialect.go index 53e38a44..d342db79 100644 --- a/server/internal/db/dbm/mysql/dialect.go +++ b/server/internal/db/dbm/mysql/dialect.go @@ -4,8 +4,6 @@ import ( "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" - "mayfly-go/pkg/utils/anyx" - "regexp" "strings" "time" ) @@ -14,10 +12,6 @@ type MysqlDialect struct { dc *dbi.DbConn } -func (md *MysqlDialect) GetMetaData() dbi.MetaData { - return &MysqlMetaData{dc: md.dc} -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) { return NewDbProgramMysql(md.dc), nil @@ -45,7 +39,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri prefix = "replace into" } - sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) // 执行批量insert sql // 把二维数组转为一维数组 var args []any @@ -55,69 +49,6 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri return md.dc.TxExec(tx, sqlStr, args...) } -func (md *MysqlDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -var ( - // 数字类型 - numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) - // 日期时间类型 - datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) - // 日期类型 - dateRegexp = regexp.MustCompile(`(?i)date`) - // 时间类型 - timeRegexp = regexp.MustCompile(`(?i)time`) - - converter = new(DataConverter) -) - -type DataConverter struct { -} - -func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if datetimeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if dateRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if timeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - -func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - return anyx.ToString(dbColumnValue) -} - -func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 - _, ok := dbColumnValue.(string) - if ok { - if dataType == dbi.DataTypeDateTime { - res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeDate { - res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeTime { - res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) - return res - } - } - return dbColumnValue -} - func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName diff --git a/server/internal/db/dbm/mysql/meta.go b/server/internal/db/dbm/mysql/meta.go index 0ce1719b..f72e4cfc 100644 --- a/server/internal/db/dbm/mysql/meta.go +++ b/server/internal/db/dbm/mysql/meta.go @@ -40,5 +40,9 @@ func (md *MysqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { } func (md *MysqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &MysqlDialect{conn} + return &MysqlDialect{dc: conn} +} + +func (md *MysqlMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&MysqlMetaData{dc: conn}) } diff --git a/server/internal/db/dbm/mysql/metadata.go b/server/internal/db/dbm/mysql/metadata.go index af02c4b7..f0463192 100644 --- a/server/internal/db/dbm/mysql/metadata.go +++ b/server/internal/db/dbm/mysql/metadata.go @@ -7,7 +7,11 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" + "regexp" "strings" + "time" + + "github.com/kanzihuang/vitess/go/vt/sqlparser" ) const ( @@ -19,6 +23,8 @@ const ( ) type MysqlMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -69,9 +75,8 @@ func (md *MysqlMetaData) GetTables() ([]dbi.Table, error) { // 获取列元信息, 如列名等 func (md *MysqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := md.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dbi.RemoveQuote(md, val)) }), ",") _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) @@ -164,3 +169,80 @@ func (md *MysqlMetaData) GetTableDDL(tableName string) (string, error) { func (md *MysqlMetaData) GetSchemas() ([]string, error) { return nil, errors.New("不支持schema") } + +func (md *MysqlMetaData) GetIdentifierQuoteString() string { + return "`" +} + +func (md *MysqlMetaData) QuoteLiteral(literal string) string { + literal = strings.ReplaceAll(literal, `\`, `\\`) + literal = strings.ReplaceAll(literal, `'`, `''`) + return "'" + literal + "'" +} + +func (md *MysqlMetaData) SqlParserDialect() sqlparser.Dialect { + return sqlparser.MysqlDialect{} +} + +func (md *MysqlMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + // 日期类型 + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + // 时间类型 + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + return anyx.ToString(dbColumnValue) +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 + _, ok := dbColumnValue.(string) + if ok { + if dataType == dbi.DataTypeDateTime { + res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeDate { + res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeTime { + res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) + return res + } + } + return dbColumnValue +} diff --git a/server/internal/db/dbm/oracle/dialect.go b/server/internal/db/dbm/oracle/dialect.go index fa2b84a4..bd542e33 100644 --- a/server/internal/db/dbm/oracle/dialect.go +++ b/server/internal/db/dbm/oracle/dialect.go @@ -5,9 +5,7 @@ import ( "fmt" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" - "regexp" "strings" "time" @@ -18,10 +16,6 @@ type OracleDialect struct { dc *dbi.DbConn } -func (od *OracleDialect) GetMetaData() dbi.MetaData { - return &OracleMetaData{dc: od.dc} -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (od *OracleDialect) GetDbProgram() (dbi.DbProgram, error) { return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", od.dc.Info.Type) @@ -39,20 +33,20 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str } if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 || duplicateStrategy == dbi.DuplicateStrategyIgnore { - return od.batchInsertSimple(od.dc.Info.Type, tableName, columns, values, duplicateStrategy, tx) + return od.batchInsertSimple(tableName, columns, values, duplicateStrategy, tx) } else { - return od.batchInsertMergeSql(od.dc.Info.Type, tableName, columns, values, args, tx) + return od.batchInsertMergeSql(tableName, columns, values, args, tx) } } // 简单批量插入sql,无需判断键冲突策略 -func (od *OracleDialect) batchInsertSimple(dbType dbi.DbType, tableName string, columns []string, values [][]any, duplicateStrategy int, tx *sql.Tx) (int64, error) { - +func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, values [][]any, duplicateStrategy int, tx *sql.Tx) (int64, error) { + metadata := od.dc.GetMetaData() // 忽略键冲突策略 ignore := "" if duplicateStrategy == dbi.DuplicateStrategyIgnore { // 查出唯一索引涉及的字段 - indexs, _ := od.GetMetaData().GetTableIndex(tableName) + indexs, _ := metadata.GetTableIndex(tableName) if indexs != nil { arr := make([]string, 0) for _, index := range indexs { @@ -75,7 +69,7 @@ func (od *OracleDialect) batchInsertSimple(dbType dbi.DbType, tableName string, for i := 0; i < len(value); i++ { placeholder = append(placeholder, fmt.Sprintf(":%d", i+1)) } - sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, dbType.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")) + sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, metadata.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")) // oracle数据库为了兼容ignore主键冲突,只能一条条的执行insert res, err := od.dc.TxExec(tx, sqlTemp, value...) @@ -87,12 +81,13 @@ func (od *OracleDialect) batchInsertSimple(dbType dbi.DbType, tableName string, return int64(effRows), nil } -func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string, columns []string, values [][]any, args []any, tx *sql.Tx) (int64, error) { +func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string, values [][]any, args []any, tx *sql.Tx) (int64, error) { // 查询主键字段 uniqueCols := make([]string, 0) caseSqls := make([]string, 0) + metadata := od.dc.GetMetaData() // 查询唯一索引涉及到的字段,并组装到match条件内 - indexs, _ := od.GetMetaData().GetTableIndex(tableName) + indexs, _ := metadata.GetTableIndex(tableName) if indexs != nil { for _, index := range indexs { if index.IsUnique { @@ -102,7 +97,7 @@ func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string if !collx.ArrayContains(uniqueCols, col) { uniqueCols = append(uniqueCols, col) } - tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dbType.QuoteIdentifier(col), dbType.QuoteIdentifier(col))) + tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", metadata.QuoteIdentifier(col), metadata.QuoteIdentifier(col))) } caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND "))) } @@ -111,7 +106,7 @@ func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string // 如果caseSqls为空,则说明没有唯一键,直接使用简单批量插入 if len(caseSqls) == 0 { - return od.batchInsertSimple(dbType, tableName, columns, values, dbi.DuplicateStrategyNone, tx) + return od.batchInsertSimple(tableName, columns, values, dbi.DuplicateStrategyNone, tx) } // 重复数据处理策略 @@ -119,7 +114,7 @@ func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string upds := make([]string, 0) insertCols := make([]string, 0) for _, column := range columns { - if !collx.ArrayContains(uniqueCols, dbType.RemoveQuote(column)) { + if !collx.ArrayContains(uniqueCols, metadata.RemoveQuote(column)) { upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column)) } insertCols = append(insertCols, fmt.Sprintf("T1.%s", column)) @@ -140,7 +135,7 @@ func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string t2 := strings.Join(t2s, " UNION ALL ") - sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") " + sqlTemp := "MERGE INTO " + metadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") " sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") " sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",") @@ -152,53 +147,6 @@ func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string return res, err } -func (od *OracleDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -var ( - // 数字类型 - numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) - // 日期时间类型 - datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`) - - converter = new(DataConverter) -) - -type DataConverter struct { -} - -func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberTypeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if datetimeTypeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - return dbi.DataTypeString -} - -func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - str := anyx.ToString(dbColumnValue) - switch dataType { - // oracle把日期类型数据格式化输出 - case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateTime) - } - return str -} - -func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - // oracle把日期类型的数据转化为time类型 - if dataType == dbi.DataTypeDateTime { - res, _ := time.Parse(time.RFC3339, anyx.ConvString(dbColumnValue)) - return res - } - return dbColumnValue -} - func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error { // 生成新表名,为老表明+_copy_时间戳 newTableName := strings.ToUpper(copy.TableName + "_copy_" + time.Now().Format("20060102150405")) diff --git a/server/internal/db/dbm/oracle/meta.go b/server/internal/db/dbm/oracle/meta.go index 013367b9..9a29ed30 100644 --- a/server/internal/db/dbm/oracle/meta.go +++ b/server/internal/db/dbm/oracle/meta.go @@ -78,6 +78,10 @@ func (md *OraMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { return conn, err } -func (md *OraMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &OracleDialect{conn} +func (om *OraMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &OracleDialect{dc: conn} +} + +func (om *OraMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&OracleMetaData{dc: conn}) } diff --git a/server/internal/db/dbm/oracle/metadata.go b/server/internal/db/dbm/oracle/metadata.go index 76c9d795..79a576a8 100644 --- a/server/internal/db/dbm/oracle/metadata.go +++ b/server/internal/db/dbm/oracle/metadata.go @@ -6,7 +6,9 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" + "regexp" "strings" + "time" ) // ---------------------------------- DM元数据 ----------------------------------- @@ -19,6 +21,8 @@ const ( ) type OracleMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -75,9 +79,8 @@ func (od *OracleMetaData) GetTables() ([]dbi.Table, error) { // 获取列元信息, 如列名等 func (od *OracleMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := od.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dbi.RemoveQuote(od, val)) }), ",") // 如果表数量超过了1000,需要分批查询 @@ -259,3 +262,50 @@ func (od *OracleMetaData) GetSchemas() ([]string, error) { } return schemaNames, nil } + +func (od *OracleMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberTypeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeTypeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) + switch dataType { + // oracle把日期类型数据格式化输出 + case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + } + return str +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // oracle把日期类型的数据转化为time类型 + if dataType == dbi.DataTypeDateTime { + res, _ := time.Parse(time.RFC3339, anyx.ConvString(dbColumnValue)) + return res + } + return dbColumnValue +} diff --git a/server/internal/db/dbm/postgres/dialect.go b/server/internal/db/dbm/postgres/dialect.go index a58e696d..e6222814 100644 --- a/server/internal/db/dbm/postgres/dialect.go +++ b/server/internal/db/dbm/postgres/dialect.go @@ -6,7 +6,6 @@ import ( "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" - "regexp" "strings" "time" ) @@ -15,10 +14,6 @@ type PgsqlDialect struct { dc *dbi.DbConn } -func (pd *PgsqlDialect) GetMetaData() dbi.MetaData { - return &PgsqlMetaData{dc: pd.dc} -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (pd *PgsqlDialect) GetDbProgram() (dbi.DbProgram, error) { return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", pd.dc.Info.Type) @@ -56,7 +51,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns) } - sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix) + sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix) // 执行批量insert sql return pd.dc.TxExec(tx, sqlStr, args...) @@ -90,13 +85,14 @@ func (pd PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableN // 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考:https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138 func (pd PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string { suffix := "" + metadata := pd.dc.GetMetaData() if duplicateStrategy == dbi.DuplicateStrategyIgnore { suffix = " \n ON DUPLICATE KEY UPDATE NOTHING" } else if duplicateStrategy == dbi.DuplicateStrategyUpdate { // 查出表里的唯一键涉及的字段 var uniqueColumns []string - indexs, err := pd.GetMetaData().GetTableIndex(tableName) + indexs, err := metadata.GetTableIndex(tableName) if err == nil { for _, index := range indexs { if index.IsUnique { @@ -113,7 +109,7 @@ func (pd PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableN suffix = " \n ON DUPLICATE KEY UPDATE " for i, col := range columns { // ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段,所以得去掉 - if !collx.ArrayContains(uniqueColumns, pd.dc.Info.Type.RemoveQuote(strings.ToLower(col))) { + if !collx.ArrayContains(uniqueColumns, metadata.RemoveQuote(strings.ToLower(col))) { suffix += fmt.Sprintf("%s = excluded.%s", col, col) if i < len(columns)-1 { suffix += ", " @@ -135,79 +131,6 @@ func (pd *PgsqlDialect) currentSchema() string { return schema } -func (pd *PgsqlDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -var ( - // 数字类型 - numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) - // 日期时间类型 - datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) - // 日期类型 - dateRegexp = regexp.MustCompile(`(?i)date`) - // 时间类型 - timeRegexp = regexp.MustCompile(`(?i)time`) - - converter = new(DataConverter) -) - -type DataConverter struct { -} - -func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if datetimeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if dateRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if timeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - -func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - str := fmt.Sprintf("%v", dbColumnValue) - switch dataType { - case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateTime) - case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateOnly) - case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.TimeOnly) - } - return anyx.ConvString(dbColumnValue) -} - -func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 - _, ok := dbColumnValue.(string) - if dataType == dbi.DataTypeDateTime && ok { - res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeDate && ok { - res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) - return res - } - if dataType == dbi.DataTypeTime && ok { - res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) - return res - } - return dbColumnValue -} - func (pd *PgsqlDialect) IsGauss() bool { return strings.Contains(pd.dc.Info.Params, "gauss") } diff --git a/server/internal/db/dbm/postgres/meta.go b/server/internal/db/dbm/postgres/meta.go index 58b84911..e088e4f6 100644 --- a/server/internal/db/dbm/postgres/meta.go +++ b/server/internal/db/dbm/postgres/meta.go @@ -45,9 +45,6 @@ func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { db := d.Database var dbParam string existSchema := false - if db == "" { - db = d.Type.MetaDbName() - } // postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema ss := strings.Split(db, "/") if len(ss) > 1 { @@ -84,8 +81,12 @@ func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { return sql.Open(driverName, dsn) } -func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &PgsqlDialect{conn} +func (pm *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &PgsqlDialect{dc: conn} +} + +func (pm *PostgresMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&PgsqlMetaData{dc: conn}) } // pgsql dialer diff --git a/server/internal/db/dbm/postgres/metadata.go b/server/internal/db/dbm/postgres/metadata.go index e986eae0..67ce809b 100644 --- a/server/internal/db/dbm/postgres/metadata.go +++ b/server/internal/db/dbm/postgres/metadata.go @@ -6,7 +6,9 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" + "regexp" "strings" + "time" ) const ( @@ -19,6 +21,8 @@ const ( ) type PgsqlMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -70,9 +74,8 @@ func (pd *PgsqlMetaData) GetTables() ([]dbi.Table, error) { // 获取列元信息, 如列名等 func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := pd.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { - return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) + return fmt.Sprintf("'%s'", dbi.RemoveQuote(pd, val)) }), ",") _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) @@ -180,3 +183,89 @@ func (pd *PgsqlMetaData) GetSchemas() ([]string, error) { } return schemaNames, nil } + +func (pd *PgsqlMetaData) DefaultDb() string { + switch pd.dc.Info.Type { + case dbi.DbTypePostgres, dbi.DbTypeGauss: + return "postgres" + case dbi.DbTypeKingbaseEs: + return "security" + case dbi.DbTypeVastbase: + return "vastbase" + default: + return "" + } +} + +func (pd *PgsqlMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + // 日期类型 + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + // 时间类型 + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := fmt.Sprintf("%v", dbColumnValue) + switch dataType { + case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateOnly) + case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.TimeOnly) + } + return anyx.ConvString(dbColumnValue) +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型 + _, ok := dbColumnValue.(string) + if dataType == dbi.DataTypeDateTime && ok { + res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeDate && ok { + res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue)) + return res + } + if dataType == dbi.DataTypeTime && ok { + res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue)) + return res + } + return dbColumnValue +} diff --git a/server/internal/db/dbm/sqlite/dialect.go b/server/internal/db/dbm/sqlite/dialect.go index a7202fd6..a544f4c9 100644 --- a/server/internal/db/dbm/sqlite/dialect.go +++ b/server/internal/db/dbm/sqlite/dialect.go @@ -4,8 +4,6 @@ import ( "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" - "mayfly-go/pkg/utils/anyx" - "regexp" "strings" "time" ) @@ -14,10 +12,6 @@ type SqliteDialect struct { dc *dbi.DbConn } -func (sd *SqliteDialect) GetMetaData() dbi.MetaData { - return &SqliteMetaData{dc: sd.dc} -} - // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 func (sd *SqliteDialect) GetDbProgram() (dbi.DbProgram, error) { return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", sd.dc.Info.Type) @@ -43,7 +37,7 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str prefix = "insert or replace into" } - sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) + sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder) // 把二维数组转为一维数组 var args []any @@ -55,58 +49,12 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str return sd.dc.TxExec(tx, sqlStr, args...) } -func (sd *SqliteDialect) GetDataConverter() dbi.DataConverter { - return converter -} - -var ( - // 数字类型 - numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`) - // 日期时间类型 - datetimeRegexp = regexp.MustCompile(`(?i)datetime`) - - converter = new(DataConverter) -) - -type DataConverter struct { -} - -func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { - if numberRegexp.MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - if datetimeRegexp.MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - return dbi.DataTypeString -} - -func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { - str := anyx.ToString(dbColumnValue) - switch dataType { - case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateTime) - case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.DateOnly) - case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, str) - return res.Format(time.TimeOnly) - } - return str -} - -func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { - return dbColumnValue -} - func (sd *SqliteDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName // 生成新表名,为老表明+_copy_时间戳 newTableName := tableName + "_copy_" + time.Now().Format("20060102150405") - ddl, err := sd.GetMetaData().GetTableDDL(tableName) + ddl, err := sd.dc.GetMetaData().GetTableDDL(tableName) if err != nil { return err } diff --git a/server/internal/db/dbm/sqlite/meta.go b/server/internal/db/dbm/sqlite/meta.go index b13034ab..855163b3 100644 --- a/server/internal/db/dbm/sqlite/meta.go +++ b/server/internal/db/dbm/sqlite/meta.go @@ -23,6 +23,10 @@ func (md *SqliteMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { return sql.Open("sqlite", d.Host) } -func (md *SqliteMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { - return &SqliteDialect{conn} +func (sm *SqliteMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &SqliteDialect{dc: conn} +} + +func (sm *SqliteMeta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX { + return dbi.NewMetaDataX(&SqliteMetaData{dc: conn}) } diff --git a/server/internal/db/dbm/sqlite/metadata.go b/server/internal/db/dbm/sqlite/metadata.go index 8805940e..70f08025 100644 --- a/server/internal/db/dbm/sqlite/metadata.go +++ b/server/internal/db/dbm/sqlite/metadata.go @@ -8,6 +8,7 @@ import ( "mayfly-go/pkg/utils/anyx" "regexp" "strings" + "time" ) const ( @@ -17,6 +18,8 @@ const ( ) type SqliteMetaData struct { + dbi.DefaultMetaData + dc *dbi.DbConn } @@ -176,3 +179,49 @@ func (sd *SqliteMetaData) GetTableDDL(tableName string) (string, error) { func (sd *SqliteMetaData) GetSchemas() ([]string, error) { return nil, nil } + +func (sd *SqliteMetaData) GetDataConverter() dbi.DataConverter { + return converter +} + +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime`) + + converter = new(DataConverter) +) + +type DataConverter struct { +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) + switch dataType { + case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateTime) + case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.DateOnly) + case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" + res, _ := time.Parse(time.RFC3339, str) + return res.Format(time.TimeOnly) + } + return str +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + return dbColumnValue +}