diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 4b7732f7..0947edd6 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/kanzihuang/vitess/go/vt/sqlparser" - "github.com/lib/pq" "io" "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" @@ -326,24 +325,6 @@ func (d *Db) DumpSql(rc *req.Ctx) { rc.ReqParam = collx.Kvs("db", db, "databases", dbNamesStr, "tables", tablesStr, "dumpType", dumpType) } -func escapeSql(dbType string, sql string) string { - if dbType == entity.DbTypePostgres { - return pq.QuoteLiteral(sql) - } else { - sql = strings.ReplaceAll(sql, `\`, `\\`) - sql = strings.ReplaceAll(sql, `'`, `''`) - return "'" + sql + "'" - } -} - -func quoteTable(dbType string, table string) string { - if dbType == entity.DbTypePostgres { - return "\"" + table + "\"" - } else { - return "`" + table + "`" - } -} - func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) { dbConn := d.DbApp.GetDbConnection(dbId, dbName) writer.WriteString("\n-- ----------------------------") @@ -355,14 +336,12 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str if switchDb { switch dbConn.Info.Type { case entity.DbTypeMysql: - writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName)) + writer.WriteString(fmt.Sprintf("USE %s;\n", entity.DbTypeMysql.QuoteIdentifier(dbName))) default: biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql) } } - if dbConn.Info.Type == entity.DbTypeMysql { - writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n") - } + writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(false)) dbMeta := dbConn.GetMeta() if len(tables) == 0 { @@ -375,7 +354,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str for _, table := range tables { writer.TryFlush() - quotedTable := quoteTable(dbConn.Info.Type, table) + quotedTable := dbConn.Info.Type.QuoteIdentifier(table) if needStruct { writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable)) @@ -398,7 +377,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str } strValue, ok := value.(string) if ok { - strValue = escapeSql(dbConn.Info.Type, strValue) + strValue = dbConn.Info.Type.QuoteLiteral(strValue) values = append(values, strValue) } else { values = append(values, stringx.AnyToStr(value)) @@ -408,9 +387,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str }) writer.WriteString("COMMIT;\n") } - if dbConn.Info.Type == entity.DbTypeMysql { - writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n") - } + writer.WriteString(dbConn.Info.Type.StmtSetForeignKeyChecks(true)) } // @router /api/db/:dbId/t-metadata [get] diff --git a/server/internal/db/api/db_test.go b/server/internal/db/api/db_test.go index ec6b9480..30b30ee6 100644 --- a/server/internal/db/api/db_test.go +++ b/server/internal/db/api/db_test.go @@ -9,46 +9,46 @@ import ( func Test_escapeSql(t *testing.T) { tests := []struct { name string - dbType string + dbType entity.DBType sql string want string }{ { - dbType: entity.DbTypeMysql, + dbType: entity.DBTypeMysql{}, sql: "\\a\\b", want: "'\\\\a\\\\b'", }, { - dbType: entity.DbTypeMysql, + dbType: entity.DBTypeMysql{}, sql: "'a'", want: "'''a'''", }, { name: "不间断空格", - dbType: entity.DbTypeMysql, + dbType: entity.DBTypeMysql{}, sql: "a\u00A0b", want: "'a\u00A0b'", }, { - dbType: entity.DbTypePostgres, + dbType: entity.DBTypePostgres{}, sql: "\\a\\b", want: " E'\\\\a\\\\b'", }, { - dbType: entity.DbTypePostgres, + dbType: entity.DBTypePostgres{}, sql: "'a'", want: "'''a'''", }, { name: "不间断空格", - dbType: entity.DbTypePostgres, + dbType: entity.DBTypePostgres{}, sql: "a\u00A0b", want: "'a\u00A0b'", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := escapeSql(tt.dbType, tt.sql) + got := tt.dbType.QuoteLiteral(tt.sql) require.Equal(t, tt.want, got) }) } diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index c95a6956..9d549432 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -144,6 +144,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection { biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置数据库【%s】的操作权限", dbName) instance := d.dbInstanceApp.GetById(db.InstanceId) + biz.NotNil(instance, "数据库实例不存在") // 密码解密 instance.PwdDecrypt() @@ -178,7 +179,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection { type DbInfo struct { Id uint64 Name string - Type string // 类型,mysql oracle等 + Type entity.DbType // 类型,mysql oracle等 Host string Port int Network string @@ -242,14 +243,14 @@ func (d *DbConnection) Exec(sql string) (int64, error) { // 获取数据库元信息实现接口 func (d *DbConnection) GetMeta() DbMetadata { - dbType := d.Info.Type - if dbType == entity.DbTypeMysql { + switch d.Info.Type { + case entity.DbTypeMysql: return &MysqlMetadata{di: d} - } - if dbType == entity.DbTypePostgres { + case entity.DbTypePostgres: return &PgsqlMetadata{di: d} + default: + panic(fmt.Sprintf("invalid database type: %s", d.Info.Type)) } - return nil } // 关闭连接 diff --git a/server/internal/db/application/instance.go b/server/internal/db/application/instance.go index a3c8c548..6efa06ea 100644 --- a/server/internal/db/application/instance.go +++ b/server/internal/db/application/instance.go @@ -2,6 +2,7 @@ package application import ( "database/sql" + "fmt" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/biz" @@ -97,10 +98,13 @@ func (app *instanceAppImpl) Delete(id uint64) { func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) { var conn *sql.DB var err error - if instance.Type == entity.DbTypeMysql { + switch instance.Type { + case entity.DbTypeMysql: conn, err = getMysqlDB(instance, db) - } else if instance.Type == entity.DbTypePostgres { + case entity.DbTypePostgres: conn, err = getPgsqlDB(instance, db) + default: + panic(fmt.Sprintf("invalid database type: %s", instance.Type)) } if err != nil { @@ -126,15 +130,8 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) []string { ed.Network = ed.GetNetwork() databases := make([]string, 0) var dbConn *sql.DB - var metaDb string - var getDatabasesSql string - if ed.Type == entity.DbTypeMysql { - metaDb = "information_schema" - getDatabasesSql = "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA" - } else { - metaDb = "postgres" - getDatabasesSql = "SELECT datname AS dbname FROM pg_database" - } + metaDb := ed.Type.MetaDbName() + getDatabasesSql := ed.Type.StmtSelectDbName() dbConn, err := getInstanceConn(ed, metaDb) biz.ErrIsNilAppendErr(err, "数据库连接失败: %s") diff --git a/server/internal/db/application/mysql_meta.go b/server/internal/db/application/mysql_meta.go index 22c18693..66cd658b 100644 --- a/server/internal/db/application/mysql_meta.go +++ b/server/internal/db/application/mysql_meta.go @@ -26,7 +26,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) { if d.Params != "" { dsn = fmt.Sprintf("%s&%s", dsn, d.Params) } - return sql.Open(d.Type, dsn) + return sql.Open(string(d.Type), dsn) } // ---------------------------------- mysql元数据 ----------------------------------- diff --git a/server/internal/db/application/pgsql_meta.go b/server/internal/db/application/pgsql_meta.go index 50056465..6503488e 100644 --- a/server/internal/db/application/pgsql_meta.go +++ b/server/internal/db/application/pgsql_meta.go @@ -17,7 +17,7 @@ import ( ) func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) { - driverName := d.Type + driverName := string(d.Type) // SSH Conect if d.SshTunnelMachineId > 0 { // 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名 diff --git a/server/internal/db/domain/entity/db_type.go b/server/internal/db/domain/entity/db_type.go new file mode 100644 index 00000000..bc97e0ee --- /dev/null +++ b/server/internal/db/domain/entity/db_type.go @@ -0,0 +1,96 @@ +package entity + +import ( + "fmt" + "github.com/lib/pq" + "strings" +) + +type DbType string + +const ( + DbTypeMysql DbType = "mysql" + DbTypePostgres DbType = "postgres" +) + +func (dbType DbType) MetaDbName() string { + switch dbType { + case DbTypeMysql: + return "information_schema" + case DbTypePostgres: + return "postgres" + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } +} + +func (dbType DbType) QuoteIdentifier(name string) string { + switch dbType { + case DbTypeMysql: + return quoteIdentifier(name, "`") + case DbTypePostgres: + return pq.QuoteIdentifier(name) + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } +} + +func (dbType DbType) QuoteLiteral(literal string) string { + switch dbType { + case DbTypeMysql: + literal = strings.ReplaceAll(literal, `\`, `\\`) + literal = strings.ReplaceAll(literal, `'`, `''`) + return "'" + literal + "'" + case DbTypePostgres: + return pq.QuoteLiteral(literal) + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } +} + +func (dbType DbType) StmtSelectDbName() string { + switch dbType { + case DbTypeMysql: + return "SELECT SCHEMA_NAME AS dbname FROM SCHEMATA" + case DbTypePostgres: + return "SELECT datname AS dbname FROM pg_database" + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } +} + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be +// case sensitive when used in a query. If the input string contains a zero +// byte, the result will be truncated immediately before it. +func quoteIdentifier(name, quoter string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return quoter + strings.Replace(name, quoter, quoter+quoter, -1) + quoter +} + +func (dbType DbType) StmtSetForeignKeyChecks(check bool) string { + switch dbType { + case DbTypeMysql: + if check { + return "\nSET FOREIGN_KEY_CHECKS = 1;\n" + } else { + return "\nSET FOREIGN_KEY_CHECKS = 0;\n" + } + case DbTypePostgres: + // not currently supported postgres + return "" + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } + +} diff --git a/server/internal/db/domain/entity/instance.go b/server/internal/db/domain/entity/instance.go index 2401aaab..ff7c353c 100644 --- a/server/internal/db/domain/entity/instance.go +++ b/server/internal/db/domain/entity/instance.go @@ -10,7 +10,7 @@ type Instance struct { model.Model Name string `orm:"column(name)" json:"name"` - Type string `orm:"column(type)" json:"type"` // 类型,mysql oracle等 + Type DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等 Host string `orm:"column(host)" json:"host"` Port int `orm:"column(port)" json:"port"` Network string `orm:"column(network)" json:"network"` @@ -47,8 +47,3 @@ func (d *Instance) PwdDecrypt() { // 密码替换为解密后的密码 d.Password = utils.PwdAesDecrypt(d.Password) } - -const ( - DbTypeMysql = "mysql" - DbTypePostgres = "postgres" -)