Merge pull request #72 from kanzihuang/refactor-dbtype

refactor: 实现 DbType 类型,集中处理部分差异化的数据库操作
This commit is contained in:
may-fly
2023-10-15 19:41:19 -05:00
committed by GitHub
8 changed files with 127 additions and 61 deletions

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"github.com/lib/pq"
"io"
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
@@ -326,24 +325,6 @@ func (d *Db) DumpSql(rc *req.Ctx) {
rc.ReqParam = collx.Kvs("db", db, "databases", dbNamesStr, "tables", tablesStr, "dumpType", dumpType)
}
func escapeSql(dbType string, sql string) string {
if dbType == entity.DbTypePostgres {
return pq.QuoteLiteral(sql)
} else {
sql = strings.ReplaceAll(sql, `\`, `\\`)
sql = strings.ReplaceAll(sql, `'`, `''`)
return "'" + sql + "'"
}
}
func quoteTable(dbType string, table string) string {
if dbType == entity.DbTypePostgres {
return "\"" + table + "\""
} else {
return "`" + table + "`"
}
}
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
writer.WriteString("\n-- ----------------------------")
@@ -355,14 +336,12 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
if switchDb {
switch dbConn.Info.Type {
case entity.DbTypeMysql:
writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
writer.WriteString(fmt.Sprintf("USE %s;\n", entity.DbTypeMysql.QuoteIdentifier(dbName)))
default:
biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
}
}
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
}
writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(false))
dbMeta := dbConn.GetMeta()
if len(tables) == 0 {
@@ -375,7 +354,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
for _, table := range tables {
writer.TryFlush()
quotedTable := quoteTable(dbConn.Info.Type, table)
quotedTable := dbConn.Info.Type.QuoteIdentifier(table)
if needStruct {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
@@ -398,7 +377,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
}
strValue, ok := value.(string)
if ok {
strValue = escapeSql(dbConn.Info.Type, strValue)
strValue = dbConn.Info.Type.QuoteLiteral(strValue)
values = append(values, strValue)
} else {
values = append(values, stringx.AnyToStr(value))
@@ -408,9 +387,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
})
writer.WriteString("COMMIT;\n")
}
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
}
writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(true))
}
// @router /api/db/:dbId/t-metadata [get]

View File

@@ -9,46 +9,46 @@ import (
func Test_escapeSql(t *testing.T) {
tests := []struct {
name string
dbType string
dbType entity.DBType
sql string
want string
}{
{
dbType: entity.DbTypeMysql,
dbType: entity.DBTypeMysql{},
sql: "\\a\\b",
want: "'\\\\a\\\\b'",
},
{
dbType: entity.DbTypeMysql,
dbType: entity.DBTypeMysql{},
sql: "'a'",
want: "'''a'''",
},
{
name: "不间断空格",
dbType: entity.DbTypeMysql,
dbType: entity.DBTypeMysql{},
sql: "a\u00A0b",
want: "'a\u00A0b'",
},
{
dbType: entity.DbTypePostgres,
dbType: entity.DBTypePostgres{},
sql: "\\a\\b",
want: " E'\\\\a\\\\b'",
},
{
dbType: entity.DbTypePostgres,
dbType: entity.DBTypePostgres{},
sql: "'a'",
want: "'''a'''",
},
{
name: "不间断空格",
dbType: entity.DbTypePostgres,
dbType: entity.DBTypePostgres{},
sql: "a\u00A0b",
want: "'a\u00A0b'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := escapeSql(tt.dbType, tt.sql)
got := tt.dbType.QuoteLiteral(tt.sql)
require.Equal(t, tt.want, got)
})
}

View File

@@ -144,6 +144,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置数据库【%s】的操作权限", dbName)
instance := d.dbInstanceApp.GetById(db.InstanceId)
biz.NotNil(instance, "数据库实例不存在")
// 密码解密
instance.PwdDecrypt()
@@ -178,7 +179,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
type DbInfo struct {
Id uint64
Name string
Type string // 类型mysql oracle等
Type entity.DbType // 类型mysql oracle等
Host string
Port int
Network string
@@ -242,14 +243,14 @@ func (d *DbConnection) Exec(sql string) (int64, error) {
// 获取数据库元信息实现接口
func (d *DbConnection) GetMeta() DbMetadata {
dbType := d.Info.Type
if dbType == entity.DbTypeMysql {
switch d.Info.Type {
case entity.DbTypeMysql:
return &MysqlMetadata{di: d}
}
if dbType == entity.DbTypePostgres {
case entity.DbTypePostgres:
return &PgsqlMetadata{di: d}
default:
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
}
return nil
}
// 关闭连接

View File

@@ -2,6 +2,7 @@ package application
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/biz"
@@ -97,10 +98,13 @@ func (app *instanceAppImpl) Delete(id uint64) {
func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) {
var conn *sql.DB
var err error
if instance.Type == entity.DbTypeMysql {
switch instance.Type {
case entity.DbTypeMysql:
conn, err = getMysqlDB(instance, db)
} else if instance.Type == entity.DbTypePostgres {
case entity.DbTypePostgres:
conn, err = getPgsqlDB(instance, db)
default:
panic(fmt.Sprintf("invalid database type: %s", instance.Type))
}
if err != nil {
@@ -126,15 +130,8 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) []string {
ed.Network = ed.GetNetwork()
databases := make([]string, 0)
var dbConn *sql.DB
var metaDb string
var getDatabasesSql string
if ed.Type == entity.DbTypeMysql {
metaDb = "information_schema"
getDatabasesSql = "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA"
} else {
metaDb = "postgres"
getDatabasesSql = "SELECT datname AS dbname FROM pg_database"
}
metaDb := ed.Type.MetaDbName()
getDatabasesSql := ed.Type.StmtSelectDbName()
dbConn, err := getInstanceConn(ed, metaDb)
biz.ErrIsNilAppendErr(err, "数据库连接失败: %s")

View File

@@ -26,7 +26,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
return sql.Open(d.Type, dsn)
return sql.Open(string(d.Type), dsn)
}
// ---------------------------------- mysql元数据 -----------------------------------

View File

@@ -17,7 +17,7 @@ import (
)
func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
driverName := d.Type
driverName := string(d.Type)
// SSH Conect
if d.SshTunnelMachineId > 0 {
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名

View File

@@ -0,0 +1,96 @@
package entity
import (
"fmt"
"github.com/lib/pq"
"strings"
)
type DbType string
const (
DbTypeMysql DbType = "mysql"
DbTypePostgres DbType = "postgres"
)
func (dbType DbType) MetaDbName() string {
switch dbType {
case DbTypeMysql:
return "information_schema"
case DbTypePostgres:
return "postgres"
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
func (dbType DbType) QuoteIdentifier(name string) string {
switch dbType {
case DbTypeMysql:
return quoteIdentifier(name, "`")
case DbTypePostgres:
return pq.QuoteIdentifier(name)
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
func (dbType DbType) QuoteLiteral(literal string) string {
switch dbType {
case DbTypeMysql:
literal = strings.ReplaceAll(literal, `\`, `\\`)
literal = strings.ReplaceAll(literal, `'`, `''`)
return "'" + literal + "'"
case DbTypePostgres:
return pq.QuoteLiteral(literal)
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
func (dbType DbType) StmtSelectDbName() string {
switch dbType {
case DbTypeMysql:
return "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA"
case DbTypePostgres:
return "SELECT datname AS dbname FROM pg_database"
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := pq.QuoteIdentifier(tblname)
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func quoteIdentifier(name, quoter string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter
}
func (dbType DbType) StmtSetForeignKeyChecks(check bool) string {
switch dbType {
case DbTypeMysql:
if check {
return "\nSET FOREIGN_KEY_CHECKS = 1;\n"
} else {
return "\nSET FOREIGN_KEY_CHECKS = 0;\n"
}
case DbTypePostgres:
// not currently supported postgres
return ""
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}

View File

@@ -10,7 +10,7 @@ type Instance struct {
model.Model
Name string `orm:"column(name)" json:"name"`
Type string `orm:"column(type)" json:"type"` // 类型mysql oracle等
Type DbType `orm:"column(type)" json:"type"` // 类型mysql oracle等
Host string `orm:"column(host)" json:"host"`
Port int `orm:"column(port)" json:"port"`
Network string `orm:"column(network)" json:"network"`
@@ -47,8 +47,3 @@ func (d *Instance) PwdDecrypt() {
// 密码替换为解密后的密码
d.Password = utils.PwdAesDecrypt(d.Password)
}
const (
DbTypeMysql = "mysql"
DbTypePostgres = "postgres"
)