diff --git a/mayfly_go_web/src/views/ops/db/SqlExec.vue b/mayfly_go_web/src/views/ops/db/SqlExec.vue index 963c58df..077f46da 100644 --- a/mayfly_go_web/src/views/ops/db/SqlExec.vue +++ b/mayfly_go_web/src/views/ops/db/SqlExec.vue @@ -175,7 +175,7 @@ import { dbApi } from './api'; import { dispposeCompletionItemProvider } from '@/components/monaco/completionItemProvider'; import SvgIcon from '@/components/svgIcon/index.vue'; import { ContextmenuItem } from '@/components/contextmenu'; -import { DbType, getDbDialect } from './dialect/index'; +import { getDbDialect, schemaDbTypes} from './dialect/index' import { sleep } from '@/common/utils/loading'; import { TagResourceTypeEnum } from '@/common/commonEnum'; import { Pane, Splitpanes } from 'splitpanes'; @@ -271,7 +271,7 @@ const NodeTypeDb = new NodeType(SqlExecNodeType.Db) const params = parentNode.params; params.parentKey = parentNode.key; // pg类数据库会多一层schema - if (params.type == DbType.postgresql || params.type === DbType.dm || params.type === DbType.oracle) { + if (schemaDbTypes.includes(params.type)) { const { id, db } = params; const schemaNames = await dbApi.pgSchemas.request({ id, db }); return schemaNames.map((sn: any) => { diff --git a/mayfly_go_web/src/views/ops/db/SyncTaskEdit.vue b/mayfly_go_web/src/views/ops/db/SyncTaskEdit.vue index 45d78cce..a1d41669 100644 --- a/mayfly_go_web/src/views/ops/db/SyncTaskEdit.vue +++ b/mayfly_go_web/src/views/ops/db/SyncTaskEdit.vue @@ -47,6 +47,7 @@ v-model:db-id="form.srcDbId" v-model:db-name="form.srcDbName" v-model:tag-path="form.srcTagPath" + v-model:db-type="form.srcDbType" @select-db="onSelectSrcDb" /> @@ -181,7 +182,7 @@ import { ElMessage } from 'element-plus'; import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue'; import MonacoEditor from '@/components/monaco/MonacoEditor.vue'; import { DbInst, registerDbCompletionItemProvider } from '@/views/ops/db/db'; -import { getDbDialect } from '@/views/ops/db/dialect'; +import {DbType, getDbDialect} from '@/views/ops/db/dialect' import CrontabInput from '@/components/crontab/CrontabInput.vue'; const props = defineProps({ @@ -227,6 +228,7 @@ type FormData = { taskCron: string; srcDbId?: number; srcDbName?: string; + srcDbType?: string; srcTagPath?: string; targetDbId?: number; targetDbName?: string; @@ -245,7 +247,7 @@ const basicFormData = { targetDbId: -1, dataSql: 'select * from', pageSize: 1000, - updField: 'id', + updField: '', updFieldVal: '0', fieldMap: [{ src: 'a', target: 'b' }], status: 1, @@ -302,6 +304,7 @@ watch(dialogVisible, async (newValue: boolean) => { // 初始化实例 db.databases = db.database?.split(' ').sort() || []; state.srcDbInst = DbInst.getOrNewInst(db); + state.form.srcDbType = state.srcDbInst.type } // 初始化target数据源 @@ -396,8 +399,8 @@ const handleGetSrcFields = async () => { } // 判断sql是否是查询语句 - if (!/^select/i.test(state.form.dataSql!)) { - let msg = 'sql语句错误,请输入查询语句'; + if (!/^select/i.test(state.form.dataSql.trim()!)) { + let msg = 'sql语句错误,请输入select语句'; ElMessage.warning(msg); return; } @@ -410,10 +413,16 @@ const handleGetSrcFields = async () => { } // 执行sql + // oracle的分页关键字不一样 + let limit = ' limit 1' + if(state.form.srcDbType === DbType.oracle){ + limit = ' where rownum <= 1' + } + const res = await dbApi.sqlExec.request({ id: state.form.srcDbId, db: state.form.srcDbName, - sql: state.form.dataSql.trim() + ' limit 1', + sql: `select * from (${state.form.dataSql}) t ${limit}` }); if (!res.columns) { diff --git a/mayfly_go_web/src/views/ops/db/component/DbSelectTree.vue b/mayfly_go_web/src/views/ops/db/component/DbSelectTree.vue index 92ff7ddc..48335b67 100644 --- a/mayfly_go_web/src/views/ops/db/component/DbSelectTree.vue +++ b/mayfly_go_web/src/views/ops/db/component/DbSelectTree.vue @@ -19,7 +19,7 @@ import { NodeType, TagTreeNode } from '@/views/ops/component/tag'; import { dbApi } from '@/views/ops/db/api'; import { sleep } from '@/common/utils/loading'; import SvgIcon from '@/components/svgIcon/index.vue'; -import { DbType, getDbDialect } from '@/views/ops/db/dialect'; +import { getDbDialect, mysqlDbTypes} from '@/views/ops/db/dialect' import TagTreeResourceSelect from '../../component/TagTreeResourceSelect.vue'; import { computed } from 'vue'; @@ -33,9 +33,12 @@ const props = defineProps({ tagPath: { type: String, }, + dbType: { + type: String, + }, }); -const emits = defineEmits(['update:dbName', 'update:tagPath', 'update:dbId', 'selectDb']); +const emits = defineEmits(['update:dbName', 'update:tagPath', 'update:dbId', 'update:dbType', 'selectDb']); /** * 树节点类型 @@ -88,7 +91,7 @@ const NodeTypeTagPath = new NodeType(TagTreeNode.TagPath).withLoadNodesFunc(asyn /** mysql类型的数据库,没有schema层 */ const mysqlType = (type: string) => { - return type === DbType.mysql; + return mysqlDbTypes.includes(type); }; // 数据库实例节点类型 @@ -150,6 +153,7 @@ const changeNode = (nodeData: TagTreeNode) => { emits('update:dbName', params.db); emits('update:dbId', params.id); emits('update:tagPath', params.tagPath); + emits('update:dbType', params.type); emits('selectDb', params); }; diff --git a/mayfly_go_web/src/views/ops/db/component/table/DbTablesOp.vue b/mayfly_go_web/src/views/ops/db/component/table/DbTablesOp.vue index cf321b67..de4eedfb 100644 --- a/mayfly_go_web/src/views/ops/db/component/table/DbTablesOp.vue +++ b/mayfly_go_web/src/views/ops/db/component/table/DbTablesOp.vue @@ -179,7 +179,6 @@ const state = reactive({ visible: false, activeName: '1', type: '', - enableEditTypes: [DbType.mysql, DbType.mariadb, DbType.postgresql, DbType.dm, DbType.oracle, DbType.sqlite], // 支持"编辑表"的数据库类型 data: { // 修改表时,传递修改数据 edit: false, diff --git a/mayfly_go_web/src/views/ops/db/dialect/index.ts b/mayfly_go_web/src/views/ops/db/dialect/index.ts index 35a0e696..1ea12f78 100644 --- a/mayfly_go_web/src/views/ops/db/dialect/index.ts +++ b/mayfly_go_web/src/views/ops/db/dialect/index.ts @@ -115,7 +115,13 @@ export const DbType = { sqlite: 'sqlite', }; -export const editDbTypes = [DbType.mysql, DbType.mariadb, DbType.postgresql, DbType.dm, DbType.oracle, DbType.sqlite]; +// mysql兼容的数据库 +export const mysqlDbTypes = [DbType.mysql, DbType.mariadb, DbType.sqlite]; + +// 有schema层的数据库 +export const schemaDbTypes = [DbType.postgresql, DbType.dm, DbType.oracle]; + +export const editDbTypes = [...mysqlDbTypes, ...schemaDbTypes]; export const compatibleMysql = (dbType: string): boolean => { switch (dbType) { diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index a3fbfc33..52c3ca3f 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -14,6 +14,9 @@ import ( "mayfly-go/pkg/logx" "mayfly-go/pkg/model" "mayfly-go/pkg/scheduler" + "regexp" + "strconv" + "strings" "time" ) @@ -44,6 +47,10 @@ type dataSyncAppImpl struct { dbDataSyncLogRepo repository.DataSyncLog `inject:"DbDataSyncLogRepo"` } +var ( + dateTimeReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`) +) + func (d *dataSyncAppImpl) InjectDbDataSyncTaskRepo(repo repository.DataSyncTask) { d.Repo = repo } @@ -123,7 +130,23 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error { updSql := "" orderSql := "" if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" { - updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal) + srcConn, _ := GetDbApp().GetDbConn(uint64(task.SrcDbId), task.SrcDbName) + + task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ") + // 把UpdFieldVal尝试转为int,如果可以转为int,则不添加引号,否则添加引号 + if _, err := strconv.Atoi(task.UpdFieldVal); err != nil { + updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal) + } else { + updSql = fmt.Sprintf("and %s > %s", task.UpdField, task.UpdFieldVal) + } + + // 如果是oracle且数据类型是时间类型,则需要加上to_date函数 + if srcConn.Info.Type == dbi.DbTypeOracle { + // 用正则判断数据类型是时间 + if dateTimeReg.MatchString(task.UpdFieldVal) { + updSql = fmt.Sprintf("and %s > to_date('%s','yyyy-mm-dd hh24:mi:ss')", task.UpdField, task.UpdFieldVal) + } + } orderSql = "order by " + task.UpdField + " asc " } // 组装查询sql @@ -194,8 +217,8 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* // 遍历columns 取task.UpdField的字段类型 updFieldType = dbi.DataTypeString for _, column := range columns { - if column.Name == task.UpdField { - updFieldType = srcDialect.GetDataType(column.Type) + if strings.ToLower(column.Name) == strings.ToLower(task.UpdField) { + updFieldType = srcDialect.GetDataConverter().GetDataType(column.Type) break } } @@ -204,7 +227,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* total++ result = append(result, row) if total%batchSize == 0 { - if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { return err } @@ -226,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* // 处理剩余的数据 if len(result) > 0 { - if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { targetDbTx.Rollback() return syncLog, err } @@ -246,10 +269,16 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* return syncLog, nil } -func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { - var data = make([]map[string]any, 0) +func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { - // 遍历res,组装插入sql + // 遍历src字段列表,取出字段对应的类型 + var srcColumnTypes = make(map[string]string) + for _, column := range columns { + srcColumnTypes[column.Name] = column.Type + } + + // 遍历res,组装数据 + var data = make([]map[string]any, 0) for _, record := range srcRes { var rowData = make(map[string]any) // 遍历字段映射, target字段的值为src字段取值 @@ -262,18 +291,23 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ data = append(data, rowData) } + // 解决字段大小写问题 + updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(task.UpdField)] + if updFieldVal == "" { + updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(task.UpdField)] + } - updFieldVal := fmt.Sprintf("%v", srcRes[len(srcRes)-1][task.UpdField]) - updFieldVal = srcDialect.FormatStrData(updFieldVal, updFieldType) - task.UpdFieldVal = updFieldVal + task.UpdFieldVal = srcDialect.GetDataConverter().FormatData(updFieldVal, updFieldType) // 获取目标库字段数组 targetWrapColumns := make([]string, 0) // 获取源库字段数组 srcColumns := make([]string, 0) + srcFieldTypes := make(map[string]dbi.DataType) for _, item := range fieldMap { targetField := item["target"] srcField := item["target"] + srcFieldTypes[srcField] = srcDialect.GetDataConverter().GetDataType(srcColumnTypes[item["src"]]) targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField)) srcColumns = append(srcColumns, srcField) } @@ -283,7 +317,9 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ for _, record := range data { rawValue := make([]any, 0) for _, column := range srcColumns { - rawValue = append(rawValue, record[column]) + // 某些情况,如oracle,需要转换时间类型的字符串为time类型 + res := srcDialect.GetDataConverter().ParseData(record[column], srcFieldTypes[column]) + rawValue = append(rawValue, res) } values = append(values, rawValue) } @@ -294,6 +330,12 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ return err } + // 运行完成一轮就记录一下修改字段最大值 + taskParam1 := new(entity.DataSyncTask) + taskParam1.Id = task.Id + taskParam1.UpdFieldVal = task.UpdFieldVal + _ = app.UpdateById(context.Background(), taskParam1) + // 运行过程中,判断状态是否为已关闭,是则结束运行,否则继续运行 taskParam, _ := app.GetById(new(entity.DataSyncTask), task.Id) if taskParam.RunningState == entity.DataSyncTaskRunStateStop { diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index eea3f61e..8a17226a 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -163,7 +163,9 @@ func doSelect(ctx context.Context, selectStmt *sqlparser.Select, execSqlReq *DbS len(strings.Split(selectExprsStr, ",")) > 1 { // 如果配置为0,则不校验分页参数 maxCount := config.GetDbQueryMaxCount() - if maxCount != 0 { + // 哪些数据库跳过校验 + skipped := dbi.DbTypeOracle == execSqlReq.DbConn.Info.Type + if maxCount != 0 && !skipped { limit := selectStmt.Limit if limit == nil { return nil, errorx.NewBiz("请完善分页信息后执行") diff --git a/server/internal/db/dbm/dbi/dialect.go b/server/internal/db/dbm/dbi/dialect.go index d4c780c4..21a3d3f8 100644 --- a/server/internal/db/dbm/dbi/dialect.go +++ b/server/internal/db/dbm/dbi/dialect.go @@ -66,9 +66,23 @@ type DbCopyTable struct { CopyData bool `json:"copyData"` // 是否复制数据 } +// 数据转换器 +type DataConverter interface { + // 获取数据对应的类型 + // @param dbColumnType 数据库原始列类型,如varchar等 + GetDataType(dbColumnType string) DataType + + // 根据数据类型格式化指定数据 + FormatData(dbColumnValue any, dataType DataType) string + + // 根据数据类型解析数据为符合要求的指定类型等 + ParseData(dbColumnValue any, dataType DataType) any +} + // -----------------------------------元数据接口定义------------------------------------------ // 数据库方言、元信息接口(表、列、获取表数据等元信息) type Dialect interface { + // 获取数据库服务实例信息 GetDbServer() (*DbServer, error) @@ -101,9 +115,7 @@ type Dialect interface { // 批量保存数据 BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) - GetDataType(dbColumnType string) DataType - - FormatStrData(dbColumnValue string, dataType DataType) string + GetDataConverter() DataConverter CopyTable(copy *DbCopyTable) error } diff --git a/server/internal/db/dbm/dm/dialect.go b/server/internal/db/dbm/dm/dialect.go index 690e0395..085d575c 100644 --- a/server/internal/db/dbm/dm/dialect.go +++ b/server/internal/db/dbm/dm/dialect.go @@ -255,24 +255,16 @@ func (dd *DMDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType { - if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return dbi.DataTypeNumber - } +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) // 日期时间类型 - if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) // 日期类型 - if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return dbi.DataTypeDate - } + dateRegexp = regexp.MustCompile(`(?i)date`) // 时间类型 - if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} + timeRegexp = regexp.MustCompile(`(?i)time`) +) func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 执行批量insert sql @@ -299,18 +291,46 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, return int64(effRows), nil } -func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { +type DataConverter struct { +} + +func (dd *DMDialect) GetDataConverter() dbi.DataConverter { + return new(DataConverter) +} + +func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } + return str +} + +func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { return dbColumnValue } diff --git a/server/internal/db/dbm/mysql/dialect.go b/server/internal/db/dbm/mysql/dialect.go index 7d126078..5eb4632f 100644 --- a/server/internal/db/dbm/mysql/dialect.go +++ b/server/internal/db/dbm/mysql/dialect.go @@ -177,25 +177,6 @@ func (md *MysqlDialect) GetDbProgram() dbi.DbProgram { return NewDbProgramMysql(md.dc) } -func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType { - if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 生成占位符字符串:如:(?,?) // 重复字符串并用逗号连接 @@ -221,8 +202,48 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri return md.dc.TxExec(tx, sqlStr, args...) } -func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { - // mysql不需要格式化时间日期等 +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) +) + +type DataConverter struct { +} + +func (md *MysqlDialect) GetDataConverter() dbi.DataConverter { + return new(DataConverter) +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + // 日期类型 + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + // 时间类型 + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + return anyx.ToString(dbColumnValue) +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { return dbColumnValue } diff --git a/server/internal/db/dbm/oracle/dialect.go b/server/internal/db/dbm/oracle/dialect.go index 9e607688..8b01e0f7 100644 --- a/server/internal/db/dbm/oracle/dialect.go +++ b/server/internal/db/dbm/oracle/dialect.go @@ -8,7 +8,6 @@ import ( "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/collx" - "reflect" "regexp" "strings" "time" @@ -257,25 +256,6 @@ func (od *OracleDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (od *OracleDialect) GetDataType(dbColumnType string) dbi.DataType { - if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { //INSERT ALL //INTO my_table(field_1,field_2) VALUES (value_1,value_2) @@ -301,20 +281,6 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str for i := 0; i < len(args); i += len(columns) { var placeholder []string for j := 0; j < len(columns); j++ { - // 判断字符串数据格式是时间"2023-06-25 10:40:10" 占位符需要变成 to_date(:x, 'fmt') - if reflect.TypeOf(args[i+j]) == reflect.TypeOf("") { - if regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`).MatchString(args[i+j].(string)) { - placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'yyyy-mm-dd hh24:mi:ss')", i+j+1)) - } else if regexp.MustCompile(`^\d{4}-\d{2}-\d{2}$`).MatchString(args[i+j].(string)) { - // 只有年月日的数据,oracle会自动补零时分秒,如:2024-01-02: to_date('2024-01-02','yyyy-mm-dd') 输出:2024-01-02 00:00:00 - placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'yyyy-mm-dd')", i+j+1)) - } else if regexp.MustCompile(`^\d{2}:\d{2}:\d{2}$`).MatchString(args[i+j].(string)) { - // 只有时间的数据,oracle会拼接当前月份的年月日,如当前月份是2024-01: to_date('13:23:11','hh24:mi:ss') 输出:2024-01-01 13:23:11 - placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'hh24:mi:ss')", i+j+1)) - } - continue - } - placeholder = append(placeholder, fmt.Sprintf(":%d", i+j+1)) } sqlArr = append(sqlArr, fmt.Sprintf("INTO %s (%s) VALUES (%s)", od.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))) @@ -326,17 +292,47 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str return res, err } -func (od *OracleDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { +var ( + // 数字类型 + numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`) +) + +type DataConverter struct { +} + +func (od *OracleDialect) GetDataConverter() dbi.DataConverter { + return new(DataConverter) +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberTypeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeTypeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) switch dataType { + // oracle把日期类型数据格式化输出 case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateTime) - case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) - return res.Format(time.DateOnly) - case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) - return res.Format(time.TimeOnly) + } + return str +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { + // oracle把日期类型的数据转化为time类型 + if dataType == dbi.DataTypeDateTime { + res, _ := time.Parse(time.RFC3339, anyx.ConvString(dbColumnValue)) + return res } return dbColumnValue } @@ -345,7 +341,7 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error { // 生成新表名,为老表明+_copy_时间戳 newTableName := strings.ToUpper(copy.TableName + "_copy_" + time.Now().Format("20060102150405")) condition := "" - if copy.CopyData { + if !copy.CopyData { condition = " where 1 = 2" } _, err := od.dc.Exec(fmt.Sprintf("create table \"%s\" as select * from \"%s\" %s", newTableName, copy.TableName, condition)) diff --git a/server/internal/db/dbm/postgres/dialect.go b/server/internal/db/dbm/postgres/dialect.go index 64776721..646b66b6 100644 --- a/server/internal/db/dbm/postgres/dialect.go +++ b/server/internal/db/dbm/postgres/dialect.go @@ -26,8 +26,8 @@ type PgsqlDialect struct { dc *dbi.DbConn } -func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) { - _, res, err := pd.dc.Query("SHOW server_version") +func (md *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) { + _, res, err := md.dc.Query("SHOW server_version") if err != nil { return nil, err } @@ -37,8 +37,8 @@ func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) { return ds, nil } -func (pd *PgsqlDialect) GetDbNames() ([]string, error) { - _, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')") +func (md *PgsqlDialect) GetDbNames() ([]string, error) { + _, res, err := md.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')") if err != nil { return nil, err } @@ -52,8 +52,8 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) { - _, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) +func (md *PgsqlDialect) GetTables() ([]dbi.Table, error) { + _, res, err := md.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) if err != nil { return nil, err } @@ -73,13 +73,13 @@ func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) { } // 获取列元信息, 如列名等 -func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { - dbType := pd.dc.Info.Type +func (md *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { + dbType := md.dc.Info.Type tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) }), ",") - _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) + _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } @@ -100,8 +100,8 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { return columns, nil } -func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) { - columns, err := pd.GetColumns(tablename) +func (md *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) { + columns, err := md.GetColumns(tablename) if err != nil { return "", err } @@ -118,8 +118,8 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { - _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) +func (md *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { + _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err } @@ -155,17 +155,17 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { } // 获取建表ddl -func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) { - _, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) +func (md *PgsqlDialect) GetTableDDL(tableName string) (string, error) { + _, err := md.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) if err != nil { return "", err } - _, schemaRes, _ := pd.dc.Query("select current_schema() as schema") + _, schemaRes, _ := md.dc.Query("select current_schema() as schema") schemaName := schemaRes[0]["schema"].(string) ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName) - _, res, err := pd.dc.Query(ddlSql) + _, res, err := md.dc.Query(ddlSql) if err != nil { return "", err } @@ -173,14 +173,14 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) { return res[0]["sql"].(string), nil } -func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { - return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) +func (md *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { + return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) } // 获取pgsql当前连接的库可访问的schemaNames -func (pd *PgsqlDialect) GetSchemas() ([]string, error) { +func (md *PgsqlDialect) GetSchemas() ([]string, error) { sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS) - _, res, err := pd.dc.Query(sql) + _, res, err := md.dc.Query(sql) if err != nil { return nil, err } @@ -192,30 +192,11 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) { } // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 -func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram { +func (md *PgsqlDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType { - if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - // 日期时间类型 - if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - // 日期类型 - if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return dbi.DataTypeDate - } - // 时间类型 - if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return dbi.DataTypeTime - } - return dbi.DataTypeString -} - -func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { +func (md *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 执行批量insert sql,跟mysql一样 pg或高斯支持批量insert语法 // insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ... @@ -235,37 +216,79 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")") } - sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", ")) + sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", ")) // 执行批量insert sql - return pd.dc.TxExec(tx, sqlStr, args...) + return md.dc.TxExec(tx, sqlStr, args...) } -func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`) + // 日期类型 + dateRegexp = regexp.MustCompile(`(?i)date`) + // 时间类型 + timeRegexp = regexp.MustCompile(`(?i)time`) +) + +type DataConverter struct { +} + +func (md *PgsqlDialect) GetDataConverter() dbi.DataConverter { + return new(DataConverter) +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + // 日期时间类型 + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + // 日期类型 + if dateRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDate + } + // 时间类型 + if timeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := fmt.Sprintf("%v", dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } + return anyx.ConvString(dbColumnValue) +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { return dbColumnValue } -func (pd *PgsqlDialect) IsGauss() bool { - return strings.Contains(pd.dc.Info.Params, "gauss") +func (md *PgsqlDialect) IsGauss() bool { + return strings.Contains(md.dc.Info.Params, "gauss") } -func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { +func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { tableName := copy.TableName // 生成新表名,为老表明+_copy_时间戳 newTableName := tableName + "_copy_" + time.Now().Format("20060102150405") // 执行根据旧表创建新表 - _, err := pd.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName)) + _, err := md.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName)) if err != nil { return err } @@ -273,12 +296,12 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { // 复制数据 if copy.CopyData { go func() { - _, _ = pd.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName)) + _, _ = md.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName)) }() } // 查询旧表的自增字段名 重新设置新表的序列序列器 - _, res, err := pd.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName)) + _, res, err := md.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName)) if err != nil { return err } @@ -288,7 +311,7 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { if colName != "" { // 查询自增列当前最大值 - _, maxRes, err := pd.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName)) + _, maxRes, err := md.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName)) if err != nil { return err } @@ -304,12 +327,12 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { newSeqName := fmt.Sprintf("%s_%s_copy_seq", newTableName, colName) // 创建自增序列,当前最大值为旧表最大值 - _, err = pd.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal)) + _, err = md.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal)) if err != nil { return err } // 将新表的自增主键序列与主键列相关联 - _, err = pd.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName)) + _, err = md.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName)) if err != nil { return err } diff --git a/server/internal/db/dbm/sqlite/dialect.go b/server/internal/db/dbm/sqlite/dialect.go index c7ba27fd..95780094 100644 --- a/server/internal/db/dbm/sqlite/dialect.go +++ b/server/internal/db/dbm/sqlite/dialect.go @@ -196,16 +196,6 @@ func (sd *SqliteDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (sd *SqliteDialect) GetDataType(dbColumnType string) dbi.DataType { - if regexp.MustCompile(`(?i)int`).MatchString(dbColumnType) { - return dbi.DataTypeNumber - } - if regexp.MustCompile(`(?i)datetime`).MatchString(dbColumnType) { - return dbi.DataTypeDateTime - } - return dbi.DataTypeString -} - func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { // 执行批量insert sql,跟mysql一样 支持批量insert语法 // 生成占位符字符串:如:(?,?) @@ -231,18 +221,47 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str return sd.dc.TxExec(tx, sqlStr, args...) } -func (sd *SqliteDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { +var ( + // 数字类型 + numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`) + // 日期时间类型 + datetimeRegexp = regexp.MustCompile(`(?i)datetime`) +) + +type DataConverter struct { +} + +func (sd *SqliteDialect) GetDataConverter() dbi.DataConverter { + return new(DataConverter) +} + +func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType { + if numberRegexp.MatchString(dbColumnType) { + return dbi.DataTypeNumber + } + if datetimeRegexp.MatchString(dbColumnType) { + return dbi.DataTypeDateTime + } + return dbi.DataTypeString +} + +func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string { + str := anyx.ToString(dbColumnValue) switch dataType { case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateTime) case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.DateOnly) case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" - res, _ := time.Parse(time.RFC3339, dbColumnValue) + res, _ := time.Parse(time.RFC3339, str) return res.Format(time.TimeOnly) } + return str +} + +func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any { return dbColumnValue }