mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-04 00:10:25 +08:00
refactor: dbm
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"mayfly-go/internal/machine/mcm"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
@@ -23,19 +24,33 @@ type DbConn struct {
|
||||
// 执行数据库查询返回的列信息
|
||||
type QueryColumn struct {
|
||||
Name string `json:"name"` // 列名
|
||||
Type string `json:"type"` // 类型
|
||||
Type string `json:"type"` // 数据类型
|
||||
|
||||
SqlColType *sql.ColumnType `json:"-"`
|
||||
DbDataType *DbDataType `json:"-"`
|
||||
valuer Valuer `json:"-"`
|
||||
}
|
||||
|
||||
func NewQueryColumn(colName string, col *sql.ColumnType) *QueryColumn {
|
||||
func NewQueryColumn(colName string, columnType *DbDataType) *QueryColumn {
|
||||
return &QueryColumn{
|
||||
Name: col.Name(),
|
||||
Type: col.DatabaseTypeName(),
|
||||
SqlColType: col,
|
||||
Name: colName,
|
||||
Type: columnType.DataType.Name,
|
||||
DbDataType: columnType,
|
||||
valuer: columnType.DataType.Valuer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) getValuePtr() any {
|
||||
return qc.valuer.NewValuePtr()
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) value() any {
|
||||
return qc.valuer.Value()
|
||||
}
|
||||
|
||||
func (qc *QueryColumn) SQLValue(val any) any {
|
||||
return qc.DbDataType.DataType.SQLValue(val)
|
||||
}
|
||||
|
||||
func (d *DbConn) GetDb() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
@@ -80,7 +95,7 @@ func (d *DbConn) Query2Struct(execSql string, dest any) error {
|
||||
|
||||
// WalkQueryRows 游标方式遍历查询结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
|
||||
func (d *DbConn) WalkQueryRows(ctx context.Context, querySql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
return walkQueryRows(ctx, d.GetDialect(), d.db, querySql, walkFn, args...)
|
||||
return d.walkQueryRows(ctx, querySql, walkFn, args...)
|
||||
}
|
||||
|
||||
// WalkTableRows 游标方式遍历指定表的结果集, walkFn返回error不为nil, 则跳出遍历并取消查询
|
||||
@@ -138,6 +153,11 @@ func (d *DbConn) GetMetadata() Metadata {
|
||||
return d.Info.Meta.GetMetadata(d)
|
||||
}
|
||||
|
||||
// GetDbDataType 获取定义的数据库数据类型
|
||||
func (d *DbConn) GetDbDataType(dataType string) *DbDataType {
|
||||
return GetDbDataType(d.Info.Type, dataType)
|
||||
}
|
||||
|
||||
// Stats 返回数据库连接状态
|
||||
func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBStats {
|
||||
return d.db.Stats()
|
||||
@@ -149,8 +169,8 @@ func (d *DbConn) Close() {
|
||||
if err := d.db.Close(); err != nil {
|
||||
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
|
||||
}
|
||||
// 如果是达梦并且使用了ssh隧道,则需要手动将其关闭
|
||||
if d.Info.Type == DbTypeDM && d.Info.SshTunnelMachineId > 0 {
|
||||
// 如果是使用了自己实现的ssh隧道转发,则需要手动将其关闭
|
||||
if d.Info.useSshTunnel {
|
||||
mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id))
|
||||
}
|
||||
d.db = nil
|
||||
@@ -158,11 +178,11 @@ func (d *DbConn) Close() {
|
||||
}
|
||||
|
||||
// 游标方式遍历查询rows, walkFn error不为nil, 则跳出遍历
|
||||
func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
func (d *DbConn) walkQueryRows(ctx context.Context, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
|
||||
cancelCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := db.QueryContext(cancelCtx, selectSql, args...)
|
||||
rows, err := d.db.QueryContext(cancelCtx, selectSql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -170,8 +190,6 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer rows.Close()
|
||||
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -188,9 +206,9 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
if colName == "" {
|
||||
colName = fmt.Sprintf("<anonymous%d>", k+1)
|
||||
}
|
||||
qc := NewQueryColumn(colName, colType)
|
||||
qc := NewQueryColumn(colName, d.GetDbDataType(colType.DatabaseTypeName()))
|
||||
cols[k] = qc
|
||||
scans[k] = columnHelper.GetScanDestPtr(qc)
|
||||
scans[k] = qc.getValuePtr()
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
@@ -201,8 +219,8 @@ func walkQueryRows(ctx context.Context, dialect Dialect, db *sql.DB, selectSql s
|
||||
// 每行数据
|
||||
rowData := make(map[string]any, lenCols)
|
||||
// 把values中的数据复制到row中
|
||||
for i, v := range scans {
|
||||
rowData[cols[i].Name] = columnHelper.ConvertScanDestValue(v, cols[i])
|
||||
for i := range scans {
|
||||
rowData[cols[i].Name] = cols[i].value()
|
||||
}
|
||||
if err = walkFn(rowData, cols); err != nil {
|
||||
logx.ErrorfContext(ctx, "[%s] cursor traversal query result set error, exit traversal: %s", selectSql, err.Error())
|
||||
|
||||
Reference in New Issue
Block a user