refactor: 新增base.Repo与base.App,重构repo与app层代码

This commit is contained in:
meilin.huang
2023-10-26 17:15:49 +08:00
parent 10f6b03fb5
commit a1303b52eb
115 changed files with 1867 additions and 1696 deletions

View File

@@ -7,8 +7,9 @@ import (
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/base"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/collx"
@@ -20,79 +21,71 @@ import (
)
type Db interface {
base.App[*entity.Db]
// 分页获取
GetPageList(condition *entity.DbQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
GetPageList(condition *entity.DbQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
Count(condition *entity.DbQuery) int64
// 根据条件获取
GetDbBy(condition *entity.Db, cols ...string) error
// 根据id获取
GetById(id uint64, cols ...string) *entity.Db
Save(entity *entity.Db)
Save(entity *entity.Db) error
// 删除数据库信息
Delete(id uint64)
Delete(id uint64) error
// 获取数据库连接实例
// @param id 数据库实例id
// @param db 数据库
GetDbConnection(dbId uint64, dbName string) *DbConnection
GetDbConnection(dbId uint64, dbName string) (*DbConnection, error)
}
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db {
return &dbAppImpl{
dbRepo: dbRepo,
app := &dbAppImpl{
dbSqlRepo: dbSqlRepo,
dbInstanceApp: dbInstanceApp,
}
app.Repo = dbRepo
return app
}
type dbAppImpl struct {
dbRepo repository.Db
base.AppImpl[*entity.Db, repository.Db]
dbSqlRepo repository.DbSql
dbInstanceApp Instance
}
// 分页获取数据库信息列表
func (d *dbAppImpl) GetPageList(condition *entity.DbQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
return d.dbRepo.GetDbList(condition, pageParam, toEntity, orderBy...)
func (d *dbAppImpl) GetPageList(condition *entity.DbQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return d.GetRepo().GetDbList(condition, pageParam, toEntity, orderBy...)
}
func (d *dbAppImpl) Count(condition *entity.DbQuery) int64 {
return d.dbRepo.Count(condition)
return d.GetRepo().Count(condition)
}
// 根据条件获取
func (d *dbAppImpl) GetDbBy(condition *entity.Db, cols ...string) error {
return d.dbRepo.GetDb(condition, cols...)
}
// 根据id获取
func (d *dbAppImpl) GetById(id uint64, cols ...string) *entity.Db {
return d.dbRepo.GetById(id, cols...)
}
func (d *dbAppImpl) Save(dbEntity *entity.Db) {
func (d *dbAppImpl) Save(dbEntity *entity.Db) error {
// 查找是否存在
oldDb := &entity.Db{Name: dbEntity.Name, InstanceId: dbEntity.InstanceId}
err := d.GetDbBy(oldDb)
err := d.GetBy(oldDb)
if dbEntity.Id == 0 {
biz.IsTrue(err != nil, "该实例下数据库名已存在")
d.dbRepo.Insert(dbEntity)
return
if err == nil {
return errorx.NewBiz("该实例下数据库名已存在")
}
return d.Insert(dbEntity)
}
// 如果存在该库,则校验修改的库是否为该库
if err == nil {
biz.IsTrue(oldDb.Id == dbEntity.Id, "该实例下数据库名已存在")
if err == nil && oldDb.Id != dbEntity.Id {
return errorx.NewBiz("该实例下数据库名已存在")
}
dbId := dbEntity.Id
old := d.GetById(dbId)
old, err := d.GetById(new(entity.Db), dbId)
if err != nil {
return errorx.NewBiz("该数据库不存在")
}
oldDbs := strings.Split(old.Database, " ")
newDbs := strings.Split(dbEntity.Database, " ")
@@ -105,27 +98,30 @@ func (d *dbAppImpl) Save(dbEntity *entity.Db) {
// 关闭数据库连接
CloseDb(dbEntity.Id, v)
// 删除该库关联的所有sql记录
d.dbSqlRepo.DeleteBy(&entity.DbSql{DbId: dbId, Db: v})
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: dbId, Db: v})
}
d.dbRepo.Update(dbEntity)
return d.UpdateById(dbEntity)
}
func (d *dbAppImpl) Delete(id uint64) {
db := d.GetById(id)
func (d *dbAppImpl) Delete(id uint64) error {
db, err := d.GetById(new(entity.Db), id)
if err != nil {
return errorx.NewBiz("该数据库不存在")
}
dbs := strings.Split(db.Database, " ")
for _, v := range dbs {
// 关闭连接
CloseDb(id, v)
}
d.dbRepo.Delete(id)
// 删除该库下用户保存的所有sql信息
d.dbSqlRepo.DeleteBy(&entity.DbSql{DbId: id})
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: id})
return d.DeleteById(id)
}
var mutex sync.Mutex
func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) (*DbConnection, error) {
cacheKey := GetDbCacheKey(dbId, dbName)
// Id不为0则为需要缓存
@@ -133,18 +129,24 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
if needCache {
load, ok := dbCache.Get(cacheKey)
if ok {
return load.(*DbConnection)
return load.(*DbConnection), nil
}
}
mutex.Lock()
defer mutex.Unlock()
db := d.GetById(dbId)
biz.NotNil(db, "数据库信息不存在")
biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置数据库【%s】的操作权限", dbName)
db, err := d.GetById(new(entity.Db), dbId)
if err != nil {
return nil, errorx.NewBiz("数据库信息不存在")
}
if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") {
return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName)
}
instance := d.dbInstanceApp.GetById(db.InstanceId)
biz.NotNil(instance, "数据库实例不存在")
instance, err := d.dbInstanceApp.GetById(new(entity.Instance), db.InstanceId)
if err != nil {
return nil, errorx.NewBiz("数据库实例不存在")
}
// 密码解密
instance.PwdDecrypt()
@@ -156,7 +158,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
if err != nil {
dbi.Close()
logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName)
panic(biz.NewBizErr(fmt.Sprintf("数据库连接失败: %s", err.Error())))
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
}
// 最大连接周期超过时间的连接就close
@@ -171,7 +173,7 @@ func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
if needCache {
dbCache.Put(cacheKey, dbi)
}
return dbi
return dbi, nil
}
//---------------------------------------- db instance ------------------------------------

View File

@@ -6,7 +6,7 @@ import (
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/model"
"strconv"
"strings"
@@ -51,7 +51,7 @@ type DbSqlExec interface {
DeleteBy(condition *entity.DbSqlExec)
// 分页获取
GetPageList(condition *entity.DbSqlExecQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
GetPageList(condition *entity.DbSqlExecQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
}
func newDbSqlExecApp(dbExecSqlRepo repository.DbSqlExec) DbSqlExec {
@@ -91,7 +91,9 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
// 如果配置为0则不校验分页参数
maxCount := config.GetDbQueryMaxCount()
if maxCount != 0 {
biz.IsTrue(strings.Contains(lowerSql, "limit"), "请完善分页信息后执行")
if !strings.Contains(lowerSql, "limit") {
return nil, errorx.NewBiz("请完善分页信息后执行")
}
}
}
var execErr error
@@ -148,10 +150,10 @@ func (d *dbSqlExecAppImpl) saveSqlExecLog(isQuery bool, dbSqlExecRecord *entity.
}
func (d *dbSqlExecAppImpl) DeleteBy(condition *entity.DbSqlExec) {
d.dbSqlExecRepo.DeleteBy(condition)
d.dbSqlExecRepo.DeleteByCond(condition)
}
func (d *dbSqlExecAppImpl) GetPageList(condition *entity.DbSqlExecQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
func (d *dbSqlExecAppImpl) GetPageList(condition *entity.DbSqlExecQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return d.dbSqlExecRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
}
@@ -163,10 +165,18 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe
maxCount := config.GetDbQueryMaxCount()
if maxCount != 0 {
limit := selectStmt.Limit
biz.NotNil(limit, "请完善分页信息后执行")
if limit == nil {
return nil, errorx.NewBiz("请完善分页信息后执行")
}
count, err := strconv.Atoi(sqlparser.String(limit.Rowcount))
biz.ErrIsNil(err, "分页参数有误")
biz.IsTrue(count <= maxCount, "查询结果集数需小于系统配置的%d条", maxCount)
if err != nil {
return nil, errorx.NewBiz("分页参数有误")
}
if count > maxCount {
return nil, errorx.NewBiz("查询结果集数需小于系统配置的%d条", maxCount)
}
}
}
@@ -193,7 +203,9 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
// 可能使用别名,故空格切割
tableName := strings.Split(tableStr, " ")[0]
where := sqlparser.String(update.Where)
biz.IsTrue(len(where) > 0, "SQL[%s]未执行. 请完善 where 条件后再执行", execSqlReq.Sql)
if len(where) == 0 {
return nil, errorx.NewBiz("SQL[%s]未执行. 请完善 where 条件后再执行", execSqlReq.Sql)
}
updateExprs := update.Exprs
updateColumns := make([]string, 0)
@@ -202,7 +214,10 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
}
// 获取表主键列名,排除使用别名
primaryKey := dbConn.GetMeta().GetPrimaryKey(tableName)
primaryKey, err := dbConn.GetMeta().GetPrimaryKey(tableName)
if err != nil {
return nil, errorx.NewBiz("获取表主键信息失败")
}
updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey
// 查询要更新字段数据的旧值,以及主键值
@@ -228,7 +243,9 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent
// 可能使用别名,故空格切割
table := strings.Split(tableStr, " ")[0]
where := sqlparser.String(delete.Where)
biz.IsTrue(len(where) > 0, "SQL[%s]未执行. 请完善 where 条件后再执行", execSqlReq.Sql)
if len(where) == 0 {
return nil, errorx.NewBiz("SQL[%s]未执行. 请完善 where 条件后再执行", execSqlReq.Sql)
}
// 查询删除数据
selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where)

View File

@@ -5,67 +5,56 @@ import (
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/base"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/model"
)
type Instance interface {
base.App[*entity.Instance]
// GetPageList 分页获取数据库实例
GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any]
GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
Count(condition *entity.InstanceQuery) int64
// GetInstanceBy 根据条件获取数据库实例
GetInstanceBy(condition *entity.Instance, cols ...string) error
// GetById 根据id获取数据库实例
GetById(id uint64, cols ...string) *entity.Instance
Save(instanceEntity *entity.Instance)
Save(instanceEntity *entity.Instance) error
// Delete 删除数据库信息
Delete(id uint64)
Delete(id uint64) error
// GetDatabases 获取数据库实例的所有数据库列表
GetDatabases(entity *entity.Instance) []string
GetDatabases(entity *entity.Instance) ([]string, error)
}
func newInstanceApp(InstanceRepo repository.Instance) Instance {
return &instanceAppImpl{
instanceRepo: InstanceRepo,
}
func newInstanceApp(instanceRepo repository.Instance) Instance {
app := new(instanceAppImpl)
app.Repo = instanceRepo
return app
}
type instanceAppImpl struct {
instanceRepo repository.Instance
base.AppImpl[*entity.Instance, repository.Instance]
}
// GetPageList 分页获取数据库实例
func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) *model.PageResult[any] {
return app.instanceRepo.GetInstanceList(condition, pageParam, toEntity, orderBy...)
func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return app.GetRepo().GetInstanceList(condition, pageParam, toEntity, orderBy...)
}
func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
return app.instanceRepo.Count(condition)
return app.CountByCond(condition)
}
// GetInstanceBy 根据条件获取数据库实例
func (app *instanceAppImpl) GetInstanceBy(condition *entity.Instance, cols ...string) error {
return app.instanceRepo.GetInstance(condition, cols...)
}
// GetById 根据id获取数据库实例
func (app *instanceAppImpl) GetById(id uint64, cols ...string) *entity.Instance {
return app.instanceRepo.GetById(id, cols...)
}
func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) {
func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) error {
// 默认tcp连接
instanceEntity.Network = instanceEntity.GetNetwork()
// 测试连接
if instanceEntity.Password != "" {
testConnection(instanceEntity)
if err := testConnection(instanceEntity); err != nil {
return errorx.NewBiz("数据库连接失败: %s", err.Error())
}
}
// 查找是否存在该库
@@ -74,24 +63,28 @@ func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) {
oldInstance.SshTunnelMachineId = instanceEntity.SshTunnelMachineId
}
err := app.GetInstanceBy(oldInstance)
err := app.GetBy(oldInstance)
if instanceEntity.Id == 0 {
biz.NotEmpty(instanceEntity.Password, "密码不能为空")
biz.IsTrue(err != nil, "该数据库实例已存在")
instanceEntity.PwdEncrypt()
app.instanceRepo.Insert(instanceEntity)
} else {
// 如果存在该库,则校验修改的库是否为该库
if instanceEntity.Password == "" {
return errorx.NewBiz("密码不能为空")
}
if err == nil {
biz.IsTrue(oldInstance.Id == instanceEntity.Id, "该数据库实例已存在")
return errorx.NewBiz("该数据库实例已存在")
}
instanceEntity.PwdEncrypt()
app.instanceRepo.Update(instanceEntity)
return app.Insert(instanceEntity)
}
// 如果存在该库,则校验修改的库是否为该库
if err == nil && oldInstance.Id != instanceEntity.Id {
return errorx.NewBiz("该数据库实例已存在")
}
instanceEntity.PwdEncrypt()
return app.UpdateById(instanceEntity)
}
func (app *instanceAppImpl) Delete(id uint64) {
app.instanceRepo.Delete(id)
func (app *instanceAppImpl) Delete(id uint64) error {
return app.DeleteById(id)
}
// getInstanceConn 获取数据库连接数据库实例
@@ -119,14 +112,17 @@ func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) {
return conn, nil
}
func testConnection(d *entity.Instance) {
func testConnection(d *entity.Instance) error {
// 不指定数据库名称
conn, err := getInstanceConn(d, "")
biz.ErrIsNilAppendErr(err, "数据库连接失败: %s")
if err != nil {
return err
}
defer conn.Close()
return nil
}
func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) []string {
func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) ([]string, error) {
ed.Network = ed.GetNetwork()
databases := make([]string, 0)
var dbConn *sql.DB
@@ -134,14 +130,18 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) []string {
getDatabasesSql := ed.Type.StmtSelectDbName()
dbConn, err := getInstanceConn(ed, metaDb)
biz.ErrIsNilAppendErr(err, "数据库连接失败: %s")
if err != nil {
return nil, errorx.NewBiz("数据库连接失败: %s", err.Error())
}
defer dbConn.Close()
_, res, err := selectDataByDb(dbConn, getDatabasesSql)
biz.ErrIsNilAppendErr(err, "获取数据库列表失败")
if err != nil {
return nil, err
}
for _, re := range res {
databases = append(databases, re["dbname"].(string))
}
return databases
return databases, nil
}

View File

@@ -44,22 +44,22 @@ type Index struct {
type DbMetadata interface {
// 获取表基础元信息
GetTables() []Table
GetTables() ([]Table, error)
// 获取指定表名的所有列元信息
GetColumns(tableNames ...string) []Column
GetColumns(tableNames ...string) ([]Column, error)
// 获取表主键字段名,没有主键标识则默认第一个字段
GetPrimaryKey(tablename string) string
GetPrimaryKey(tablename string) (string, error)
// 获取表信息比GetTables获取更详细的表信息
GetTableInfos() []Table
GetTableInfos() ([]Table, error)
// 获取表索引信息
GetTableIndex(tableName string) []Index
GetTableIndex(tableName string) ([]Index, error)
// 获取建表ddl
GetCreateTableDdl(tableName string) string
GetCreateTableDdl(tableName string) (string, error)
// 获取指定表的数据-分页查询
// @return columns: 列字段名result: 结果集error: 错误

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
"net"
@@ -16,7 +16,10 @@ import (
func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
// SSH Conect
if d.SshTunnelMachineId > 0 {
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if err != nil {
return nil, err
}
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
return sshTunnelMachine.GetDialConn("tcp", addr)
})
@@ -43,9 +46,11 @@ type MysqlMetadata struct {
}
// 获取表基础元信息, 如表名等
func (mm *MysqlMetadata) GetTables() []Table {
func (mm *MysqlMetadata) GetTables() ([]Table, error) {
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY))
biz.ErrIsNilAppendErr(err, "获取表基本信息失败: %s")
if err != nil {
return nil, err
}
tables := make([]Table, 0)
for _, re := range res {
@@ -54,11 +59,11 @@ func (mm *MysqlMetadata) GetTables() []Table {
TableComment: anyx.ConvString(re["tableComment"]),
})
}
return tables
return tables, nil
}
// 获取列元信息, 如列名等
func (mm *MysqlMetadata) GetColumns(tableNames ...string) []Column {
func (mm *MysqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -68,7 +73,10 @@ func (mm *MysqlMetadata) GetColumns(tableNames ...string) []Column {
}
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s")
if err != nil {
return nil, err
}
columns := make([]Column, 0)
for _, re := range res {
columns = append(columns, Column{
@@ -81,26 +89,34 @@ func (mm *MysqlMetadata) GetColumns(tableNames ...string) []Column {
ColumnDefault: anyx.ConvString(re["columnDefault"]),
})
}
return columns
return columns, nil
}
// 获取表主键字段名,不存在主键标识则默认第一个字段
func (mm *MysqlMetadata) GetPrimaryKey(tablename string) string {
columns := mm.GetColumns(tablename)
biz.IsTrue(len(columns) > 0, "[%s] 表不存在", tablename)
func (mm *MysqlMetadata) GetPrimaryKey(tablename string) (string, error) {
columns, err := mm.GetColumns(tablename)
if err != nil {
return "", err
}
if len(columns) == 0 {
return "", errorx.NewBiz("[%s] 表不存在", tablename)
}
for _, v := range columns {
if v.ColumnKey == "PRI" {
return v.ColumnName
return v.ColumnName, nil
}
}
return columns[0].ColumnName
return columns[0].ColumnName, nil
}
// 获取表信息比GetTableMetedatas获取更详细的表信息
func (mm *MysqlMetadata) GetTableInfos() []Table {
func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) {
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
biz.ErrIsNilAppendErr(err, "获取表信息失败: %s")
if err != nil {
return nil, err
}
tables := make([]Table, 0)
for _, re := range res {
@@ -113,13 +129,16 @@ func (mm *MysqlMetadata) GetTableInfos() []Table {
IndexLength: anyx.ConvInt64(re["indexLength"]),
})
}
return tables
return tables, nil
}
// 获取表索引信息
func (mm *MysqlMetadata) GetTableIndex(tableName string) []Index {
func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName))
biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s")
if err != nil {
return nil, err
}
indexs := make([]Index, 0)
for _, re := range res {
indexs = append(indexs, Index{
@@ -147,14 +166,16 @@ func (mm *MysqlMetadata) GetTableIndex(tableName string) []Index {
result = append(result, v)
}
}
return result
return result, nil
}
// 获取建表ddl
func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) string {
func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
_, res, err := mm.di.SelectData(fmt.Sprintf("show create table `%s` ", tableName))
biz.ErrIsNilAppendErr(err, "获取表结构失败: %s")
return res[0]["Create Table"].(string) + ";"
if err != nil {
return "", err
}
return res[0]["Create Table"].(string) + ";", nil
}
func (mm *MysqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"net"
@@ -49,7 +49,11 @@ func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
if sshConn, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId).GetDialConn("tcp", address); err == nil {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
if err != nil {
return nil, err
}
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
return sshConn, nil
} else {
return nil, err
@@ -75,9 +79,11 @@ type PgsqlMetadata struct {
}
// 获取表基础元信息, 如表名等
func (pm *PgsqlMetadata) GetTables() []Table {
func (pm *PgsqlMetadata) GetTables() ([]Table, error) {
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY))
biz.ErrIsNilAppendErr(err, "获取表基本信息失败: %s")
if err != nil {
return nil, err
}
tables := make([]Table, 0)
for _, re := range res {
@@ -86,11 +92,11 @@ func (pm *PgsqlMetadata) GetTables() []Table {
TableComment: anyx.ConvString(re["tableComment"]),
})
}
return tables
return tables, nil
}
// 获取列元信息, 如列名等
func (pm *PgsqlMetadata) GetColumns(tableNames ...string) []Column {
func (pm *PgsqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -100,7 +106,10 @@ func (pm *PgsqlMetadata) GetColumns(tableNames ...string) []Column {
}
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s")
if err != nil {
return nil, err
}
columns := make([]Column, 0)
for _, re := range res {
columns = append(columns, Column{
@@ -113,25 +122,32 @@ func (pm *PgsqlMetadata) GetColumns(tableNames ...string) []Column {
ColumnDefault: anyx.ConvString(re["columnDefault"]),
})
}
return columns
return columns, nil
}
func (pm *PgsqlMetadata) GetPrimaryKey(tablename string) string {
columns := pm.GetColumns(tablename)
biz.IsTrue(len(columns) > 0, "[%s] 表不存在", tablename)
func (pm *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) {
columns, err := pm.GetColumns(tablename)
if err != nil {
return "", err
}
if len(columns) == 0 {
return "", errorx.NewBiz("[%s] 表不存在", tablename)
}
for _, v := range columns {
if v.ColumnKey == "PRI" {
return v.ColumnName
return v.ColumnName, nil
}
}
return columns[0].ColumnName
return columns[0].ColumnName, nil
}
// 获取表信息比GetTables获取更详细的表信息
func (pm *PgsqlMetadata) GetTableInfos() []Table {
func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) {
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
biz.ErrIsNilAppendErr(err, "获取表信息失败: %s")
if err != nil {
return nil, err
}
tables := make([]Table, 0)
for _, re := range res {
@@ -144,13 +160,16 @@ func (pm *PgsqlMetadata) GetTableInfos() []Table {
IndexLength: anyx.ConvInt64(re["indexLength"]),
})
}
return tables
return tables, nil
}
// 获取表索引信息
func (pm *PgsqlMetadata) GetTableIndex(tableName string) []Index {
func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s")
if err != nil {
return nil, err
}
indexs := make([]Index, 0)
for _, re := range res {
indexs = append(indexs, Index{
@@ -162,22 +181,26 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) []Index {
SeqInIndex: anyx.ConvInt(re["seqInIndex"]),
})
}
return indexs
return indexs, nil
}
// 获取建表ddl
func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) string {
func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
_, err := pm.di.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
biz.ErrIsNilAppendErr(err, "创建ddl函数失败: %s")
if err != nil {
return "", err
}
_, schemaRes, _ := pm.di.SelectData("select current_schema() as schema")
schemaName := schemaRes[0]["schema"].(string)
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName)
_, res, err := pm.di.SelectData(ddlSql)
if err != nil {
return "", err
}
biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s")
return res[0]["sql"].(string)
return res[0]["sql"].(string), nil
}
func (pm *PgsqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {