package dbm import ( "context" "database/sql" "fmt" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" "reflect" "strconv" "strings" ) // db实例连接信息 type DbConn struct { Id string Info *DbInfo db *sql.DB } // 执行查询语句 // 依次返回 列名数组(顺序),结果map,错误 func (d *DbConn) Query(querySql string) ([]string, []map[string]any, error) { return d.QueryContext(context.Background(), querySql) } // 执行查询语句 // 依次返回 列名数组(顺序),结果map,错误 func (d *DbConn) QueryContext(ctx context.Context, querySql string) ([]string, []map[string]any, error) { result := make([]map[string]any, 0, 16) columns, err := walkTableRecord(ctx, d.db, querySql, func(record map[string]any, columns []string) { result = append(result, record) }) if err != nil { return nil, nil, wrapSqlError(err) } return columns, result, nil } // 将查询结果映射至struct,可具体参考sqlx库 func (d *DbConn) Query2Struct(execSql string, dest any) error { rows, err := d.db.Query(execSql) if err != nil { return err } // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, // 后面的链接过来直接报错或拒绝,实际上也没有起效果 defer func() { if rows != nil { rows.Close() } }() return scanAll(rows, dest, false) } // WalkTableRecord 遍历表记录 func (d *DbConn) WalkTableRecord(ctx context.Context, selectSql string, walk func(record map[string]any, columns []string)) error { _, err := walkTableRecord(ctx, d.db, selectSql, walk) return err } // 执行 update, insert, delete,建表等sql // 返回影响条数和错误 func (d *DbConn) Exec(sql string) (int64, error) { return d.ExecContext(context.Background(), sql) } // 执行 update, insert, delete,建表等sql // 返回影响条数和错误 func (d *DbConn) ExecContext(ctx context.Context, sql string) (int64, error) { res, err := d.db.ExecContext(ctx, sql) if err != nil { return 0, wrapSqlError(err) } return res.RowsAffected() } // 获取数据库元信息实现接口 func (d *DbConn) GetDialect() DbDialect { switch d.Info.Type { case DbTypeMysql: return &MysqlDialect{dc: d} case DbTypePostgres: return &PgsqlDialect{dc: d} case DM: return &DMDialect{dc: d} default: panic(fmt.Sprintf("invalid database type: %s", d.Info.Type)) } } // 关闭连接 func (d *DbConn) Close() { if d.db != nil { if err := d.db.Close(); err != nil { logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) } d.db = nil } } func walkTableRecord(ctx context.Context, db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) ([]string, error) { rows, err := db.QueryContext(ctx, selectSql) if err != nil { return nil, err } // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, // 后面的链接过来直接报错或拒绝,实际上也没有起效果 defer func() { if rows != nil { rows.Close() } }() colTypes, err := rows.ColumnTypes() if err != nil { return nil, err } lenCols := len(colTypes) // 列名用于前端表头名称按照数据库与查询字段顺序显示 colNames := make([]string, lenCols) // 这里表示一行填充数据 scans := make([]any, lenCols) // 这里表示一行所有列的值,用[]byte表示 values := make([][]byte, lenCols) for k, colType := range colTypes { colNames[k] = colType.Name() // 这里scans引用values,把数据填充到[]byte里 scans[k] = &values[k] } for rows.Next() { // 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大 if err := rows.Scan(scans...); err != nil { return nil, err } // 每行数据 rowData := make(map[string]any, lenCols) // 把values中的数据复制到row中 for i, v := range values { rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i]) } walk(rowData, colNames) } return colNames, nil } // 将查询的值转为对应列类型的实际值,不全部转为字符串 func valueConvert(data []byte, colType *sql.ColumnType) any { if data == nil { return nil } // 列的数据库类型名 colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName()) // 如果类型是bit,则直接返回第一个字节即可 if strings.Contains(colDatabaseTypeName, "bit") { return data[0] } // 这里把[]byte数据转成string stringV := string(data) if stringV == "" { return "" } colScanType := strings.ToLower(colType.ScanType().Name()) if strings.Contains(colScanType, "int") { // 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度 if len(stringV) > 16 { return stringV } intV, _ := strconv.Atoi(stringV) switch colType.ScanType().Kind() { case reflect.Int8: return int8(intV) case reflect.Uint8: return uint8(intV) case reflect.Int64: return int64(intV) case reflect.Uint64: return uint64(intV) case reflect.Uint: return uint(intV) default: return intV } } if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") { floatV, _ := strconv.ParseFloat(stringV, 64) return floatV } return stringV } // 包装sql执行相关错误 func wrapSqlError(err error) error { if err == context.Canceled { return errorx.NewBiz("取消执行") } if err == context.DeadlineExceeded { return errorx.NewBiz("执行超时") } return err }