mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 15:30:25 +08:00
Merge pull request #72 from kanzihuang/refactor-dbtype
refactor: 实现 DbType 类型,集中处理部分差异化的数据库操作
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"github.com/lib/pq"
|
||||
"io"
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
@@ -326,24 +325,6 @@ func (d *Db) DumpSql(rc *req.Ctx) {
|
||||
rc.ReqParam = collx.Kvs("db", db, "databases", dbNamesStr, "tables", tablesStr, "dumpType", dumpType)
|
||||
}
|
||||
|
||||
func escapeSql(dbType string, sql string) string {
|
||||
if dbType == entity.DbTypePostgres {
|
||||
return pq.QuoteLiteral(sql)
|
||||
} else {
|
||||
sql = strings.ReplaceAll(sql, `\`, `\\`)
|
||||
sql = strings.ReplaceAll(sql, `'`, `''`)
|
||||
return "'" + sql + "'"
|
||||
}
|
||||
}
|
||||
|
||||
func quoteTable(dbType string, table string) string {
|
||||
if dbType == entity.DbTypePostgres {
|
||||
return "\"" + table + "\""
|
||||
} else {
|
||||
return "`" + table + "`"
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
|
||||
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
writer.WriteString("\n-- ----------------------------")
|
||||
@@ -355,14 +336,12 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
if switchDb {
|
||||
switch dbConn.Info.Type {
|
||||
case entity.DbTypeMysql:
|
||||
writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
|
||||
writer.WriteString(fmt.Sprintf("USE %s;\n", entity.DbTypeMysql.QuoteIdentifier(dbName)))
|
||||
default:
|
||||
biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
|
||||
}
|
||||
}
|
||||
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
|
||||
}
|
||||
writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(false))
|
||||
|
||||
dbMeta := dbConn.GetMeta()
|
||||
if len(tables) == 0 {
|
||||
@@ -375,7 +354,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
|
||||
for _, table := range tables {
|
||||
writer.TryFlush()
|
||||
quotedTable := quoteTable(dbConn.Info.Type, table)
|
||||
quotedTable := dbConn.Info.Type.QuoteIdentifier(table)
|
||||
if needStruct {
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
|
||||
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
|
||||
@@ -398,7 +377,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
}
|
||||
strValue, ok := value.(string)
|
||||
if ok {
|
||||
strValue = escapeSql(dbConn.Info.Type, strValue)
|
||||
strValue = dbConn.Info.Type.QuoteLiteral(strValue)
|
||||
values = append(values, strValue)
|
||||
} else {
|
||||
values = append(values, stringx.AnyToStr(value))
|
||||
@@ -408,9 +387,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
})
|
||||
writer.WriteString("COMMIT;\n")
|
||||
}
|
||||
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
|
||||
}
|
||||
writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(true))
|
||||
}
|
||||
|
||||
// @router /api/db/:dbId/t-metadata [get]
|
||||
|
||||
@@ -9,46 +9,46 @@ import (
|
||||
func Test_escapeSql(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType string
|
||||
dbType entity.DBType
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
dbType: entity.DbTypeMysql,
|
||||
dbType: entity.DBTypeMysql{},
|
||||
sql: "\\a\\b",
|
||||
want: "'\\\\a\\\\b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypeMysql,
|
||||
dbType: entity.DBTypeMysql{},
|
||||
sql: "'a'",
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: entity.DbTypeMysql,
|
||||
dbType: entity.DBTypeMysql{},
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypePostgres,
|
||||
dbType: entity.DBTypePostgres{},
|
||||
sql: "\\a\\b",
|
||||
want: " E'\\\\a\\\\b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypePostgres,
|
||||
dbType: entity.DBTypePostgres{},
|
||||
sql: "'a'",
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: entity.DbTypePostgres,
|
||||
dbType: entity.DBTypePostgres{},
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := escapeSql(tt.dbType, tt.sql)
|
||||
got := tt.dbType.QuoteLiteral(tt.sql)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -144,6 +144,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
|
||||
biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置数据库【%s】的操作权限", dbName)
|
||||
|
||||
instance := d.dbInstanceApp.GetById(db.InstanceId)
|
||||
biz.NotNil(instance, "数据库实例不存在")
|
||||
// 密码解密
|
||||
instance.PwdDecrypt()
|
||||
|
||||
@@ -178,7 +179,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
|
||||
type DbInfo struct {
|
||||
Id uint64
|
||||
Name string
|
||||
Type string // 类型,mysql oracle等
|
||||
Type entity.DbType // 类型,mysql oracle等
|
||||
Host string
|
||||
Port int
|
||||
Network string
|
||||
@@ -242,14 +243,14 @@ func (d *DbConnection) Exec(sql string) (int64, error) {
|
||||
|
||||
// 获取数据库元信息实现接口
|
||||
func (d *DbConnection) GetMeta() DbMetadata {
|
||||
dbType := d.Info.Type
|
||||
if dbType == entity.DbTypeMysql {
|
||||
switch d.Info.Type {
|
||||
case entity.DbTypeMysql:
|
||||
return &MysqlMetadata{di: d}
|
||||
}
|
||||
if dbType == entity.DbTypePostgres {
|
||||
case entity.DbTypePostgres:
|
||||
return &PgsqlMetadata{di: d}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
|
||||
@@ -2,6 +2,7 @@ package application
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/biz"
|
||||
@@ -97,10 +98,13 @@ func (app *instanceAppImpl) Delete(id uint64) {
|
||||
func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
if instance.Type == entity.DbTypeMysql {
|
||||
switch instance.Type {
|
||||
case entity.DbTypeMysql:
|
||||
conn, err = getMysqlDB(instance, db)
|
||||
} else if instance.Type == entity.DbTypePostgres {
|
||||
case entity.DbTypePostgres:
|
||||
conn, err = getPgsqlDB(instance, db)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", instance.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -126,15 +130,8 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) []string {
|
||||
ed.Network = ed.GetNetwork()
|
||||
databases := make([]string, 0)
|
||||
var dbConn *sql.DB
|
||||
var metaDb string
|
||||
var getDatabasesSql string
|
||||
if ed.Type == entity.DbTypeMysql {
|
||||
metaDb = "information_schema"
|
||||
getDatabasesSql = "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA"
|
||||
} else {
|
||||
metaDb = "postgres"
|
||||
getDatabasesSql = "SELECT datname AS dbname FROM pg_database"
|
||||
}
|
||||
metaDb := ed.Type.MetaDbName()
|
||||
getDatabasesSql := ed.Type.StmtSelectDbName()
|
||||
|
||||
dbConn, err := getInstanceConn(ed, metaDb)
|
||||
biz.ErrIsNilAppendErr(err, "数据库连接失败: %s")
|
||||
|
||||
@@ -26,7 +26,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
return sql.Open(d.Type, dsn)
|
||||
return sql.Open(string(d.Type), dsn)
|
||||
}
|
||||
|
||||
// ---------------------------------- mysql元数据 -----------------------------------
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
driverName := d.Type
|
||||
driverName := string(d.Type)
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
||||
|
||||
96
server/internal/db/domain/entity/db_type.go
Normal file
96
server/internal/db/domain/entity/db_type.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DbType string
|
||||
|
||||
const (
|
||||
DbTypeMysql DbType = "mysql"
|
||||
DbTypePostgres DbType = "postgres"
|
||||
)
|
||||
|
||||
func (dbType DbType) MetaDbName() string {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
return "information_schema"
|
||||
case DbTypePostgres:
|
||||
return "postgres"
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) QuoteIdentifier(name string) string {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
return quoteIdentifier(name, "`")
|
||||
case DbTypePostgres:
|
||||
return pq.QuoteIdentifier(name)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) QuoteLiteral(literal string) string {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
literal = strings.ReplaceAll(literal, `\`, `\\`)
|
||||
literal = strings.ReplaceAll(literal, `'`, `''`)
|
||||
return "'" + literal + "'"
|
||||
case DbTypePostgres:
|
||||
return pq.QuoteLiteral(literal)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
func (dbType DbType) StmtSelectDbName() string {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
return "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA"
|
||||
case DbTypePostgres:
|
||||
return "SELECT datname AS dbname FROM pg_database"
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
|
||||
// used as part of an SQL statement. For example:
|
||||
//
|
||||
// tblname := "my_table"
|
||||
// data := "my_data"
|
||||
// quoted := pq.QuoteIdentifier(tblname)
|
||||
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
|
||||
//
|
||||
// Any double quotes in name will be escaped. The quoted identifier will be
|
||||
// case sensitive when used in a query. If the input string contains a zero
|
||||
// byte, the result will be truncated immediately before it.
|
||||
func quoteIdentifier(name, quoter string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter
|
||||
}
|
||||
|
||||
func (dbType DbType) StmtSetForeignKeyChecks(check bool) string {
|
||||
switch dbType {
|
||||
case DbTypeMysql:
|
||||
if check {
|
||||
return "\nSET FOREIGN_KEY_CHECKS = 1;\n"
|
||||
} else {
|
||||
return "\nSET FOREIGN_KEY_CHECKS = 0;\n"
|
||||
}
|
||||
case DbTypePostgres:
|
||||
// not currently supported postgres
|
||||
return ""
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dbType))
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,7 +10,7 @@ type Instance struct {
|
||||
model.Model
|
||||
|
||||
Name string `orm:"column(name)" json:"name"`
|
||||
Type string `orm:"column(type)" json:"type"` // 类型,mysql oracle等
|
||||
Type DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等
|
||||
Host string `orm:"column(host)" json:"host"`
|
||||
Port int `orm:"column(port)" json:"port"`
|
||||
Network string `orm:"column(network)" json:"network"`
|
||||
@@ -47,8 +47,3 @@ func (d *Instance) PwdDecrypt() {
|
||||
// 密码替换为解密后的密码
|
||||
d.Password = utils.PwdAesDecrypt(d.Password)
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeMysql = "mysql"
|
||||
DbTypePostgres = "postgres"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user