!91 fix: oracle数据同步 bug

* fix: oracle数据同步 bug
This commit is contained in:
zongyangleo
2024-01-24 08:29:16 +00:00
committed by Coder慌
parent e4d13f3377
commit bed95254d0
13 changed files with 335 additions and 182 deletions

View File

@@ -175,7 +175,7 @@ import { dbApi } from './api';
import { dispposeCompletionItemProvider } from '@/components/monaco/completionItemProvider'; import { dispposeCompletionItemProvider } from '@/components/monaco/completionItemProvider';
import SvgIcon from '@/components/svgIcon/index.vue'; import SvgIcon from '@/components/svgIcon/index.vue';
import { ContextmenuItem } from '@/components/contextmenu'; import { ContextmenuItem } from '@/components/contextmenu';
import { DbType, getDbDialect } from './dialect/index'; import { getDbDialect, schemaDbTypes} from './dialect/index'
import { sleep } from '@/common/utils/loading'; import { sleep } from '@/common/utils/loading';
import { TagResourceTypeEnum } from '@/common/commonEnum'; import { TagResourceTypeEnum } from '@/common/commonEnum';
import { Pane, Splitpanes } from 'splitpanes'; import { Pane, Splitpanes } from 'splitpanes';
@@ -271,7 +271,7 @@ const NodeTypeDb = new NodeType(SqlExecNodeType.Db)
const params = parentNode.params; const params = parentNode.params;
params.parentKey = parentNode.key; params.parentKey = parentNode.key;
// pg类数据库会多一层schema // pg类数据库会多一层schema
if (params.type == DbType.postgresql || params.type === DbType.dm || params.type === DbType.oracle) { if (schemaDbTypes.includes(params.type)) {
const { id, db } = params; const { id, db } = params;
const schemaNames = await dbApi.pgSchemas.request({ id, db }); const schemaNames = await dbApi.pgSchemas.request({ id, db });
return schemaNames.map((sn: any) => { return schemaNames.map((sn: any) => {

View File

@@ -47,6 +47,7 @@
v-model:db-id="form.srcDbId" v-model:db-id="form.srcDbId"
v-model:db-name="form.srcDbName" v-model:db-name="form.srcDbName"
v-model:tag-path="form.srcTagPath" v-model:tag-path="form.srcTagPath"
v-model:db-type="form.srcDbType"
@select-db="onSelectSrcDb" @select-db="onSelectSrcDb"
/> />
</el-form-item> </el-form-item>
@@ -181,7 +182,7 @@ import { ElMessage } from 'element-plus';
import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue'; import DbSelectTree from '@/views/ops/db/component/DbSelectTree.vue';
import MonacoEditor from '@/components/monaco/MonacoEditor.vue'; import MonacoEditor from '@/components/monaco/MonacoEditor.vue';
import { DbInst, registerDbCompletionItemProvider } from '@/views/ops/db/db'; import { DbInst, registerDbCompletionItemProvider } from '@/views/ops/db/db';
import { getDbDialect } from '@/views/ops/db/dialect'; import {DbType, getDbDialect} from '@/views/ops/db/dialect'
import CrontabInput from '@/components/crontab/CrontabInput.vue'; import CrontabInput from '@/components/crontab/CrontabInput.vue';
const props = defineProps({ const props = defineProps({
@@ -227,6 +228,7 @@ type FormData = {
taskCron: string; taskCron: string;
srcDbId?: number; srcDbId?: number;
srcDbName?: string; srcDbName?: string;
srcDbType?: string;
srcTagPath?: string; srcTagPath?: string;
targetDbId?: number; targetDbId?: number;
targetDbName?: string; targetDbName?: string;
@@ -245,7 +247,7 @@ const basicFormData = {
targetDbId: -1, targetDbId: -1,
dataSql: 'select * from', dataSql: 'select * from',
pageSize: 1000, pageSize: 1000,
updField: 'id', updField: '',
updFieldVal: '0', updFieldVal: '0',
fieldMap: [{ src: 'a', target: 'b' }], fieldMap: [{ src: 'a', target: 'b' }],
status: 1, status: 1,
@@ -302,6 +304,7 @@ watch(dialogVisible, async (newValue: boolean) => {
// 初始化实例 // 初始化实例
db.databases = db.database?.split(' ').sort() || []; db.databases = db.database?.split(' ').sort() || [];
state.srcDbInst = DbInst.getOrNewInst(db); state.srcDbInst = DbInst.getOrNewInst(db);
state.form.srcDbType = state.srcDbInst.type
} }
// 初始化target数据源 // 初始化target数据源
@@ -396,8 +399,8 @@ const handleGetSrcFields = async () => {
} }
// 判断sql是否是查询语句 // 判断sql是否是查询语句
if (!/^select/i.test(state.form.dataSql!)) { if (!/^select/i.test(state.form.dataSql.trim()!)) {
let msg = 'sql语句错误请输入查询语句'; let msg = 'sql语句错误请输入select语句';
ElMessage.warning(msg); ElMessage.warning(msg);
return; return;
} }
@@ -410,10 +413,16 @@ const handleGetSrcFields = async () => {
} }
// 执行sql // 执行sql
// oracle的分页关键字不一样
let limit = ' limit 1'
if(state.form.srcDbType === DbType.oracle){
limit = ' where rownum <= 1'
}
const res = await dbApi.sqlExec.request({ const res = await dbApi.sqlExec.request({
id: state.form.srcDbId, id: state.form.srcDbId,
db: state.form.srcDbName, db: state.form.srcDbName,
sql: state.form.dataSql.trim() + ' limit 1', sql: `select * from (${state.form.dataSql}) t ${limit}`
}); });
if (!res.columns) { if (!res.columns) {

View File

@@ -19,7 +19,7 @@ import { NodeType, TagTreeNode } from '@/views/ops/component/tag';
import { dbApi } from '@/views/ops/db/api'; import { dbApi } from '@/views/ops/db/api';
import { sleep } from '@/common/utils/loading'; import { sleep } from '@/common/utils/loading';
import SvgIcon from '@/components/svgIcon/index.vue'; import SvgIcon from '@/components/svgIcon/index.vue';
import { DbType, getDbDialect } from '@/views/ops/db/dialect'; import { getDbDialect, mysqlDbTypes} from '@/views/ops/db/dialect'
import TagTreeResourceSelect from '../../component/TagTreeResourceSelect.vue'; import TagTreeResourceSelect from '../../component/TagTreeResourceSelect.vue';
import { computed } from 'vue'; import { computed } from 'vue';
@@ -33,9 +33,12 @@ const props = defineProps({
tagPath: { tagPath: {
type: String, type: String,
}, },
dbType: {
type: String,
},
}); });
const emits = defineEmits(['update:dbName', 'update:tagPath', 'update:dbId', 'selectDb']); const emits = defineEmits(['update:dbName', 'update:tagPath', 'update:dbId', 'update:dbType', 'selectDb']);
/** /**
* 树节点类型 * 树节点类型
@@ -88,7 +91,7 @@ const NodeTypeTagPath = new NodeType(TagTreeNode.TagPath).withLoadNodesFunc(asyn
/** mysql类型的数据库没有schema层 */ /** mysql类型的数据库没有schema层 */
const mysqlType = (type: string) => { const mysqlType = (type: string) => {
return type === DbType.mysql; return mysqlDbTypes.includes(type);
}; };
// 数据库实例节点类型 // 数据库实例节点类型
@@ -150,6 +153,7 @@ const changeNode = (nodeData: TagTreeNode) => {
emits('update:dbName', params.db); emits('update:dbName', params.db);
emits('update:dbId', params.id); emits('update:dbId', params.id);
emits('update:tagPath', params.tagPath); emits('update:tagPath', params.tagPath);
emits('update:dbType', params.type);
emits('selectDb', params); emits('selectDb', params);
}; };
</script> </script>

View File

@@ -179,7 +179,6 @@ const state = reactive({
visible: false, visible: false,
activeName: '1', activeName: '1',
type: '', type: '',
enableEditTypes: [DbType.mysql, DbType.mariadb, DbType.postgresql, DbType.dm, DbType.oracle, DbType.sqlite], // 支持"编辑表"的数据库类型
data: { data: {
// 修改表时,传递修改数据 // 修改表时,传递修改数据
edit: false, edit: false,

View File

@@ -115,7 +115,13 @@ export const DbType = {
sqlite: 'sqlite', sqlite: 'sqlite',
}; };
export const editDbTypes = [DbType.mysql, DbType.mariadb, DbType.postgresql, DbType.dm, DbType.oracle, DbType.sqlite]; // mysql兼容的数据库
export const mysqlDbTypes = [DbType.mysql, DbType.mariadb, DbType.sqlite];
// 有schema层的数据库
export const schemaDbTypes = [DbType.postgresql, DbType.dm, DbType.oracle];
export const editDbTypes = [...mysqlDbTypes, ...schemaDbTypes];
export const compatibleMysql = (dbType: string): boolean => { export const compatibleMysql = (dbType: string): boolean => {
switch (dbType) { switch (dbType) {

View File

@@ -14,6 +14,9 @@ import (
"mayfly-go/pkg/logx" "mayfly-go/pkg/logx"
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"mayfly-go/pkg/scheduler" "mayfly-go/pkg/scheduler"
"regexp"
"strconv"
"strings"
"time" "time"
) )
@@ -44,6 +47,10 @@ type dataSyncAppImpl struct {
dbDataSyncLogRepo repository.DataSyncLog `inject:"DbDataSyncLogRepo"` dbDataSyncLogRepo repository.DataSyncLog `inject:"DbDataSyncLogRepo"`
} }
var (
dateTimeReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`)
)
func (d *dataSyncAppImpl) InjectDbDataSyncTaskRepo(repo repository.DataSyncTask) { func (d *dataSyncAppImpl) InjectDbDataSyncTaskRepo(repo repository.DataSyncTask) {
d.Repo = repo d.Repo = repo
} }
@@ -123,7 +130,23 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error {
updSql := "" updSql := ""
orderSql := "" orderSql := ""
if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" { if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" {
updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal) srcConn, _ := GetDbApp().GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ")
// 把UpdFieldVal尝试转为int如果可以转为int则不添加引号否则添加引号
if _, err := strconv.Atoi(task.UpdFieldVal); err != nil {
updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal)
} else {
updSql = fmt.Sprintf("and %s > %s", task.UpdField, task.UpdFieldVal)
}
// 如果是oracle且数据类型是时间类型则需要加上to_date函数
if srcConn.Info.Type == dbi.DbTypeOracle {
// 用正则判断数据类型是时间
if dateTimeReg.MatchString(task.UpdFieldVal) {
updSql = fmt.Sprintf("and %s > to_date('%s','yyyy-mm-dd hh24:mi:ss')", task.UpdField, task.UpdFieldVal)
}
}
orderSql = "order by " + task.UpdField + " asc " orderSql = "order by " + task.UpdField + " asc "
} }
// 组装查询sql // 组装查询sql
@@ -194,8 +217,8 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
// 遍历columns 取task.UpdField的字段类型 // 遍历columns 取task.UpdField的字段类型
updFieldType = dbi.DataTypeString updFieldType = dbi.DataTypeString
for _, column := range columns { for _, column := range columns {
if column.Name == task.UpdField { if strings.ToLower(column.Name) == strings.ToLower(task.UpdField) {
updFieldType = srcDialect.GetDataType(column.Type) updFieldType = srcDialect.GetDataConverter().GetDataType(column.Type)
break break
} }
} }
@@ -204,7 +227,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
total++ total++
result = append(result, row) result = append(result, row)
if total%batchSize == 0 { if total%batchSize == 0 {
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil {
return err return err
} }
@@ -226,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
// 处理剩余的数据 // 处理剩余的数据
if len(result) > 0 { if len(result) > 0 {
if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil {
targetDbTx.Rollback() targetDbTx.Rollback()
return syncLog, err return syncLog, err
} }
@@ -246,10 +269,16 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
return syncLog, nil return syncLog, nil
} }
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
var data = make([]map[string]any, 0)
// 遍历res组装插入sql // 遍历src字段列表取出字段对应的类型
var srcColumnTypes = make(map[string]string)
for _, column := range columns {
srcColumnTypes[column.Name] = column.Type
}
// 遍历res组装数据
var data = make([]map[string]any, 0)
for _, record := range srcRes { for _, record := range srcRes {
var rowData = make(map[string]any) var rowData = make(map[string]any)
// 遍历字段映射, target字段的值为src字段取值 // 遍历字段映射, target字段的值为src字段取值
@@ -262,18 +291,23 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
data = append(data, rowData) data = append(data, rowData)
} }
// 解决字段大小写问题
updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(task.UpdField)]
if updFieldVal == "" {
updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(task.UpdField)]
}
updFieldVal := fmt.Sprintf("%v", srcRes[len(srcRes)-1][task.UpdField]) task.UpdFieldVal = srcDialect.GetDataConverter().FormatData(updFieldVal, updFieldType)
updFieldVal = srcDialect.FormatStrData(updFieldVal, updFieldType)
task.UpdFieldVal = updFieldVal
// 获取目标库字段数组 // 获取目标库字段数组
targetWrapColumns := make([]string, 0) targetWrapColumns := make([]string, 0)
// 获取源库字段数组 // 获取源库字段数组
srcColumns := make([]string, 0) srcColumns := make([]string, 0)
srcFieldTypes := make(map[string]dbi.DataType)
for _, item := range fieldMap { for _, item := range fieldMap {
targetField := item["target"] targetField := item["target"]
srcField := item["target"] srcField := item["target"]
srcFieldTypes[srcField] = srcDialect.GetDataConverter().GetDataType(srcColumnTypes[item["src"]])
targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField)) targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField))
srcColumns = append(srcColumns, srcField) srcColumns = append(srcColumns, srcField)
} }
@@ -283,7 +317,9 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
for _, record := range data { for _, record := range data {
rawValue := make([]any, 0) rawValue := make([]any, 0)
for _, column := range srcColumns { for _, column := range srcColumns {
rawValue = append(rawValue, record[column]) // 某些情况如oracle需要转换时间类型的字符串为time类型
res := srcDialect.GetDataConverter().ParseData(record[column], srcFieldTypes[column])
rawValue = append(rawValue, res)
} }
values = append(values, rawValue) values = append(values, rawValue)
} }
@@ -294,6 +330,12 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
return err return err
} }
// 运行完成一轮就记录一下修改字段最大值
taskParam1 := new(entity.DataSyncTask)
taskParam1.Id = task.Id
taskParam1.UpdFieldVal = task.UpdFieldVal
_ = app.UpdateById(context.Background(), taskParam1)
// 运行过程中,判断状态是否为已关闭,是则结束运行,否则继续运行 // 运行过程中,判断状态是否为已关闭,是则结束运行,否则继续运行
taskParam, _ := app.GetById(new(entity.DataSyncTask), task.Id) taskParam, _ := app.GetById(new(entity.DataSyncTask), task.Id)
if taskParam.RunningState == entity.DataSyncTaskRunStateStop { if taskParam.RunningState == entity.DataSyncTaskRunStateStop {

View File

@@ -163,7 +163,9 @@ func doSelect(ctx context.Context, selectStmt *sqlparser.Select, execSqlReq *DbS
len(strings.Split(selectExprsStr, ",")) > 1 { len(strings.Split(selectExprsStr, ",")) > 1 {
// 如果配置为0则不校验分页参数 // 如果配置为0则不校验分页参数
maxCount := config.GetDbQueryMaxCount() maxCount := config.GetDbQueryMaxCount()
if maxCount != 0 { // 哪些数据库跳过校验
skipped := dbi.DbTypeOracle == execSqlReq.DbConn.Info.Type
if maxCount != 0 && !skipped {
limit := selectStmt.Limit limit := selectStmt.Limit
if limit == nil { if limit == nil {
return nil, errorx.NewBiz("请完善分页信息后执行") return nil, errorx.NewBiz("请完善分页信息后执行")

View File

@@ -66,9 +66,23 @@ type DbCopyTable struct {
CopyData bool `json:"copyData"` // 是否复制数据 CopyData bool `json:"copyData"` // 是否复制数据
} }
// 数据转换器
type DataConverter interface {
// 获取数据对应的类型
// @param dbColumnType 数据库原始列类型如varchar等
GetDataType(dbColumnType string) DataType
// 根据数据类型格式化指定数据
FormatData(dbColumnValue any, dataType DataType) string
// 根据数据类型解析数据为符合要求的指定类型等
ParseData(dbColumnValue any, dataType DataType) any
}
// -----------------------------------元数据接口定义------------------------------------------ // -----------------------------------元数据接口定义------------------------------------------
// 数据库方言、元信息接口(表、列、获取表数据等元信息) // 数据库方言、元信息接口(表、列、获取表数据等元信息)
type Dialect interface { type Dialect interface {
// 获取数据库服务实例信息 // 获取数据库服务实例信息
GetDbServer() (*DbServer, error) GetDbServer() (*DbServer, error)
@@ -101,9 +115,7 @@ type Dialect interface {
// 批量保存数据 // 批量保存数据
BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error)
GetDataType(dbColumnType string) DataType GetDataConverter() DataConverter
FormatStrData(dbColumnValue string, dataType DataType) string
CopyTable(copy *DbCopyTable) error CopyTable(copy *DbCopyTable) error
} }

View File

@@ -255,24 +255,16 @@ func (dd *DMDialect) GetDbProgram() dbi.DbProgram {
panic("implement me") panic("implement me")
} }
func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType { var (
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { // 数字类型
return dbi.DataTypeNumber numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
}
// 日期时间类型 // 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
return dbi.DataTypeDateTime
}
// 日期类型 // 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { dateRegexp = regexp.MustCompile(`(?i)date`)
return dbi.DataTypeDate
}
// 时间类型 // 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { timeRegexp = regexp.MustCompile(`(?i)time`)
return dbi.DataTypeTime )
}
return dbi.DataTypeString
}
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 执行批量insert sql // 执行批量insert sql
@@ -299,18 +291,46 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
return int64(effRows), nil return int64(effRows), nil
} }
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { type DataConverter struct {
}
func (dd *DMDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
}
func (dd *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType { switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateTime) return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly) return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly) return res.Format(time.TimeOnly)
} }
return str
}
func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue return dbColumnValue
} }

View File

@@ -177,25 +177,6 @@ func (md *MysqlDialect) GetDbProgram() dbi.DbProgram {
return NewDbProgramMysql(md.dc) return NewDbProgramMysql(md.dc)
} }
func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 生成占位符字符串:如:(?,?) // 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接 // 重复字符串并用逗号连接
@@ -221,8 +202,48 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
return md.dc.TxExec(tx, sqlStr, args...) return md.dc.TxExec(tx, sqlStr, args...)
} }
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { var (
// mysql不需要格式化时间日期等 // 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
)
type DataConverter struct {
}
func (md *MysqlDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
}
func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
return anyx.ToString(dbColumnValue)
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue return dbColumnValue
} }

View File

@@ -8,7 +8,6 @@ import (
"mayfly-go/pkg/errorx" "mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/collx"
"reflect"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@@ -257,25 +256,6 @@ func (od *OracleDialect) GetDbProgram() dbi.DbProgram {
panic("implement me") panic("implement me")
} }
func (od *OracleDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
//INSERT ALL //INSERT ALL
//INTO my_table(field_1,field_2) VALUES (value_1,value_2) //INTO my_table(field_1,field_2) VALUES (value_1,value_2)
@@ -301,20 +281,6 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
for i := 0; i < len(args); i += len(columns) { for i := 0; i < len(args); i += len(columns) {
var placeholder []string var placeholder []string
for j := 0; j < len(columns); j++ { for j := 0; j < len(columns); j++ {
// 判断字符串数据格式是时间"2023-06-25 10:40:10" 占位符需要变成 to_date(:x, 'fmt')
if reflect.TypeOf(args[i+j]) == reflect.TypeOf("") {
if regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`).MatchString(args[i+j].(string)) {
placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'yyyy-mm-dd hh24:mi:ss')", i+j+1))
} else if regexp.MustCompile(`^\d{4}-\d{2}-\d{2}$`).MatchString(args[i+j].(string)) {
// 只有年月日的数据oracle会自动补零时分秒2024-01-02: to_date('2024-01-02','yyyy-mm-dd') 输出2024-01-02 00:00:00
placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'yyyy-mm-dd')", i+j+1))
} else if regexp.MustCompile(`^\d{2}:\d{2}:\d{2}$`).MatchString(args[i+j].(string)) {
// 只有时间的数据oracle会拼接当前月份的年月日如当前月份是2024-01: to_date('13:23:11','hh24:mi:ss') 输出2024-01-01 13:23:11
placeholder = append(placeholder, fmt.Sprintf("to_date(:%d, 'hh24:mi:ss')", i+j+1))
}
continue
}
placeholder = append(placeholder, fmt.Sprintf(":%d", i+j+1)) placeholder = append(placeholder, fmt.Sprintf(":%d", i+j+1))
} }
sqlArr = append(sqlArr, fmt.Sprintf("INTO %s (%s) VALUES (%s)", od.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))) sqlArr = append(sqlArr, fmt.Sprintf("INTO %s (%s) VALUES (%s)", od.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")))
@@ -326,17 +292,47 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
return res, err return res, err
} }
func (od *OracleDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { var (
// 数字类型
numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`)
)
type DataConverter struct {
}
func (od *OracleDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
}
func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberTypeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeTypeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
return dbi.DataTypeString
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType { switch dataType {
// oracle把日期类型数据格式化输出
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateTime) return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" }
res, _ := time.Parse(time.RFC3339, dbColumnValue) return str
return res.Format(time.DateOnly) }
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return res.Format(time.TimeOnly) // oracle把日期类型的数据转化为time类型
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.RFC3339, anyx.ConvString(dbColumnValue))
return res
} }
return dbColumnValue return dbColumnValue
} }
@@ -345,7 +341,7 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 生成新表名,为老表明+_copy_时间戳 // 生成新表名,为老表明+_copy_时间戳
newTableName := strings.ToUpper(copy.TableName + "_copy_" + time.Now().Format("20060102150405")) newTableName := strings.ToUpper(copy.TableName + "_copy_" + time.Now().Format("20060102150405"))
condition := "" condition := ""
if copy.CopyData { if !copy.CopyData {
condition = " where 1 = 2" condition = " where 1 = 2"
} }
_, err := od.dc.Exec(fmt.Sprintf("create table \"%s\" as select * from \"%s\" %s", newTableName, copy.TableName, condition)) _, err := od.dc.Exec(fmt.Sprintf("create table \"%s\" as select * from \"%s\" %s", newTableName, copy.TableName, condition))

View File

@@ -26,8 +26,8 @@ type PgsqlDialect struct {
dc *dbi.DbConn dc *dbi.DbConn
} }
func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) { func (md *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := pd.dc.Query("SHOW server_version") _, res, err := md.dc.Query("SHOW server_version")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -37,8 +37,8 @@ func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
return ds, nil return ds, nil
} }
func (pd *PgsqlDialect) GetDbNames() ([]string, error) { func (md *PgsqlDialect) GetDbNames() ([]string, error) {
_, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')") _, res, err := md.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -52,8 +52,8 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) {
} }
// 获取表基础元信息, 如表名等 // 获取表基础元信息, 如表名等
func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) { func (md *PgsqlDialect) GetTables() ([]dbi.Table, error) {
_, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) _, res, err := md.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -73,13 +73,13 @@ func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) {
} }
// 获取列元信息, 如列名等 // 获取列元信息, 如列名等
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { func (md *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dbType := pd.dc.Info.Type dbType := md.dc.Info.Type
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string { tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dbType.RemoveQuote(val)) return fmt.Sprintf("'%s'", dbType.RemoveQuote(val))
}), ",") }), ",")
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -100,8 +100,8 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
return columns, nil return columns, nil
} }
func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) { func (md *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
columns, err := pd.GetColumns(tablename) columns, err := md.GetColumns(tablename)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -118,8 +118,8 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
} }
// 获取表索引信息 // 获取表索引信息
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { func (md *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -155,17 +155,17 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
} }
// 获取建表ddl // 获取建表ddl
func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) { func (md *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
_, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) _, err := md.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
if err != nil { if err != nil {
return "", err return "", err
} }
_, schemaRes, _ := pd.dc.Query("select current_schema() as schema") _, schemaRes, _ := md.dc.Query("select current_schema() as schema")
schemaName := schemaRes[0]["schema"].(string) schemaName := schemaRes[0]["schema"].(string)
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName) ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName)
_, res, err := pd.dc.Query(ddlSql) _, res, err := md.dc.Query(ddlSql)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -173,14 +173,14 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
return res[0]["sql"].(string), nil return res[0]["sql"].(string), nil
} }
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { func (md *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
} }
// 获取pgsql当前连接的库可访问的schemaNames // 获取pgsql当前连接的库可访问的schemaNames
func (pd *PgsqlDialect) GetSchemas() ([]string, error) { func (md *PgsqlDialect) GetSchemas() ([]string, error) {
sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS) sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
_, res, err := pd.dc.Query(sql) _, res, err := md.dc.Query(sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -192,30 +192,11 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
} }
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram { func (md *PgsqlDialect) GetDbProgram() dbi.DbProgram {
panic("implement me") panic("implement me")
} }
func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType { func (md *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 执行批量insert sql跟mysql一样 pg或高斯支持批量insert语法 // 执行批量insert sql跟mysql一样 pg或高斯支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ... // insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
@@ -235,37 +216,79 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")") placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
} }
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", ")) sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
// 执行批量insert sql // 执行批量insert sql
return pd.dc.TxExec(tx, sqlStr, args...) return md.dc.TxExec(tx, sqlStr, args...)
} }
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
)
type DataConverter struct {
}
func (md *PgsqlDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
}
func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
return dbi.DataTypeString
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := fmt.Sprintf("%v", dbColumnValue)
switch dataType { switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateTime) return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly) return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly) return res.Format(time.TimeOnly)
} }
return anyx.ConvString(dbColumnValue)
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue return dbColumnValue
} }
func (pd *PgsqlDialect) IsGauss() bool { func (md *PgsqlDialect) IsGauss() bool {
return strings.Contains(pd.dc.Info.Params, "gauss") return strings.Contains(md.dc.Info.Params, "gauss")
} }
func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error { func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName tableName := copy.TableName
// 生成新表名,为老表明+_copy_时间戳 // 生成新表名,为老表明+_copy_时间戳
newTableName := tableName + "_copy_" + time.Now().Format("20060102150405") newTableName := tableName + "_copy_" + time.Now().Format("20060102150405")
// 执行根据旧表创建新表 // 执行根据旧表创建新表
_, err := pd.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName)) _, err := md.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName))
if err != nil { if err != nil {
return err return err
} }
@@ -273,12 +296,12 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 复制数据 // 复制数据
if copy.CopyData { if copy.CopyData {
go func() { go func() {
_, _ = pd.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName)) _, _ = md.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName))
}() }()
} }
// 查询旧表的自增字段名 重新设置新表的序列序列器 // 查询旧表的自增字段名 重新设置新表的序列序列器
_, res, err := pd.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName)) _, res, err := md.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName))
if err != nil { if err != nil {
return err return err
} }
@@ -288,7 +311,7 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
if colName != "" { if colName != "" {
// 查询自增列当前最大值 // 查询自增列当前最大值
_, maxRes, err := pd.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName)) _, maxRes, err := md.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName))
if err != nil { if err != nil {
return err return err
} }
@@ -304,12 +327,12 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
newSeqName := fmt.Sprintf("%s_%s_copy_seq", newTableName, colName) newSeqName := fmt.Sprintf("%s_%s_copy_seq", newTableName, colName)
// 创建自增序列,当前最大值为旧表最大值 // 创建自增序列,当前最大值为旧表最大值
_, err = pd.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal)) _, err = md.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal))
if err != nil { if err != nil {
return err return err
} }
// 将新表的自增主键序列与主键列相关联 // 将新表的自增主键序列与主键列相关联
_, err = pd.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName)) _, err = md.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -196,16 +196,6 @@ func (sd *SqliteDialect) GetDbProgram() dbi.DbProgram {
panic("implement me") panic("implement me")
} }
func (sd *SqliteDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int`).MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
if regexp.MustCompile(`(?i)datetime`).MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
return dbi.DataTypeString
}
func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 执行批量insert sql跟mysql一样 支持批量insert语法 // 执行批量insert sql跟mysql一样 支持批量insert语法
// 生成占位符字符串:如:(?,?) // 生成占位符字符串:如:(?,?)
@@ -231,18 +221,47 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
return sd.dc.TxExec(tx, sqlStr, args...) return sd.dc.TxExec(tx, sqlStr, args...)
} }
func (sd *SqliteDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime`)
)
type DataConverter struct {
}
func (sd *SqliteDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
}
func (dc *DataConverter) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
return dbi.DataTypeString
}
func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) string {
str := anyx.ToString(dbColumnValue)
switch dataType { switch dataType {
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateTime) return res.Format(time.DateTime)
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.DateOnly) return res.Format(time.DateOnly)
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue) res, _ := time.Parse(time.RFC3339, str)
return res.Format(time.TimeOnly) return res.Format(time.TimeOnly)
} }
return str
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
return dbColumnValue return dbColumnValue
} }