diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index a4438132..d64cdf88 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -165,7 +165,7 @@ func (d *dbAppImpl) GetDatabases(ed *entity.Db) []string { biz.ErrIsNilAppendErr(err, "数据库连接失败: %s") defer dbConn.Close() - _, res, err := SelectDataByDb(dbConn, getDatabasesSql) + _, res, err := SelectDataByDb(dbConn, getDatabasesSql, true) biz.ErrIsNilAppendErr(err, "获取数据库列表失败") for _, re := range res { databases = append(databases, re["dbname"].(string)) @@ -218,6 +218,67 @@ func (da *dbAppImpl) GetDbInstance(id uint64, db string) *DbInstance { return dbi } +//---------------------------------------- db instance ------------------------------------ + +// db实例 +type DbInstance struct { + Id string + Type string + ProjectId uint64 + db *sql.DB + sshTunnelMachineId uint64 +} + +// 执行查询语句 +// 依次返回 列名数组,结果map,错误 +func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interface{}, error) { + return SelectDataByDb(d.db, execSql, false) +} + +// 将查询结果映射至struct,可具体参考sqlx库 +func (d *DbInstance) SelectData2Struct(execSql string, dest interface{}) error { + return Select2StructByDb(d.db, execSql, dest) +} + +// 执行内部查询语句,不返回列名以及不限制行数 +// 依次返回 结果map,错误 +func (d *DbInstance) innerSelect(execSql string) ([]map[string]interface{}, error) { + _, res, err := SelectDataByDb(d.db, execSql, true) + return res, err +} + +// 执行 update, insert, delete,建表等sql +// 返回影响条数和错误 +func (d *DbInstance) Exec(sql string) (int64, error) { + res, err := d.db.Exec(sql) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// 获取数据库元信息实现接口 +func (di *DbInstance) GetMeta() DbMetadata { + dbType := di.Type + if dbType == entity.DbTypeMysql { + return &MysqlMetadata{di: di} + } + if dbType == entity.DbTypePostgres { + return &PgsqlMetadata{di: di} + } + return nil +} + +// 关闭连接 +func (d *DbInstance) Close() { + if d.db != nil { + if err := d.db.Close(); err != nil { + global.Log.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) + } + d.db = nil + } +} + //------------------------------------------------------------------------------ // 单次最大查询数据集 @@ -292,7 +353,28 @@ func GetDbConn(d *entity.Db, db string) (*sql.DB, error) { return DB, nil } -func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]interface{}, error) { +// 获取dataSourceName +func getDsn(d *entity.Db, db string) string { + var dsn string + if d.Type == entity.DbTypeMysql { + dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db) + if d.Params != "" { + dsn = fmt.Sprintf("%s&%s", dsn, d.Params) + } + return dsn + } + + if d.Type == entity.DbTypePostgres { + dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db) + if d.Params != "" { + dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " ")) + } + return dsn + } + return "" +} + +func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) { rows, err := db.Query(selectSql) if err != nil { return nil, nil, err @@ -322,7 +404,10 @@ func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]interf rowNum := 0 for rows.Next() { rowNum++ - biz.IsTrue(rowNum <= Max_Rows, "结果集 > 2000, 请完善条件或分页信息") + // 非内部sql,则校验返回结果数量 + if !isInner { + biz.IsTrue(rowNum <= Max_Rows, "结果集 > 2000, 请完善条件或分页信息") + } // 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大 err := rows.Scan(scans...) @@ -397,10 +482,11 @@ func valueConvert(data []byte, colType *sql.ColumnType) interface{} { return stringV } -func innerSelectByDb(db *sql.DB, selectSql string) ([]map[string]interface{}, error) { +// 查询数据结果映射至struct。可参考sqlx库 +func Select2StructByDb(db *sql.DB, selectSql string, dest interface{}) error { rows, err := db.Query(selectSql) if err != nil { - return nil, err + return err } // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, // 后面的链接过来直接报错或拒绝,实际上也没有起效果 @@ -409,35 +495,12 @@ func innerSelectByDb(db *sql.DB, selectSql string) ([]map[string]interface{}, er rows.Close() } }() - colTypes, _ := rows.ColumnTypes() - // 这里表示一行填充数据 - scans := make([]interface{}, len(colTypes)) - // 这里表示一行所有列的值,用[]byte表示 - vals := make([][]byte, len(colTypes)) - // 这里scans引用vals,把数据填充到[]byte里 - for k := range vals { - scans[k] = &vals[k] - } + return scanAll(rows, dest, false) +} - result := make([]map[string]interface{}, 0) - for rows.Next() { - // 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大 - err := rows.Scan(scans...) - if err != nil { - return nil, err - } - // 每行数据 - rowData := make(map[string]interface{}) - // 把vals中的数据复制到row中 - for i, v := range vals { - colType := colTypes[i] - colName := colType.Name() - rowData[colName] = valueConvert(v, colType) - } - // 放入结果集 - result = append(result, rowData) - } - return result, nil +// 删除db缓存并关闭该数据库所有连接 +func CloseDb(dbId uint64, db string) { + dbCache.Delete(GetDbCacheKey(dbId, db)) } type PqSqlDialer struct { @@ -452,85 +515,7 @@ func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) { return nil, err } } + func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { return pd.Dial(network, address) } - -// db实例 -type DbInstance struct { - Id string - Type string - ProjectId uint64 - db *sql.DB - sshTunnelMachineId uint64 -} - -// 执行查询语句 -// 依次返回 列名数组,结果map,错误 -func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interface{}, error) { - return SelectDataByDb(d.db, execSql) -} - -// 执行内部查询语句,不返回列名以及不限制行数 -// 依次返回 结果map,错误 -func (d *DbInstance) innerSelect(execSql string) ([]map[string]interface{}, error) { - return innerSelectByDb(d.db, execSql) -} - -// 执行 update, insert, delete,建表等sql -// 返回影响条数和错误 -func (d *DbInstance) Exec(sql string) (int64, error) { - res, err := d.db.Exec(sql) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -// 获取数据库元信息实现接口 -func (di *DbInstance) GetMeta() DbMetadata { - dbType := di.Type - if dbType == entity.DbTypeMysql { - return &MysqlMetadata{di: di} - } - if dbType == entity.DbTypePostgres { - return &PgsqlMetadata{di: di} - } - return nil -} - -// 关闭连接 -func (d *DbInstance) Close() { - if d.db != nil { - if err := d.db.Close(); err != nil { - global.Log.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) - } - d.db = nil - } -} - -// 获取dataSourceName -func getDsn(d *entity.Db, db string) string { - var dsn string - if d.Type == entity.DbTypeMysql { - dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db) - if d.Params != "" { - dsn = fmt.Sprintf("%s&%s", dsn, d.Params) - } - return dsn - } - - if d.Type == entity.DbTypePostgres { - dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db) - if d.Params != "" { - dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " ")) - } - return dsn - } - return "" -} - -// 删除db缓存并关闭该数据库所有连接 -func CloseDb(dbId uint64, db string) { - dbCache.Delete(GetDbCacheKey(dbId, db)) -} diff --git a/server/internal/db/application/meta.go b/server/internal/db/application/meta.go index 5fb36768..677388d2 100644 --- a/server/internal/db/application/meta.go +++ b/server/internal/db/application/meta.go @@ -1,9 +1,11 @@ package application -// -----------------------------------元数据------------------------------------------- +// -----------------------------------元数据接口定义------------------------------------------ // 数据库元信息接口(表、列等元信息) -// 所有数据查出来直接用map接收,注意map的key需要统一 +// 所有数据查出来直接用map接收,注意不同数据库实现该接口返回的map中的key需要统一. +// 即: 使用别名统一即可。如table_name AS tableName type DbMetadata interface { + // 获取表基础元信息 // 表名: tableName, 备注: tableComment GetTables() []map[string]interface{} @@ -24,6 +26,3 @@ type DbMetadata interface { // 获取建表ddl GetCreateTableDdl(tableName string) []map[string]interface{} } - -// 默认每次查询列元信息数量 -const DEFAULT_COLUMN_SIZE = 2000 diff --git a/server/internal/db/application/mysql_meta.go b/server/internal/db/application/mysql_meta.go index f1b43011..388e59c6 100644 --- a/server/internal/db/application/mysql_meta.go +++ b/server/internal/db/application/mysql_meta.go @@ -8,9 +8,7 @@ import ( // ---------------------------------- mysql元数据 ----------------------------------- const ( // mysql 表信息元数据 - MYSQL_TABLE_MA = `SELECT table_name tableName, engine, table_comment tableComment, - create_time createTime from information_schema.tables - WHERE table_schema = (SELECT database())` + MYSQL_TABLE_MA = `SELECT table_name tableName, table_comment tableComment from information_schema.tables WHERE table_schema = (SELECT database())` // mysql 表信息 MYSQL_TABLE_INFO = `SELECT table_name tableName, table_comment tableComment, table_rows tableRows, diff --git a/server/internal/db/application/pgsql_meta.go b/server/internal/db/application/pgsql_meta.go index b58f7e3b..60f3d1de 100644 --- a/server/internal/db/application/pgsql_meta.go +++ b/server/internal/db/application/pgsql_meta.go @@ -88,5 +88,6 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) []map[string]interface{ // 获取建表ddl func (mm *PgsqlMetadata) GetCreateTableDdl(tableName string) []map[string]interface{} { + biz.IsTrue(tableName == "", "暂不支持获取pgsql建表DDL") return nil } diff --git a/server/internal/db/application/sqlx.go b/server/internal/db/application/sqlx.go new file mode 100644 index 00000000..f46c4432 --- /dev/null +++ b/server/internal/db/application/sqlx.go @@ -0,0 +1,630 @@ +package application + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "sync" +) + +// 将结果scan至结构体,copy至 sqlx库: https://github.com/jmoiron/sqlx +func scanAll(rows *sql.Rows, dest interface{}, structOnly bool) error { + var v, vp reflect.Value + + value := reflect.ValueOf(dest) + + // json.Unmarshal returns errors for these + if value.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if value.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + direct := reflect.Indirect(value) + + slice, err := baseType(value.Type(), reflect.Slice) + if err != nil { + return err + } + direct.SetLen(0) + + isPtr := slice.Elem().Kind() == reflect.Ptr + base := Deref(slice.Elem()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + columns, err := rows.Columns() + if err != nil { + return err + } + + // if it's a base type make sure it only has 1 column; if not return an error + if scannable && len(columns) > 1 { + return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) + } + + if !scannable { + var values []interface{} + var m *Mapper = mapper() + + fields := m.TraversalsByName(base, columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(fields); err != nil { + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) + } + values = make([]interface{}, len(columns)) + + for rows.Next() { + // create a new struct type (which returns PtrTo) and indirect it + vp = reflect.New(base) + v = reflect.Indirect(vp) + + err = fieldsByTraversal(v, fields, values, true) + if err != nil { + return err + } + + // scan into the struct field pointers and append to our results + err = rows.Scan(values...) + if err != nil { + return err + } + + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, v)) + } + } + } else { + for rows.Next() { + vp = reflect.New(base) + err = rows.Scan(vp.Interface()) + if err != nil { + return err + } + // append + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, reflect.Indirect(vp))) + } + } + } + + return rows.Err() +} + +func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { + t = Deref(t) + if t.Kind() != expected { + return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) + } + return t, nil +} + +// structOnlyError returns an error appropriate for type when a non-scannable +// struct is expected but something else is given +func structOnlyError(t reflect.Type) error { + isStruct := t.Kind() == reflect.Struct + isScanner := reflect.PtrTo(t).Implements(_scannerInterface) + if !isStruct { + return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) + } + if isScanner { + return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) + } + return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) +} + +var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func isScannable(t reflect.Type) bool { + if reflect.PtrTo(t).Implements(_scannerInterface) { + return true + } + if t.Kind() != reflect.Struct { + return true + } + + // it's not important that we use the right mapper for this particular object, + // we're only concerned on how many exported fields this struct has + return len(mapper().TypeMap(t).Index) == 0 +} + +var NameMapper = strings.ToLower +var origMapper = reflect.ValueOf(NameMapper) + +// Rather than creating on init, this is created when necessary so that +// importers have time to customize the NameMapper. +var mpr *Mapper + +// mprMu protects mpr. +var mprMu sync.Mutex + +// mapper returns a valid mapper using the configured NameMapper func. +func mapper() *Mapper { + mprMu.Lock() + defer mprMu.Unlock() + + if mpr == nil { + mpr = NewMapperFunc("db", NameMapper) + } else if origMapper != reflect.ValueOf(NameMapper) { + // if NameMapper has changed, create a new mapper + mpr = NewMapperFunc("db", NameMapper) + origMapper = reflect.ValueOf(NameMapper) + } + return mpr +} + +func missingFields(transversals [][]int) (field int, err error) { + for i, t := range transversals { + if len(t) == 0 { + return i, errors.New("missing field") + } + } + return 0, nil +} + +// fieldsByName fills a values interface with fields from the passed value based +// on the traversals in int. If ptrs is true, return addresses instead of values. +// We write this instead of using FieldsByName to save allocations and map lookups +// when iterating over many rows. Empty traversals will get an interface pointer. +// Because of the necessity of requesting ptrs or values, it's considered a bit too +// specialized for inclusion in reflectx itself. +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return errors.New("argument not a struct") + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + values[i] = new(interface{}) + continue + } + f := FieldByIndexes(v, traversal) + if ptrs { + values[i] = f.Addr().Interface() + } else { + values[i] = f.Interface() + } + } + return nil +} + +// A FieldInfo is metadata for a struct field. +type FieldInfo struct { + Index []int + Path string + Field reflect.StructField + Zero reflect.Value + Name string + Options map[string]string + Embedded bool + Children []*FieldInfo + Parent *FieldInfo +} + +// A StructMap is an index of field metadata for a struct. +type StructMap struct { + Tree *FieldInfo + Index []*FieldInfo + Paths map[string]*FieldInfo + Names map[string]*FieldInfo +} + +// GetByPath returns a *FieldInfo for a given string path. +func (f StructMap) GetByPath(path string) *FieldInfo { + return f.Paths[path] +} + +// GetByTraversal returns a *FieldInfo for a given integer path. It is +// analogous to reflect.FieldByIndex, but using the cached traversal +// rather than re-executing the reflect machinery each time. +func (f StructMap) GetByTraversal(index []int) *FieldInfo { + if len(index) == 0 { + return nil + } + + tree := f.Tree + for _, i := range index { + if i >= len(tree.Children) || tree.Children[i] == nil { + return nil + } + tree = tree.Children[i] + } + return tree +} + +// Mapper is a general purpose mapper of names to struct fields. A Mapper +// behaves like most marshallers in the standard library, obeying a field tag +// for name mapping but also providing a basic transform function. +type Mapper struct { + cache map[reflect.Type]*StructMap + tagName string + tagMapFunc func(string) string + mapFunc func(string) string + mutex sync.Mutex +} + +// NewMapper returns a new mapper using the tagName as its struct field tag. +// If tagName is the empty string, it is ignored. +func NewMapper(tagName string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + } +} + +// NewMapperTagFunc returns a new mapper which contains a mapper for field names +// AND a mapper for tag values. This is useful for tags like json which can +// have values like "name,omitempty". +func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: mapFunc, + tagMapFunc: tagMapFunc, + } +} + +// NewMapperFunc returns a new mapper which optionally obeys a field tag and +// a struct field name mapper func given by f. Tags will take precedence, but +// for any other field, the mapped name will be f(field.Name) +func NewMapperFunc(tagName string, f func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: f, + } +} + +// TypeMap returns a mapping of field strings to int slices representing +// the traversal down the struct to reach the field. +func (m *Mapper) TypeMap(t reflect.Type) *StructMap { + m.mutex.Lock() + mapping, ok := m.cache[t] + if !ok { + mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) + m.cache[t] = mapping + } + m.mutex.Unlock() + return mapping +} + +// FieldMap returns the mapper's mapping of field names to reflect values. Panics +// if v's Kind is not Struct, or v is not Indirectable to a struct kind. +func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + r[tagName] = FieldByIndexes(v, fi.Index) + } + return r +} + +// FieldByName returns a field by its mapped name as a reflect.Value. +// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. +// Returns zero Value if the name is not found. +func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + fi, ok := tm.Names[name] + if !ok { + return v + } + return FieldByIndexes(v, fi.Index) +} + +// FieldsByName returns a slice of values corresponding to the slice of names +// for the value. Panics if v's Kind is not Struct or v is not Indirectable +// to a struct Kind. Returns zero Value for each name not found. +func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + vals := make([]reflect.Value, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + vals = append(vals, *new(reflect.Value)) + } else { + vals = append(vals, FieldByIndexes(v, fi.Index)) + } + } + return vals +} + +// TraversalsByName returns a slice of int slices which represent the struct +// traversals for each mapped name. Panics if t is not a struct or Indirectable +// to a struct. Returns empty int slice for each name not found. +func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { + r := make([][]int, 0, len(names)) + m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { + if i == nil { + r = append(r, []int{}) + } else { + r = append(r, i) + } + + return nil + }) + return r +} + +// TraversalsByNameFunc traverses the mapped names and calls fn with the index of +// each name and the struct traversal represented by that name. Panics if t is not +// a struct or Indirectable to a struct. Returns the first error returned by fn or nil. +func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + for i, name := range names { + fi, ok := tm.Names[name] + if !ok { + if err := fn(i, nil); err != nil { + return err + } + } else { + if err := fn(i, fi.Index); err != nil { + return err + } + } + } + return nil +} + +// FieldByIndexes returns a value for the field given by the struct traversal +// for the given value. +func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + // if this is a pointer and it's nil, allocate a new value and set it + if v.Kind() == reflect.Ptr && v.IsNil() { + alloc := reflect.New(Deref(v.Type())) + v.Set(alloc) + } + if v.Kind() == reflect.Map && v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + } + return v +} + +// FieldByIndexesReadOnly returns a value for a particular struct traversal, +// but is not concerned with allocating nil pointers because the value is +// going to be used for reading and not setting. +func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + } + return v +} + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// -- helpers & utilities -- + +type kinder interface { + Kind() reflect.Kind +} + +// mustBe checks a value against a kind, panicing with a reflect.ValueError +// if the kind isn't that which is required. +func mustBe(v kinder, expected reflect.Kind) { + if k := v.Kind(); k != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: k}) + } +} + +// methodName returns the caller of the function calling methodName +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +type typeQueue struct { + t reflect.Type + fi *FieldInfo + pp string // Parent path +} + +// A copying append that creates a new slice each time. +func apnd(is []int, i int) []int { + x := make([]int, len(is)+1) + copy(x, is) + x[len(x)-1] = i + return x +} + +type mapf func(string) string + +// parseName parses the tag and the target name for the given field using +// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the +// field's name to a target name, and tagMapFunc for mapping the tag to +// a target name. +func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { + // first, set the fieldName to the field's name + fieldName = field.Name + // if a mapFunc is set, use that to override the fieldName + if mapFunc != nil { + fieldName = mapFunc(fieldName) + } + + // if there's no tag to look for, return the field name + if tagName == "" { + return "", fieldName + } + + // if this tag is not set using the normal convention in the tag, + // then return the fieldname.. this check is done because according + // to the reflect documentation: + // If the tag does not have the conventional format, + // the value returned by Get is unspecified. + // which doesn't sound great. + if !strings.Contains(string(field.Tag), tagName+":") { + return "", fieldName + } + + // at this point we're fairly sure that we have a tag, so lets pull it out + tag = field.Tag.Get(tagName) + + // if we have a mapper function, call it on the whole tag + // XXX: this is a change from the old version, which pulled out the name + // before the tagMapFunc could be run, but I think this is the right way + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + // finally, split the options from the name + parts := strings.Split(tag, ",") + fieldName = parts[0] + + return tag, fieldName +} + +// parseOptions parses options out of a tag string, skipping the name +func parseOptions(tag string) map[string]string { + parts := strings.Split(tag, ",") + options := make(map[string]string, len(parts)) + if len(parts) > 1 { + for _, opt := range parts[1:] { + // short circuit potentially expensive split op + if strings.Contains(opt, "=") { + kv := strings.Split(opt, "=") + options[kv[0]] = kv[1] + continue + } + options[opt] = "" + } + } + return options +} + +// getMapping returns a mapping for the t type, using the tagName, mapFunc and +// tagMapFunc to determine the canonical names of fields. +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { + m := []*FieldInfo{} + + root := &FieldInfo{} + queue := []typeQueue{} + queue = append(queue, typeQueue{Deref(t), root, ""}) + +QueueLoop: + for len(queue) != 0 { + // pop the first item off of the queue + tq := queue[0] + queue = queue[1:] + + // ignore recursive field + for p := tq.fi.Parent; p != nil; p = p.Parent { + if tq.fi.Field.Type == p.Field.Type { + continue QueueLoop + } + } + + nChildren := 0 + if tq.t.Kind() == reflect.Struct { + nChildren = tq.t.NumField() + } + tq.fi.Children = make([]*FieldInfo, nChildren) + + // iterate through all of its fields + for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + + f := tq.t.Field(fieldPos) + + // parse the tag and the target name using the mapping options for this field + tag, name := parseName(f, tagName, mapFunc, tagMapFunc) + + // if the name is "-", disabled via a tag, skip it + if name == "-" { + continue + } + + fi := FieldInfo{ + Field: f, + Name: name, + Zero: reflect.New(f.Type).Elem(), + Options: parseOptions(tag), + } + + // if the path is empty this path is just the name + if tq.pp == "" { + fi.Path = fi.Name + } else { + fi.Path = tq.pp + "." + fi.Name + } + + // skip unexported fields + if len(f.PkgPath) != 0 && !f.Anonymous { + continue + } + + // bfs search of anonymous embedded structs + if f.Anonymous { + pp := tq.pp + if tag != "" { + pp = fi.Path + } + + fi.Embedded = true + fi.Index = apnd(tq.fi.Index, fieldPos) + nChildren := 0 + ft := Deref(f.Type) + if ft.Kind() == reflect.Struct { + nChildren = ft.NumField() + } + fi.Children = make([]*FieldInfo, nChildren) + queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) + } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) + queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) + } + + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Parent = tq.fi + tq.fi.Children[fieldPos] = &fi + m = append(m, &fi) + } + } + + flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} + for _, fi := range flds.Index { + // check if nothing has already been pushed with the same path + // sometimes you can choose to override a type using embedded struct + fld, ok := flds.Paths[fi.Path] + if !ok || fld.Embedded { + flds.Paths[fi.Path] = fi + if fi.Name != "" && !fi.Embedded { + flds.Names[fi.Path] = fi + } + } + } + + return flds +}