mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	refactor: dbm重构、调整metadata与dialect接口
This commit is contained in:
		@@ -4,8 +4,6 @@ import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/sqlparser"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/sqlparser/pgsql"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -52,7 +50,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
 | 
			
		||||
		suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix)
 | 
			
		||||
	sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix)
 | 
			
		||||
	// 执行批量insert sql
 | 
			
		||||
 | 
			
		||||
	return pd.dc.TxExec(tx, sqlStr, args...)
 | 
			
		||||
@@ -86,7 +84,7 @@ func (pd *PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, table
 | 
			
		||||
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考:https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
 | 
			
		||||
func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
 | 
			
		||||
	suffix := ""
 | 
			
		||||
	metadata := pd.dc.GetMetaData()
 | 
			
		||||
	metadata := pd.dc.GetMetadata()
 | 
			
		||||
	if duplicateStrategy == dbi.DuplicateStrategyIgnore {
 | 
			
		||||
		suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
 | 
			
		||||
	} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
 | 
			
		||||
@@ -110,7 +108,7 @@ func (pd *PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, table
 | 
			
		||||
		suffix = " \n ON DUPLICATE KEY UPDATE "
 | 
			
		||||
		for i, col := range columns {
 | 
			
		||||
			// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段,所以得去掉
 | 
			
		||||
			if !collx.ArrayContains(uniqueColumns, metadata.RemoveQuote(strings.ToLower(col))) {
 | 
			
		||||
			if !collx.ArrayContains(uniqueColumns, pd.RemoveQuote(strings.ToLower(col))) {
 | 
			
		||||
				suffix += fmt.Sprintf("%s = excluded.%s", col, col)
 | 
			
		||||
				if i < len(columns)-1 {
 | 
			
		||||
					suffix += ", "
 | 
			
		||||
@@ -178,17 +176,101 @@ func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) CreateTable(commonColumns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
	sqlArr := meta.GenerateTableDDL(commonColumns, tableInfo, dropOldTable)
 | 
			
		||||
	_, err := pd.dc.Exec(strings.Join(sqlArr, ";"))
 | 
			
		||||
	return len(sqlArr), err
 | 
			
		||||
func (pd *PgsqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
 | 
			
		||||
	quoteTableName := pd.QuoteIdentifier(tableInfo.TableName)
 | 
			
		||||
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
	if dropBeforeCreate {
 | 
			
		||||
		sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
 | 
			
		||||
	}
 | 
			
		||||
	// 组装建表语句
 | 
			
		||||
	createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
 | 
			
		||||
	fields := make([]string, 0)
 | 
			
		||||
	pks := make([]string, 0)
 | 
			
		||||
	columnComments := make([]string, 0)
 | 
			
		||||
	commentTmp := "comment on column %s.%s is '%s'"
 | 
			
		||||
 | 
			
		||||
	for _, column := range columns {
 | 
			
		||||
		if column.IsPrimaryKey {
 | 
			
		||||
			pks = append(pks, pd.QuoteIdentifier(column.ColumnName))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fields = append(fields, pd.genColumnBasicSql(column))
 | 
			
		||||
 | 
			
		||||
		// 防止注释内含有特殊字符串导致sql出错
 | 
			
		||||
		if column.ColumnComment != "" {
 | 
			
		||||
			comment := pd.QuoteEscape(column.ColumnComment)
 | 
			
		||||
			columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteIdentifier(column.ColumnName), comment))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	createSql += strings.Join(fields, ",\n")
 | 
			
		||||
	if len(pks) > 0 {
 | 
			
		||||
		createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
 | 
			
		||||
	}
 | 
			
		||||
	createSql += "\n)"
 | 
			
		||||
 | 
			
		||||
	tableCommentSql := ""
 | 
			
		||||
	if tableInfo.TableComment != "" {
 | 
			
		||||
		commentTmp := "comment on table %s is '%s'"
 | 
			
		||||
		tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, pd.QuoteEscape(tableInfo.TableComment))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create
 | 
			
		||||
	sqlArr = append(sqlArr, createSql)
 | 
			
		||||
 | 
			
		||||
	// table comment
 | 
			
		||||
	if tableCommentSql != "" {
 | 
			
		||||
		sqlArr = append(sqlArr, tableCommentSql)
 | 
			
		||||
	}
 | 
			
		||||
	// column comment
 | 
			
		||||
	if len(columnComments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, columnComments...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error {
 | 
			
		||||
	sqlArr := pd.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo)
 | 
			
		||||
	_, err := pd.dc.Exec(strings.Join(sqlArr, ";"))
 | 
			
		||||
	return err
 | 
			
		||||
func (pd *PgsqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
 | 
			
		||||
	creates := make([]string, 0)
 | 
			
		||||
	drops := make([]string, 0)
 | 
			
		||||
	comments := make([]string, 0)
 | 
			
		||||
	for _, index := range indexs {
 | 
			
		||||
		unique := ""
 | 
			
		||||
		if index.IsUnique {
 | 
			
		||||
			unique = "unique"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果索引名存在,先删除索引
 | 
			
		||||
		drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
 | 
			
		||||
 | 
			
		||||
		// 取出列名,添加引号
 | 
			
		||||
		cols := strings.Split(index.ColumnName, ",")
 | 
			
		||||
		colNames := make([]string, len(cols))
 | 
			
		||||
		for i, name := range cols {
 | 
			
		||||
			colNames[i] = pd.QuoteIdentifier(name)
 | 
			
		||||
		}
 | 
			
		||||
		// 创建索引
 | 
			
		||||
		creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, pd.QuoteIdentifier(index.IndexName), pd.QuoteIdentifier(pd.dc.Info.CurrentSchema()), pd.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
 | 
			
		||||
		if index.IndexComment != "" {
 | 
			
		||||
			comment := pd.QuoteEscape(index.IndexComment)
 | 
			
		||||
			comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	if len(drops) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, drops...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(creates) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, creates...)
 | 
			
		||||
	}
 | 
			
		||||
	if len(comments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, comments...)
 | 
			
		||||
	}
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) {
 | 
			
		||||
@@ -199,6 +281,77 @@ func (pd *PgsqlDialect) UpdateSequence(tableName string, columns []dbi.Column) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) GetSQLParser() sqlparser.SqlParser {
 | 
			
		||||
	return new(pgsql.PgsqlParser)
 | 
			
		||||
func (pd *PgsqlDialect) GetDataHelper() dbi.DataHelper {
 | 
			
		||||
	return new(DataHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) GetColumnHelper() dbi.ColumnHelper {
 | 
			
		||||
	return new(ColumnHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) GetDumpHelper() dbi.DumpHelper {
 | 
			
		||||
	return new(DumpHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) genColumnBasicSql(column dbi.Column) string {
 | 
			
		||||
	colName := pd.QuoteIdentifier(column.ColumnName)
 | 
			
		||||
	dataType := string(column.DataType)
 | 
			
		||||
 | 
			
		||||
	// 如果数据类型是数字,则去掉长度
 | 
			
		||||
	if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
 | 
			
		||||
		column.NumPrecision = 0
 | 
			
		||||
		column.CharMaxLength = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是自增类型,需要转换为serial
 | 
			
		||||
	if column.IsIdentity {
 | 
			
		||||
		if dataType == "int4" {
 | 
			
		||||
			column.DataType = "serial"
 | 
			
		||||
		} else if dataType == "int2" {
 | 
			
		||||
			column.DataType = "smallserial"
 | 
			
		||||
		} else if dataType == "int8" {
 | 
			
		||||
			column.DataType = "bigserial"
 | 
			
		||||
		} else {
 | 
			
		||||
			column.DataType = "bigserial"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nullAble := ""
 | 
			
		||||
	if !column.Nullable {
 | 
			
		||||
		nullAble = " NOT NULL"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
 | 
			
		||||
	if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
 | 
			
		||||
		mark := false
 | 
			
		||||
		// 哪些字段类型默认值需要加引号
 | 
			
		||||
		if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
 | 
			
		||||
			// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
 | 
			
		||||
			if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
 | 
			
		||||
				collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
 | 
			
		||||
				mark = false
 | 
			
		||||
			} else {
 | 
			
		||||
				mark = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 如果数据类型是日期时间,则写死默认值函数
 | 
			
		||||
		if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
 | 
			
		||||
			column.ColumnDefault = "CURRENT_TIMESTAMP"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if mark {
 | 
			
		||||
			defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
 | 
			
		||||
		} else {
 | 
			
		||||
			defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是varchar,长度翻倍,防止报错
 | 
			
		||||
	if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
 | 
			
		||||
		column.CharMaxLength = column.CharMaxLength * 2
 | 
			
		||||
	}
 | 
			
		||||
	columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
 | 
			
		||||
	return columnSql
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -85,8 +85,8 @@ func (pm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
 | 
			
		||||
	return &PgsqlDialect{dc: conn}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *Meta) GetMetaData(conn *dbi.DbConn) *dbi.MetaDataX {
 | 
			
		||||
	return dbi.NewMetaDataX(&PgsqlMetaData{dc: conn})
 | 
			
		||||
func (pm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
 | 
			
		||||
	return &PgsqlMetadata{dc: conn}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// pgsql dialer
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package postgres
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
@@ -21,13 +20,13 @@ const (
 | 
			
		||||
	PGSQL_COLUMN_MA_KEY  = "PGSQL_COLUMN_MA"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PgsqlMetaData struct {
 | 
			
		||||
	dbi.DefaultMetaData
 | 
			
		||||
type PgsqlMetadata struct {
 | 
			
		||||
	dbi.DefaultMetadata
 | 
			
		||||
 | 
			
		||||
	dc *dbi.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
func (pd *PgsqlMetadata) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query("SELECT version() as server_version")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -38,7 +37,7 @@ func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
	return ds, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetDbNames() ([]string, error) {
 | 
			
		||||
func (pd *PgsqlMetadata) GetDbNames() ([]string, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -52,10 +51,10 @@ func (pd *PgsqlMetaData) GetDbNames() ([]string, error) {
 | 
			
		||||
	return databases, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
 | 
			
		||||
	dialect := pd.dc.GetDialect()
 | 
			
		||||
	names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
 | 
			
		||||
		return fmt.Sprintf("'%s'", meta.RemoveQuote(val))
 | 
			
		||||
		return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
 | 
			
		||||
	}), ",")
 | 
			
		||||
 | 
			
		||||
	var res []map[string]any
 | 
			
		||||
@@ -86,10 +85,10 @@ func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取列元信息, 如列名等
 | 
			
		||||
func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
 | 
			
		||||
	dialect := pd.dc.GetDialect()
 | 
			
		||||
	tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
 | 
			
		||||
		return fmt.Sprintf("'%s'", meta.RemoveQuote(val))
 | 
			
		||||
		return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
 | 
			
		||||
	}), ",")
 | 
			
		||||
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
 | 
			
		||||
@@ -97,7 +96,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columnHelper := meta.GetColumnHelper()
 | 
			
		||||
	columnHelper := dialect.GetColumnHelper()
 | 
			
		||||
	columns := make([]dbi.Column, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		column := dbi.Column{
 | 
			
		||||
@@ -119,7 +118,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
 | 
			
		||||
	return columns, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
func (pd *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
	columns, err := pd.GetColumns(tablename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
@@ -137,7 +136,7 @@ func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表索引信息
 | 
			
		||||
func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -174,172 +173,8 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
	creates := make([]string, 0)
 | 
			
		||||
	drops := make([]string, 0)
 | 
			
		||||
	comments := make([]string, 0)
 | 
			
		||||
	for _, index := range indexs {
 | 
			
		||||
		unique := ""
 | 
			
		||||
		if index.IsUnique {
 | 
			
		||||
			unique = "unique"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果索引名存在,先删除索引
 | 
			
		||||
		drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
 | 
			
		||||
 | 
			
		||||
		// 取出列名,添加引号
 | 
			
		||||
		cols := strings.Split(index.ColumnName, ",")
 | 
			
		||||
		colNames := make([]string, len(cols))
 | 
			
		||||
		for i, name := range cols {
 | 
			
		||||
			colNames[i] = meta.QuoteIdentifier(name)
 | 
			
		||||
		}
 | 
			
		||||
		// 创建索引
 | 
			
		||||
		creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(pd.dc.Info.CurrentSchema()), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
 | 
			
		||||
		if index.IndexComment != "" {
 | 
			
		||||
			comment := meta.QuoteEscape(index.IndexComment)
 | 
			
		||||
			comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	if len(drops) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, drops...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(creates) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, creates...)
 | 
			
		||||
	}
 | 
			
		||||
	if len(comments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, comments...)
 | 
			
		||||
	}
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
	colName := meta.QuoteIdentifier(column.ColumnName)
 | 
			
		||||
	dataType := string(column.DataType)
 | 
			
		||||
 | 
			
		||||
	// 如果数据类型是数字,则去掉长度
 | 
			
		||||
	if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
 | 
			
		||||
		column.NumPrecision = 0
 | 
			
		||||
		column.CharMaxLength = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是自增类型,需要转换为serial
 | 
			
		||||
	if column.IsIdentity {
 | 
			
		||||
		if dataType == "int4" {
 | 
			
		||||
			column.DataType = "serial"
 | 
			
		||||
		} else if dataType == "int2" {
 | 
			
		||||
			column.DataType = "smallserial"
 | 
			
		||||
		} else if dataType == "int8" {
 | 
			
		||||
			column.DataType = "bigserial"
 | 
			
		||||
		} else {
 | 
			
		||||
			column.DataType = "bigserial"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nullAble := ""
 | 
			
		||||
	if !column.Nullable {
 | 
			
		||||
		nullAble = " NOT NULL"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
 | 
			
		||||
	if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
 | 
			
		||||
		mark := false
 | 
			
		||||
		// 哪些字段类型默认值需要加引号
 | 
			
		||||
		if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
 | 
			
		||||
			// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
 | 
			
		||||
			if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
 | 
			
		||||
				collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
 | 
			
		||||
				mark = false
 | 
			
		||||
			} else {
 | 
			
		||||
				mark = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 如果数据类型是日期时间,则写死默认值函数
 | 
			
		||||
		if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
 | 
			
		||||
			column.ColumnDefault = "CURRENT_TIMESTAMP"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if mark {
 | 
			
		||||
			defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
 | 
			
		||||
		} else {
 | 
			
		||||
			defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果是varchar,长度翻倍,防止报错
 | 
			
		||||
	if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
 | 
			
		||||
		column.CharMaxLength = column.CharMaxLength * 2
 | 
			
		||||
	}
 | 
			
		||||
	columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
 | 
			
		||||
	return columnSql
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
 | 
			
		||||
	meta := pd.dc.GetMetaData()
 | 
			
		||||
	quoteTableName := meta.QuoteIdentifier(tableInfo.TableName)
 | 
			
		||||
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
	if dropBeforeCreate {
 | 
			
		||||
		sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
 | 
			
		||||
	}
 | 
			
		||||
	// 组装建表语句
 | 
			
		||||
	createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
 | 
			
		||||
	fields := make([]string, 0)
 | 
			
		||||
	pks := make([]string, 0)
 | 
			
		||||
	columnComments := make([]string, 0)
 | 
			
		||||
	commentTmp := "comment on column %s.%s is '%s'"
 | 
			
		||||
 | 
			
		||||
	for _, column := range columns {
 | 
			
		||||
		if column.IsPrimaryKey {
 | 
			
		||||
			pks = append(pks, meta.QuoteIdentifier(column.ColumnName))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fields = append(fields, pd.genColumnBasicSql(column))
 | 
			
		||||
 | 
			
		||||
		// 防止注释内含有特殊字符串导致sql出错
 | 
			
		||||
		if column.ColumnComment != "" {
 | 
			
		||||
			comment := meta.QuoteEscape(column.ColumnComment)
 | 
			
		||||
			columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	createSql += strings.Join(fields, ",\n")
 | 
			
		||||
	if len(pks) > 0 {
 | 
			
		||||
		createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
 | 
			
		||||
	}
 | 
			
		||||
	createSql += "\n)"
 | 
			
		||||
 | 
			
		||||
	tableCommentSql := ""
 | 
			
		||||
	if tableInfo.TableComment != "" {
 | 
			
		||||
		commentTmp := "comment on table %s is '%s'"
 | 
			
		||||
		tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteEscape(tableInfo.TableComment))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create
 | 
			
		||||
	sqlArr = append(sqlArr, createSql)
 | 
			
		||||
 | 
			
		||||
	// table comment
 | 
			
		||||
	if tableCommentSql != "" {
 | 
			
		||||
		sqlArr = append(sqlArr, tableCommentSql)
 | 
			
		||||
	}
 | 
			
		||||
	// column comment
 | 
			
		||||
	if len(columnComments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, columnComments...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取建表ddl
 | 
			
		||||
func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
 | 
			
		||||
	// 1.获取表信息
 | 
			
		||||
	tbs, err := pd.GetTables(tableName)
 | 
			
		||||
	tableInfo := &dbi.Table{}
 | 
			
		||||
@@ -356,7 +191,8 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s
 | 
			
		||||
		logx.Errorf("获取列信息失败, %s", tableName)
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
 | 
			
		||||
	dialect := pd.dc.GetDialect()
 | 
			
		||||
	tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
 | 
			
		||||
	// 3.获取索引信息
 | 
			
		||||
	indexs, err := pd.GetTableIndex(tableName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -364,12 +200,12 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	// 组装返回
 | 
			
		||||
	tableDDLArr = append(tableDDLArr, pd.GenerateIndexDDL(indexs, *tableInfo)...)
 | 
			
		||||
	tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
 | 
			
		||||
	return strings.Join(tableDDLArr, ";\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取pgsql当前连接的库可访问的schemaNames
 | 
			
		||||
func (pd *PgsqlMetaData) GetSchemas() ([]string, error) {
 | 
			
		||||
func (pd *PgsqlMetadata) GetSchemas() ([]string, error) {
 | 
			
		||||
	sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
 | 
			
		||||
	_, res, err := pd.dc.Query(sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -382,7 +218,7 @@ func (pd *PgsqlMetaData) GetSchemas() ([]string, error) {
 | 
			
		||||
	return schemaNames, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) DefaultDb() string {
 | 
			
		||||
func (pd *PgsqlMetadata) GetDefaultDb() string {
 | 
			
		||||
	switch pd.dc.Info.Type {
 | 
			
		||||
	case dbi.DbTypePostgres, dbi.DbTypeGauss:
 | 
			
		||||
		return "postgres"
 | 
			
		||||
@@ -394,28 +230,3 @@ func (pd *PgsqlMetaData) DefaultDb() string {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) {
 | 
			
		||||
 | 
			
		||||
	// 设置自增序列当前值
 | 
			
		||||
	for _, column := range columns {
 | 
			
		||||
		if column.IsIdentity {
 | 
			
		||||
			seq := fmt.Sprintf("SELECT setval('%s_%s_seq', (SELECT max(%s) FROM %s));\n", tableName, column.ColumnName, column.ColumnName, tableName)
 | 
			
		||||
			writer.Write([]byte(seq))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	writer.Write([]byte("COMMIT;\n"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetDataHelper() dbi.DataHelper {
 | 
			
		||||
	return new(DataHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetColumnHelper() dbi.ColumnHelper {
 | 
			
		||||
	return new(ColumnHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlMetaData) GetDumpHelper() dbi.DumpHelper {
 | 
			
		||||
	return new(DumpHelper)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user