refactor: dbm重构、调整metadata与dialect接口

This commit is contained in:
meilin.huang
2024-11-01 17:27:22 +08:00
parent af14be9801
commit 74ae031853
36 changed files with 1216 additions and 1384 deletions

View File

@@ -2,7 +2,6 @@ package postgres
import (
"fmt"
"io"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
@@ -21,13 +20,13 @@ const (
PGSQL_COLUMN_MA_KEY = "PGSQL_COLUMN_MA"
)
type PgsqlMetaData struct {
dbi.DefaultMetaData
type PgsqlMetadata struct {
dbi.DefaultMetadata
dc *dbi.DbConn
}
func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) {
func (pd *PgsqlMetadata) GetDbServer() (*dbi.DbServer, error) {
_, res, err := pd.dc.Query("SELECT version() as server_version")
if err != nil {
return nil, err
@@ -38,7 +37,7 @@ func (pd *PgsqlMetaData) GetDbServer() (*dbi.DbServer, error) {
return ds, nil
}
func (pd *PgsqlMetaData) GetDbNames() ([]string, error) {
func (pd *PgsqlMetadata) GetDbNames() ([]string, error) {
_, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')")
if err != nil {
return nil, err
@@ -52,10 +51,10 @@ func (pd *PgsqlMetaData) GetDbNames() ([]string, error) {
return databases, nil
}
func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) {
meta := pd.dc.GetMetaData()
func (pd *PgsqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := pd.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", meta.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
}), ",")
var res []map[string]any
@@ -86,10 +85,10 @@ func (pd *PgsqlMetaData) GetTables(tableNames ...string) ([]dbi.Table, error) {
}
// 获取列元信息, 如列名等
func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error) {
meta := pd.dc.GetMetaData()
func (pd *PgsqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := pd.dc.GetDialect()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", meta.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
}), ",")
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
@@ -97,7 +96,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
return nil, err
}
columnHelper := meta.GetColumnHelper()
columnHelper := dialect.GetColumnHelper()
columns := make([]dbi.Column, 0)
for _, re := range res {
column := dbi.Column{
@@ -119,7 +118,7 @@ func (pd *PgsqlMetaData) GetColumns(tableNames ...string) ([]dbi.Column, error)
return columns, nil
}
func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) {
func (pd *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) {
columns, err := pd.GetColumns(tablename)
if err != nil {
return "", err
@@ -137,7 +136,7 @@ func (pd *PgsqlMetaData) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
func (pd *PgsqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
@@ -174,172 +173,8 @@ func (pd *PgsqlMetaData) GetTableIndex(tableName string) ([]dbi.Index, error) {
return result, nil
}
func (pd *PgsqlMetaData) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
meta := pd.dc.GetMetaData()
creates := make([]string, 0)
drops := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 如果索引名存在,先删除索引
drops = append(drops, fmt.Sprintf("drop index if exists %s.%s", pd.dc.Info.CurrentSchema(), index.IndexName))
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = meta.QuoteIdentifier(name)
}
// 创建索引
creates = append(creates, fmt.Sprintf("CREATE %s INDEX %s on %s.%s(%s)", unique, meta.QuoteIdentifier(index.IndexName), meta.QuoteIdentifier(pd.dc.Info.CurrentSchema()), meta.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := meta.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("COMMENT ON INDEX %s.%s IS '%s'", pd.dc.Info.CurrentSchema(), index.IndexName, comment))
}
}
sqlArr := make([]string, 0)
if len(drops) > 0 {
sqlArr = append(sqlArr, drops...)
}
if len(creates) > 0 {
sqlArr = append(sqlArr, creates...)
}
if len(comments) > 0 {
sqlArr = append(sqlArr, comments...)
}
return sqlArr
}
func (pd *PgsqlMetaData) genColumnBasicSql(column dbi.Column) string {
meta := pd.dc.GetMetaData()
colName := meta.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
// 如果数据类型是数字,则去掉长度
if collx.ArrayAnyMatches([]string{"int"}, strings.ToLower(dataType)) {
column.NumPrecision = 0
column.CharMaxLength = 0
}
// 如果是自增类型需要转换为serial
if column.IsIdentity {
if dataType == "int4" {
column.DataType = "serial"
} else if dataType == "int2" {
column.DataType = "smallserial"
} else if dataType == "int8" {
column.DataType = "bigserial"
} else {
column.DataType = "bigserial"
}
return fmt.Sprintf(" %s %s NOT NULL", colName, column.GetColumnType())
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
mark := false
// 哪些字段类型默认值需要加引号
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
// 如果数据类型是日期时间,则写死默认值函数
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) {
column.ColumnDefault = "CURRENT_TIMESTAMP"
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
// 如果是varchar长度翻倍防止报错
if collx.ArrayAnyMatches([]string{"char"}, strings.ToLower(dataType)) {
column.CharMaxLength = column.CharMaxLength * 2
}
columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), nullAble, defVal)
return columnSql
}
func (pd *PgsqlMetaData) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
meta := pd.dc.GetMetaData()
quoteTableName := meta.QuoteIdentifier(tableInfo.TableName)
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteTableName))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoteTableName)
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
commentTmp := "comment on column %s.%s is '%s'"
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, meta.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, pd.genColumnBasicSql(column))
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := meta.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteIdentifier(column.ColumnName), comment))
}
}
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "comment on table %s is '%s'"
tableCommentSql = fmt.Sprintf(commentTmp, quoteTableName, meta.QuoteEscape(tableInfo.TableComment))
}
// create
sqlArr = append(sqlArr, createSql)
// table comment
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
// column comment
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
// 获取建表ddl
func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
func (pd *PgsqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := pd.GetTables(tableName)
tableInfo := &dbi.Table{}
@@ -356,7 +191,8 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
tableDDLArr := pd.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
dialect := pd.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := pd.GetTableIndex(tableName)
if err != nil {
@@ -364,12 +200,12 @@ func (pd *PgsqlMetaData) GetTableDDL(tableName string, dropBeforeCreate bool) (s
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, pd.GenerateIndexDDL(indexs, *tableInfo)...)
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
}
// 获取pgsql当前连接的库可访问的schemaNames
func (pd *PgsqlMetaData) GetSchemas() ([]string, error) {
func (pd *PgsqlMetadata) GetSchemas() ([]string, error) {
sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
_, res, err := pd.dc.Query(sql)
if err != nil {
@@ -382,7 +218,7 @@ func (pd *PgsqlMetaData) GetSchemas() ([]string, error) {
return schemaNames, nil
}
func (pd *PgsqlMetaData) DefaultDb() string {
func (pd *PgsqlMetadata) GetDefaultDb() string {
switch pd.dc.Info.Type {
case dbi.DbTypePostgres, dbi.DbTypeGauss:
return "postgres"
@@ -394,28 +230,3 @@ func (pd *PgsqlMetaData) DefaultDb() string {
return ""
}
}
func (pd *PgsqlMetaData) AfterDumpInsert(writer io.Writer, tableName string, columns []dbi.Column) {
// 设置自增序列当前值
for _, column := range columns {
if column.IsIdentity {
seq := fmt.Sprintf("SELECT setval('%s_%s_seq', (SELECT max(%s) FROM %s));\n", tableName, column.ColumnName, column.ColumnName, tableName)
writer.Write([]byte(seq))
}
}
writer.Write([]byte("COMMIT;\n"))
}
func (pd *PgsqlMetaData) GetDataHelper() dbi.DataHelper {
return new(DataHelper)
}
func (pd *PgsqlMetaData) GetColumnHelper() dbi.ColumnHelper {
return new(ColumnHelper)
}
func (pd *PgsqlMetaData) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}