mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	refactor: db代码review
This commit is contained in:
		@@ -165,7 +165,7 @@ func (d *dbAppImpl) GetDatabases(ed *entity.Db) []string {
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "数据库连接失败: %s")
 | 
			
		||||
	defer dbConn.Close()
 | 
			
		||||
 | 
			
		||||
	_, res, err := SelectDataByDb(dbConn, getDatabasesSql)
 | 
			
		||||
	_, res, err := SelectDataByDb(dbConn, getDatabasesSql, true)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "获取数据库列表失败")
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		databases = append(databases, re["dbname"].(string))
 | 
			
		||||
@@ -218,6 +218,67 @@ func (da *dbAppImpl) GetDbInstance(id uint64, db string) *DbInstance {
 | 
			
		||||
	return dbi
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//----------------------------------------  db instance  ------------------------------------
 | 
			
		||||
 | 
			
		||||
// db实例
 | 
			
		||||
type DbInstance struct {
 | 
			
		||||
	Id                 string
 | 
			
		||||
	Type               string
 | 
			
		||||
	ProjectId          uint64
 | 
			
		||||
	db                 *sql.DB
 | 
			
		||||
	sshTunnelMachineId uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行查询语句
 | 
			
		||||
// 依次返回 列名数组,结果map,错误
 | 
			
		||||
func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interface{}, error) {
 | 
			
		||||
	return SelectDataByDb(d.db, execSql, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 将查询结果映射至struct,可具体参考sqlx库
 | 
			
		||||
func (d *DbInstance) SelectData2Struct(execSql string, dest interface{}) error {
 | 
			
		||||
	return Select2StructByDb(d.db, execSql, dest)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行内部查询语句,不返回列名以及不限制行数
 | 
			
		||||
// 依次返回 结果map,错误
 | 
			
		||||
func (d *DbInstance) innerSelect(execSql string) ([]map[string]interface{}, error) {
 | 
			
		||||
	_, res, err := SelectDataByDb(d.db, execSql, true)
 | 
			
		||||
	return res, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行 update, insert, delete,建表等sql
 | 
			
		||||
// 返回影响条数和错误
 | 
			
		||||
func (d *DbInstance) Exec(sql string) (int64, error) {
 | 
			
		||||
	res, err := d.db.Exec(sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return res.RowsAffected()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取数据库元信息实现接口
 | 
			
		||||
func (di *DbInstance) GetMeta() DbMetadata {
 | 
			
		||||
	dbType := di.Type
 | 
			
		||||
	if dbType == entity.DbTypeMysql {
 | 
			
		||||
		return &MysqlMetadata{di: di}
 | 
			
		||||
	}
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return &PgsqlMetadata{di: di}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭连接
 | 
			
		||||
func (d *DbInstance) Close() {
 | 
			
		||||
	if d.db != nil {
 | 
			
		||||
		if err := d.db.Close(); err != nil {
 | 
			
		||||
			global.Log.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		d.db = nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
// 单次最大查询数据集
 | 
			
		||||
@@ -292,7 +353,28 @@ func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
 | 
			
		||||
	return DB, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]interface{}, error) {
 | 
			
		||||
// 获取dataSourceName
 | 
			
		||||
func getDsn(d *entity.Db, db string) string {
 | 
			
		||||
	var dsn string
 | 
			
		||||
	if d.Type == entity.DbTypeMysql {
 | 
			
		||||
		dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
 | 
			
		||||
		if d.Params != "" {
 | 
			
		||||
			dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
 | 
			
		||||
		}
 | 
			
		||||
		return dsn
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.Type == entity.DbTypePostgres {
 | 
			
		||||
		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
 | 
			
		||||
		if d.Params != "" {
 | 
			
		||||
			dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
 | 
			
		||||
		}
 | 
			
		||||
		return dsn
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) {
 | 
			
		||||
	rows, err := db.Query(selectSql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
@@ -322,7 +404,10 @@ func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]interf
 | 
			
		||||
	rowNum := 0
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		rowNum++
 | 
			
		||||
		biz.IsTrue(rowNum <= Max_Rows, "结果集 > 2000, 请完善条件或分页信息")
 | 
			
		||||
		// 非内部sql,则校验返回结果数量
 | 
			
		||||
		if !isInner {
 | 
			
		||||
			biz.IsTrue(rowNum <= Max_Rows, "结果集 > 2000, 请完善条件或分页信息")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大
 | 
			
		||||
		err := rows.Scan(scans...)
 | 
			
		||||
@@ -397,10 +482,11 @@ func valueConvert(data []byte, colType *sql.ColumnType) interface{} {
 | 
			
		||||
	return stringV
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func innerSelectByDb(db *sql.DB, selectSql string) ([]map[string]interface{}, error) {
 | 
			
		||||
// 查询数据结果映射至struct。可参考sqlx库
 | 
			
		||||
func Select2StructByDb(db *sql.DB, selectSql string, dest interface{}) error {
 | 
			
		||||
	rows, err := db.Query(selectSql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数,
 | 
			
		||||
	// 后面的链接过来直接报错或拒绝,实际上也没有起效果
 | 
			
		||||
@@ -409,35 +495,12 @@ func innerSelectByDb(db *sql.DB, selectSql string) ([]map[string]interface{}, er
 | 
			
		||||
			rows.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	colTypes, _ := rows.ColumnTypes()
 | 
			
		||||
	// 这里表示一行填充数据
 | 
			
		||||
	scans := make([]interface{}, len(colTypes))
 | 
			
		||||
	// 这里表示一行所有列的值,用[]byte表示
 | 
			
		||||
	vals := make([][]byte, len(colTypes))
 | 
			
		||||
	// 这里scans引用vals,把数据填充到[]byte里
 | 
			
		||||
	for k := range vals {
 | 
			
		||||
		scans[k] = &vals[k]
 | 
			
		||||
	}
 | 
			
		||||
	return scanAll(rows, dest, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	result := make([]map[string]interface{}, 0)
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		// 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大
 | 
			
		||||
		err := rows.Scan(scans...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		// 每行数据
 | 
			
		||||
		rowData := make(map[string]interface{})
 | 
			
		||||
		// 把vals中的数据复制到row中
 | 
			
		||||
		for i, v := range vals {
 | 
			
		||||
			colType := colTypes[i]
 | 
			
		||||
			colName := colType.Name()
 | 
			
		||||
			rowData[colName] = valueConvert(v, colType)
 | 
			
		||||
		}
 | 
			
		||||
		// 放入结果集
 | 
			
		||||
		result = append(result, rowData)
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
// 删除db缓存并关闭该数据库所有连接
 | 
			
		||||
func CloseDb(dbId uint64, db string) {
 | 
			
		||||
	dbCache.Delete(GetDbCacheKey(dbId, db))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PqSqlDialer struct {
 | 
			
		||||
@@ -452,85 +515,7 @@ func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
 | 
			
		||||
	return pd.Dial(network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// db实例
 | 
			
		||||
type DbInstance struct {
 | 
			
		||||
	Id                 string
 | 
			
		||||
	Type               string
 | 
			
		||||
	ProjectId          uint64
 | 
			
		||||
	db                 *sql.DB
 | 
			
		||||
	sshTunnelMachineId uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行查询语句
 | 
			
		||||
// 依次返回 列名数组,结果map,错误
 | 
			
		||||
func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interface{}, error) {
 | 
			
		||||
	return SelectDataByDb(d.db, execSql)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行内部查询语句,不返回列名以及不限制行数
 | 
			
		||||
// 依次返回 结果map,错误
 | 
			
		||||
func (d *DbInstance) innerSelect(execSql string) ([]map[string]interface{}, error) {
 | 
			
		||||
	return innerSelectByDb(d.db, execSql)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行 update, insert, delete,建表等sql
 | 
			
		||||
// 返回影响条数和错误
 | 
			
		||||
func (d *DbInstance) Exec(sql string) (int64, error) {
 | 
			
		||||
	res, err := d.db.Exec(sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return res.RowsAffected()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取数据库元信息实现接口
 | 
			
		||||
func (di *DbInstance) GetMeta() DbMetadata {
 | 
			
		||||
	dbType := di.Type
 | 
			
		||||
	if dbType == entity.DbTypeMysql {
 | 
			
		||||
		return &MysqlMetadata{di: di}
 | 
			
		||||
	}
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return &PgsqlMetadata{di: di}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭连接
 | 
			
		||||
func (d *DbInstance) Close() {
 | 
			
		||||
	if d.db != nil {
 | 
			
		||||
		if err := d.db.Close(); err != nil {
 | 
			
		||||
			global.Log.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		d.db = nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取dataSourceName
 | 
			
		||||
func getDsn(d *entity.Db, db string) string {
 | 
			
		||||
	var dsn string
 | 
			
		||||
	if d.Type == entity.DbTypeMysql {
 | 
			
		||||
		dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
 | 
			
		||||
		if d.Params != "" {
 | 
			
		||||
			dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
 | 
			
		||||
		}
 | 
			
		||||
		return dsn
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.Type == entity.DbTypePostgres {
 | 
			
		||||
		dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
 | 
			
		||||
		if d.Params != "" {
 | 
			
		||||
			dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
 | 
			
		||||
		}
 | 
			
		||||
		return dsn
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 删除db缓存并关闭该数据库所有连接
 | 
			
		||||
func CloseDb(dbId uint64, db string) {
 | 
			
		||||
	dbCache.Delete(GetDbCacheKey(dbId, db))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,11 @@
 | 
			
		||||
package application
 | 
			
		||||
 | 
			
		||||
// -----------------------------------元数据-------------------------------------------
 | 
			
		||||
// -----------------------------------元数据接口定义------------------------------------------
 | 
			
		||||
// 数据库元信息接口(表、列等元信息)
 | 
			
		||||
// 所有数据查出来直接用map接收,注意map的key需要统一
 | 
			
		||||
// 所有数据查出来直接用map接收,注意不同数据库实现该接口返回的map中的key需要统一.
 | 
			
		||||
// 即: 使用别名统一即可。如table_name AS tableName
 | 
			
		||||
type DbMetadata interface {
 | 
			
		||||
 | 
			
		||||
	// 获取表基础元信息
 | 
			
		||||
	// 表名: tableName, 备注: tableComment
 | 
			
		||||
	GetTables() []map[string]interface{}
 | 
			
		||||
@@ -24,6 +26,3 @@ type DbMetadata interface {
 | 
			
		||||
	// 获取建表ddl
 | 
			
		||||
	GetCreateTableDdl(tableName string) []map[string]interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 默认每次查询列元信息数量
 | 
			
		||||
const DEFAULT_COLUMN_SIZE = 2000
 | 
			
		||||
 
 | 
			
		||||
@@ -8,9 +8,7 @@ import (
 | 
			
		||||
// ---------------------------------- mysql元数据 -----------------------------------
 | 
			
		||||
const (
 | 
			
		||||
	// mysql 表信息元数据
 | 
			
		||||
	MYSQL_TABLE_MA = `SELECT table_name tableName, engine, table_comment tableComment, 
 | 
			
		||||
	create_time createTime from information_schema.tables
 | 
			
		||||
	WHERE table_schema = (SELECT database())`
 | 
			
		||||
	MYSQL_TABLE_MA = `SELECT table_name tableName, table_comment tableComment from information_schema.tables WHERE table_schema = (SELECT database())`
 | 
			
		||||
 | 
			
		||||
	// mysql 表信息
 | 
			
		||||
	MYSQL_TABLE_INFO = `SELECT table_name tableName, table_comment tableComment, table_rows tableRows,
 | 
			
		||||
 
 | 
			
		||||
@@ -88,5 +88,6 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) []map[string]interface{
 | 
			
		||||
 | 
			
		||||
// 获取建表ddl
 | 
			
		||||
func (mm *PgsqlMetadata) GetCreateTableDdl(tableName string) []map[string]interface{} {
 | 
			
		||||
	biz.IsTrue(tableName == "", "暂不支持获取pgsql建表DDL")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										630
									
								
								server/internal/db/application/sqlx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										630
									
								
								server/internal/db/application/sqlx.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,630 @@
 | 
			
		||||
package application
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 将结果scan至结构体,copy至 sqlx库: https://github.com/jmoiron/sqlx
 | 
			
		||||
func scanAll(rows *sql.Rows, dest interface{}, structOnly bool) error {
 | 
			
		||||
	var v, vp reflect.Value
 | 
			
		||||
 | 
			
		||||
	value := reflect.ValueOf(dest)
 | 
			
		||||
 | 
			
		||||
	// json.Unmarshal returns errors for these
 | 
			
		||||
	if value.Kind() != reflect.Ptr {
 | 
			
		||||
		return errors.New("must pass a pointer, not a value, to StructScan destination")
 | 
			
		||||
	}
 | 
			
		||||
	if value.IsNil() {
 | 
			
		||||
		return errors.New("nil pointer passed to StructScan destination")
 | 
			
		||||
	}
 | 
			
		||||
	direct := reflect.Indirect(value)
 | 
			
		||||
 | 
			
		||||
	slice, err := baseType(value.Type(), reflect.Slice)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	direct.SetLen(0)
 | 
			
		||||
 | 
			
		||||
	isPtr := slice.Elem().Kind() == reflect.Ptr
 | 
			
		||||
	base := Deref(slice.Elem())
 | 
			
		||||
	scannable := isScannable(base)
 | 
			
		||||
 | 
			
		||||
	if structOnly && scannable {
 | 
			
		||||
		return structOnlyError(base)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columns, err := rows.Columns()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// if it's a base type make sure it only has 1 column;  if not return an error
 | 
			
		||||
	if scannable && len(columns) > 1 {
 | 
			
		||||
		return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scannable {
 | 
			
		||||
		var values []interface{}
 | 
			
		||||
		var m *Mapper = mapper()
 | 
			
		||||
 | 
			
		||||
		fields := m.TraversalsByName(base, columns)
 | 
			
		||||
		// if we are not unsafe and are missing fields, return an error
 | 
			
		||||
		if f, err := missingFields(fields); err != nil {
 | 
			
		||||
			return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
 | 
			
		||||
		}
 | 
			
		||||
		values = make([]interface{}, len(columns))
 | 
			
		||||
 | 
			
		||||
		for rows.Next() {
 | 
			
		||||
			// create a new struct type (which returns PtrTo) and indirect it
 | 
			
		||||
			vp = reflect.New(base)
 | 
			
		||||
			v = reflect.Indirect(vp)
 | 
			
		||||
 | 
			
		||||
			err = fieldsByTraversal(v, fields, values, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// scan into the struct field pointers and append to our results
 | 
			
		||||
			err = rows.Scan(values...)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if isPtr {
 | 
			
		||||
				direct.Set(reflect.Append(direct, vp))
 | 
			
		||||
			} else {
 | 
			
		||||
				direct.Set(reflect.Append(direct, v))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		for rows.Next() {
 | 
			
		||||
			vp = reflect.New(base)
 | 
			
		||||
			err = rows.Scan(vp.Interface())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			// append
 | 
			
		||||
			if isPtr {
 | 
			
		||||
				direct.Set(reflect.Append(direct, vp))
 | 
			
		||||
			} else {
 | 
			
		||||
				direct.Set(reflect.Append(direct, reflect.Indirect(vp)))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rows.Err()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
 | 
			
		||||
	t = Deref(t)
 | 
			
		||||
	if t.Kind() != expected {
 | 
			
		||||
		return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
 | 
			
		||||
	}
 | 
			
		||||
	return t, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// structOnlyError returns an error appropriate for type when a non-scannable
 | 
			
		||||
// struct is expected but something else is given
 | 
			
		||||
func structOnlyError(t reflect.Type) error {
 | 
			
		||||
	isStruct := t.Kind() == reflect.Struct
 | 
			
		||||
	isScanner := reflect.PtrTo(t).Implements(_scannerInterface)
 | 
			
		||||
	if !isStruct {
 | 
			
		||||
		return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind())
 | 
			
		||||
	}
 | 
			
		||||
	if isScanner {
 | 
			
		||||
		return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name())
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
 | 
			
		||||
 | 
			
		||||
func isScannable(t reflect.Type) bool {
 | 
			
		||||
	if reflect.PtrTo(t).Implements(_scannerInterface) {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if t.Kind() != reflect.Struct {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// it's not important that we use the right mapper for this particular object,
 | 
			
		||||
	// we're only concerned on how many exported fields this struct has
 | 
			
		||||
	return len(mapper().TypeMap(t).Index) == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var NameMapper = strings.ToLower
 | 
			
		||||
var origMapper = reflect.ValueOf(NameMapper)
 | 
			
		||||
 | 
			
		||||
// Rather than creating on init, this is created when necessary so that
 | 
			
		||||
// importers have time to customize the NameMapper.
 | 
			
		||||
var mpr *Mapper
 | 
			
		||||
 | 
			
		||||
// mprMu protects mpr.
 | 
			
		||||
var mprMu sync.Mutex
 | 
			
		||||
 | 
			
		||||
// mapper returns a valid mapper using the configured NameMapper func.
 | 
			
		||||
func mapper() *Mapper {
 | 
			
		||||
	mprMu.Lock()
 | 
			
		||||
	defer mprMu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if mpr == nil {
 | 
			
		||||
		mpr = NewMapperFunc("db", NameMapper)
 | 
			
		||||
	} else if origMapper != reflect.ValueOf(NameMapper) {
 | 
			
		||||
		// if NameMapper has changed, create a new mapper
 | 
			
		||||
		mpr = NewMapperFunc("db", NameMapper)
 | 
			
		||||
		origMapper = reflect.ValueOf(NameMapper)
 | 
			
		||||
	}
 | 
			
		||||
	return mpr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func missingFields(transversals [][]int) (field int, err error) {
 | 
			
		||||
	for i, t := range transversals {
 | 
			
		||||
		if len(t) == 0 {
 | 
			
		||||
			return i, errors.New("missing field")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return 0, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fieldsByName fills a values interface with fields from the passed value based
 | 
			
		||||
// on the traversals in int.  If ptrs is true, return addresses instead of values.
 | 
			
		||||
// We write this instead of using FieldsByName to save allocations and map lookups
 | 
			
		||||
// when iterating over many rows.  Empty traversals will get an interface pointer.
 | 
			
		||||
// Because of the necessity of requesting ptrs or values, it's considered a bit too
 | 
			
		||||
// specialized for inclusion in reflectx itself.
 | 
			
		||||
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
 | 
			
		||||
	v = reflect.Indirect(v)
 | 
			
		||||
	if v.Kind() != reflect.Struct {
 | 
			
		||||
		return errors.New("argument not a struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, traversal := range traversals {
 | 
			
		||||
		if len(traversal) == 0 {
 | 
			
		||||
			values[i] = new(interface{})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		f := FieldByIndexes(v, traversal)
 | 
			
		||||
		if ptrs {
 | 
			
		||||
			values[i] = f.Addr().Interface()
 | 
			
		||||
		} else {
 | 
			
		||||
			values[i] = f.Interface()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A FieldInfo is metadata for a struct field.
 | 
			
		||||
type FieldInfo struct {
 | 
			
		||||
	Index    []int
 | 
			
		||||
	Path     string
 | 
			
		||||
	Field    reflect.StructField
 | 
			
		||||
	Zero     reflect.Value
 | 
			
		||||
	Name     string
 | 
			
		||||
	Options  map[string]string
 | 
			
		||||
	Embedded bool
 | 
			
		||||
	Children []*FieldInfo
 | 
			
		||||
	Parent   *FieldInfo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A StructMap is an index of field metadata for a struct.
 | 
			
		||||
type StructMap struct {
 | 
			
		||||
	Tree  *FieldInfo
 | 
			
		||||
	Index []*FieldInfo
 | 
			
		||||
	Paths map[string]*FieldInfo
 | 
			
		||||
	Names map[string]*FieldInfo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByPath returns a *FieldInfo for a given string path.
 | 
			
		||||
func (f StructMap) GetByPath(path string) *FieldInfo {
 | 
			
		||||
	return f.Paths[path]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByTraversal returns a *FieldInfo for a given integer path.  It is
 | 
			
		||||
// analogous to reflect.FieldByIndex, but using the cached traversal
 | 
			
		||||
// rather than re-executing the reflect machinery each time.
 | 
			
		||||
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
 | 
			
		||||
	if len(index) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tree := f.Tree
 | 
			
		||||
	for _, i := range index {
 | 
			
		||||
		if i >= len(tree.Children) || tree.Children[i] == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		tree = tree.Children[i]
 | 
			
		||||
	}
 | 
			
		||||
	return tree
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mapper is a general purpose mapper of names to struct fields.  A Mapper
 | 
			
		||||
// behaves like most marshallers in the standard library, obeying a field tag
 | 
			
		||||
// for name mapping but also providing a basic transform function.
 | 
			
		||||
type Mapper struct {
 | 
			
		||||
	cache      map[reflect.Type]*StructMap
 | 
			
		||||
	tagName    string
 | 
			
		||||
	tagMapFunc func(string) string
 | 
			
		||||
	mapFunc    func(string) string
 | 
			
		||||
	mutex      sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMapper returns a new mapper using the tagName as its struct field tag.
 | 
			
		||||
// If tagName is the empty string, it is ignored.
 | 
			
		||||
func NewMapper(tagName string) *Mapper {
 | 
			
		||||
	return &Mapper{
 | 
			
		||||
		cache:   make(map[reflect.Type]*StructMap),
 | 
			
		||||
		tagName: tagName,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
 | 
			
		||||
// AND a mapper for tag values.  This is useful for tags like json which can
 | 
			
		||||
// have values like "name,omitempty".
 | 
			
		||||
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
 | 
			
		||||
	return &Mapper{
 | 
			
		||||
		cache:      make(map[reflect.Type]*StructMap),
 | 
			
		||||
		tagName:    tagName,
 | 
			
		||||
		mapFunc:    mapFunc,
 | 
			
		||||
		tagMapFunc: tagMapFunc,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
 | 
			
		||||
// a struct field name mapper func given by f.  Tags will take precedence, but
 | 
			
		||||
// for any other field, the mapped name will be f(field.Name)
 | 
			
		||||
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
 | 
			
		||||
	return &Mapper{
 | 
			
		||||
		cache:   make(map[reflect.Type]*StructMap),
 | 
			
		||||
		tagName: tagName,
 | 
			
		||||
		mapFunc: f,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TypeMap returns a mapping of field strings to int slices representing
 | 
			
		||||
// the traversal down the struct to reach the field.
 | 
			
		||||
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
 | 
			
		||||
	m.mutex.Lock()
 | 
			
		||||
	mapping, ok := m.cache[t]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
 | 
			
		||||
		m.cache[t] = mapping
 | 
			
		||||
	}
 | 
			
		||||
	m.mutex.Unlock()
 | 
			
		||||
	return mapping
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FieldMap returns the mapper's mapping of field names to reflect values.  Panics
 | 
			
		||||
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
 | 
			
		||||
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
 | 
			
		||||
	v = reflect.Indirect(v)
 | 
			
		||||
	mustBe(v, reflect.Struct)
 | 
			
		||||
 | 
			
		||||
	r := map[string]reflect.Value{}
 | 
			
		||||
	tm := m.TypeMap(v.Type())
 | 
			
		||||
	for tagName, fi := range tm.Names {
 | 
			
		||||
		r[tagName] = FieldByIndexes(v, fi.Index)
 | 
			
		||||
	}
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FieldByName returns a field by its mapped name as a reflect.Value.
 | 
			
		||||
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
 | 
			
		||||
// Returns zero Value if the name is not found.
 | 
			
		||||
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
 | 
			
		||||
	v = reflect.Indirect(v)
 | 
			
		||||
	mustBe(v, reflect.Struct)
 | 
			
		||||
 | 
			
		||||
	tm := m.TypeMap(v.Type())
 | 
			
		||||
	fi, ok := tm.Names[name]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return v
 | 
			
		||||
	}
 | 
			
		||||
	return FieldByIndexes(v, fi.Index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FieldsByName returns a slice of values corresponding to the slice of names
 | 
			
		||||
// for the value.  Panics if v's Kind is not Struct or v is not Indirectable
 | 
			
		||||
// to a struct Kind.  Returns zero Value for each name not found.
 | 
			
		||||
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
 | 
			
		||||
	v = reflect.Indirect(v)
 | 
			
		||||
	mustBe(v, reflect.Struct)
 | 
			
		||||
 | 
			
		||||
	tm := m.TypeMap(v.Type())
 | 
			
		||||
	vals := make([]reflect.Value, 0, len(names))
 | 
			
		||||
	for _, name := range names {
 | 
			
		||||
		fi, ok := tm.Names[name]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			vals = append(vals, *new(reflect.Value))
 | 
			
		||||
		} else {
 | 
			
		||||
			vals = append(vals, FieldByIndexes(v, fi.Index))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return vals
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TraversalsByName returns a slice of int slices which represent the struct
 | 
			
		||||
// traversals for each mapped name.  Panics if t is not a struct or Indirectable
 | 
			
		||||
// to a struct.  Returns empty int slice for each name not found.
 | 
			
		||||
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
 | 
			
		||||
	r := make([][]int, 0, len(names))
 | 
			
		||||
	m.TraversalsByNameFunc(t, names, func(_ int, i []int) error {
 | 
			
		||||
		if i == nil {
 | 
			
		||||
			r = append(r, []int{})
 | 
			
		||||
		} else {
 | 
			
		||||
			r = append(r, i)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	return r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TraversalsByNameFunc traverses the mapped names and calls fn with the index of
 | 
			
		||||
// each name and the struct traversal represented by that name. Panics if t is not
 | 
			
		||||
// a struct or Indirectable to a struct. Returns the first error returned by fn or nil.
 | 
			
		||||
func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error {
 | 
			
		||||
	t = Deref(t)
 | 
			
		||||
	mustBe(t, reflect.Struct)
 | 
			
		||||
	tm := m.TypeMap(t)
 | 
			
		||||
	for i, name := range names {
 | 
			
		||||
		fi, ok := tm.Names[name]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			if err := fn(i, nil); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			if err := fn(i, fi.Index); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FieldByIndexes returns a value for the field given by the struct traversal
 | 
			
		||||
// for the given value.
 | 
			
		||||
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
 | 
			
		||||
	for _, i := range indexes {
 | 
			
		||||
		v = reflect.Indirect(v).Field(i)
 | 
			
		||||
		// if this is a pointer and it's nil, allocate a new value and set it
 | 
			
		||||
		if v.Kind() == reflect.Ptr && v.IsNil() {
 | 
			
		||||
			alloc := reflect.New(Deref(v.Type()))
 | 
			
		||||
			v.Set(alloc)
 | 
			
		||||
		}
 | 
			
		||||
		if v.Kind() == reflect.Map && v.IsNil() {
 | 
			
		||||
			v.Set(reflect.MakeMap(v.Type()))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return v
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
 | 
			
		||||
// but is not concerned with allocating nil pointers because the value is
 | 
			
		||||
// going to be used for reading and not setting.
 | 
			
		||||
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
 | 
			
		||||
	for _, i := range indexes {
 | 
			
		||||
		v = reflect.Indirect(v).Field(i)
 | 
			
		||||
	}
 | 
			
		||||
	return v
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Deref is Indirect for reflect.Types
 | 
			
		||||
func Deref(t reflect.Type) reflect.Type {
 | 
			
		||||
	if t.Kind() == reflect.Ptr {
 | 
			
		||||
		t = t.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -- helpers & utilities --
 | 
			
		||||
 | 
			
		||||
type kinder interface {
 | 
			
		||||
	Kind() reflect.Kind
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mustBe checks a value against a kind, panicing with a reflect.ValueError
 | 
			
		||||
// if the kind isn't that which is required.
 | 
			
		||||
func mustBe(v kinder, expected reflect.Kind) {
 | 
			
		||||
	if k := v.Kind(); k != expected {
 | 
			
		||||
		panic(&reflect.ValueError{Method: methodName(), Kind: k})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// methodName returns the caller of the function calling methodName
 | 
			
		||||
func methodName() string {
 | 
			
		||||
	pc, _, _, _ := runtime.Caller(2)
 | 
			
		||||
	f := runtime.FuncForPC(pc)
 | 
			
		||||
	if f == nil {
 | 
			
		||||
		return "unknown method"
 | 
			
		||||
	}
 | 
			
		||||
	return f.Name()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type typeQueue struct {
 | 
			
		||||
	t  reflect.Type
 | 
			
		||||
	fi *FieldInfo
 | 
			
		||||
	pp string // Parent path
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A copying append that creates a new slice each time.
 | 
			
		||||
func apnd(is []int, i int) []int {
 | 
			
		||||
	x := make([]int, len(is)+1)
 | 
			
		||||
	copy(x, is)
 | 
			
		||||
	x[len(x)-1] = i
 | 
			
		||||
	return x
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mapf func(string) string
 | 
			
		||||
 | 
			
		||||
// parseName parses the tag and the target name for the given field using
 | 
			
		||||
// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the
 | 
			
		||||
// field's name to a target name, and tagMapFunc for mapping the tag to
 | 
			
		||||
// a target name.
 | 
			
		||||
func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) {
 | 
			
		||||
	// first, set the fieldName to the field's name
 | 
			
		||||
	fieldName = field.Name
 | 
			
		||||
	// if a mapFunc is set, use that to override the fieldName
 | 
			
		||||
	if mapFunc != nil {
 | 
			
		||||
		fieldName = mapFunc(fieldName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// if there's no tag to look for, return the field name
 | 
			
		||||
	if tagName == "" {
 | 
			
		||||
		return "", fieldName
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// if this tag is not set using the normal convention in the tag,
 | 
			
		||||
	// then return the fieldname..  this check is done because according
 | 
			
		||||
	// to the reflect documentation:
 | 
			
		||||
	//    If the tag does not have the conventional format,
 | 
			
		||||
	//    the value returned by Get is unspecified.
 | 
			
		||||
	// which doesn't sound great.
 | 
			
		||||
	if !strings.Contains(string(field.Tag), tagName+":") {
 | 
			
		||||
		return "", fieldName
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// at this point we're fairly sure that we have a tag, so lets pull it out
 | 
			
		||||
	tag = field.Tag.Get(tagName)
 | 
			
		||||
 | 
			
		||||
	// if we have a mapper function, call it on the whole tag
 | 
			
		||||
	// XXX: this is a change from the old version, which pulled out the name
 | 
			
		||||
	// before the tagMapFunc could be run, but I think this is the right way
 | 
			
		||||
	if tagMapFunc != nil {
 | 
			
		||||
		tag = tagMapFunc(tag)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// finally, split the options from the name
 | 
			
		||||
	parts := strings.Split(tag, ",")
 | 
			
		||||
	fieldName = parts[0]
 | 
			
		||||
 | 
			
		||||
	return tag, fieldName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// parseOptions parses options out of a tag string, skipping the name
 | 
			
		||||
func parseOptions(tag string) map[string]string {
 | 
			
		||||
	parts := strings.Split(tag, ",")
 | 
			
		||||
	options := make(map[string]string, len(parts))
 | 
			
		||||
	if len(parts) > 1 {
 | 
			
		||||
		for _, opt := range parts[1:] {
 | 
			
		||||
			// short circuit potentially expensive split op
 | 
			
		||||
			if strings.Contains(opt, "=") {
 | 
			
		||||
				kv := strings.Split(opt, "=")
 | 
			
		||||
				options[kv[0]] = kv[1]
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			options[opt] = ""
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return options
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
 | 
			
		||||
// tagMapFunc to determine the canonical names of fields.
 | 
			
		||||
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap {
 | 
			
		||||
	m := []*FieldInfo{}
 | 
			
		||||
 | 
			
		||||
	root := &FieldInfo{}
 | 
			
		||||
	queue := []typeQueue{}
 | 
			
		||||
	queue = append(queue, typeQueue{Deref(t), root, ""})
 | 
			
		||||
 | 
			
		||||
QueueLoop:
 | 
			
		||||
	for len(queue) != 0 {
 | 
			
		||||
		// pop the first item off of the queue
 | 
			
		||||
		tq := queue[0]
 | 
			
		||||
		queue = queue[1:]
 | 
			
		||||
 | 
			
		||||
		// ignore recursive field
 | 
			
		||||
		for p := tq.fi.Parent; p != nil; p = p.Parent {
 | 
			
		||||
			if tq.fi.Field.Type == p.Field.Type {
 | 
			
		||||
				continue QueueLoop
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nChildren := 0
 | 
			
		||||
		if tq.t.Kind() == reflect.Struct {
 | 
			
		||||
			nChildren = tq.t.NumField()
 | 
			
		||||
		}
 | 
			
		||||
		tq.fi.Children = make([]*FieldInfo, nChildren)
 | 
			
		||||
 | 
			
		||||
		// iterate through all of its fields
 | 
			
		||||
		for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
 | 
			
		||||
 | 
			
		||||
			f := tq.t.Field(fieldPos)
 | 
			
		||||
 | 
			
		||||
			// parse the tag and the target name using the mapping options for this field
 | 
			
		||||
			tag, name := parseName(f, tagName, mapFunc, tagMapFunc)
 | 
			
		||||
 | 
			
		||||
			// if the name is "-", disabled via a tag, skip it
 | 
			
		||||
			if name == "-" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fi := FieldInfo{
 | 
			
		||||
				Field:   f,
 | 
			
		||||
				Name:    name,
 | 
			
		||||
				Zero:    reflect.New(f.Type).Elem(),
 | 
			
		||||
				Options: parseOptions(tag),
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// if the path is empty this path is just the name
 | 
			
		||||
			if tq.pp == "" {
 | 
			
		||||
				fi.Path = fi.Name
 | 
			
		||||
			} else {
 | 
			
		||||
				fi.Path = tq.pp + "." + fi.Name
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// skip unexported fields
 | 
			
		||||
			if len(f.PkgPath) != 0 && !f.Anonymous {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// bfs search of anonymous embedded structs
 | 
			
		||||
			if f.Anonymous {
 | 
			
		||||
				pp := tq.pp
 | 
			
		||||
				if tag != "" {
 | 
			
		||||
					pp = fi.Path
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				fi.Embedded = true
 | 
			
		||||
				fi.Index = apnd(tq.fi.Index, fieldPos)
 | 
			
		||||
				nChildren := 0
 | 
			
		||||
				ft := Deref(f.Type)
 | 
			
		||||
				if ft.Kind() == reflect.Struct {
 | 
			
		||||
					nChildren = ft.NumField()
 | 
			
		||||
				}
 | 
			
		||||
				fi.Children = make([]*FieldInfo, nChildren)
 | 
			
		||||
				queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
 | 
			
		||||
			} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
 | 
			
		||||
				fi.Index = apnd(tq.fi.Index, fieldPos)
 | 
			
		||||
				fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
 | 
			
		||||
				queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fi.Index = apnd(tq.fi.Index, fieldPos)
 | 
			
		||||
			fi.Parent = tq.fi
 | 
			
		||||
			tq.fi.Children[fieldPos] = &fi
 | 
			
		||||
			m = append(m, &fi)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
 | 
			
		||||
	for _, fi := range flds.Index {
 | 
			
		||||
		// check if nothing has already been pushed with the same path
 | 
			
		||||
		// sometimes you can choose to override a type using embedded struct
 | 
			
		||||
		fld, ok := flds.Paths[fi.Path]
 | 
			
		||||
		if !ok || fld.Embedded {
 | 
			
		||||
			flds.Paths[fi.Path] = fi
 | 
			
		||||
			if fi.Name != "" && !fi.Embedded {
 | 
			
		||||
				flds.Names[fi.Path] = fi
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return flds
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user