2024-01-21 22:52:20 +08:00
package ioc
import (
2024-01-26 17:17:26 +08:00
"context"
2024-01-21 22:52:20 +08:00
"errors"
"fmt"
2024-12-16 23:29:18 +08:00
"mayfly-go/pkg/contextx"
2024-01-21 22:52:20 +08:00
"mayfly-go/pkg/logx"
2024-01-26 17:17:26 +08:00
"mayfly-go/pkg/utils/collx"
2024-01-21 22:52:20 +08:00
"mayfly-go/pkg/utils/structx"
"reflect"
"strings"
2024-01-26 17:17:26 +08:00
"sync"
"golang.org/x/sync/errgroup"
2024-01-21 22:52:20 +08:00
)
2024-12-13 12:15:24 +08:00
const (
InjectTag = "inject"
2025-04-18 22:07:37 +08:00
InjectMethodPrefix = "Inject"
2024-12-13 12:15:24 +08:00
ByTypeComponentName = "T" // 根据类型注入的组件名
)
2024-01-21 22:52:20 +08:00
// 容器
type Container struct {
2024-01-26 17:17:26 +08:00
mu sync . RWMutex
2024-01-21 22:52:20 +08:00
components map [ string ] * Component
}
func NewContainer ( ) * Container {
return & Container {
components : make ( map [ string ] * Component ) ,
}
}
// 注册实例至实例容器
func ( c * Container ) Register ( bean any , opts ... ComponentOption ) {
2024-01-26 17:17:26 +08:00
c . mu . Lock ( )
defer c . mu . Unlock ( )
2024-01-21 22:52:20 +08:00
component := NewComponent ( bean , opts ... )
componentName := component . Name
2024-12-16 23:29:18 +08:00
indirectCType := structx . IndirectType ( component . GetType ( ) )
// 组件名为空,则取组件类型`包名路径.名称`作为组件名
2024-01-21 22:52:20 +08:00
if componentName == "" {
2024-12-16 23:29:18 +08:00
componentName = fmt . Sprintf ( "%s.%s" , indirectCType . PkgPath ( ) , indirectCType . Name ( ) )
2024-01-21 22:52:20 +08:00
component . Name = componentName
}
2024-01-26 17:17:26 +08:00
if _ , ok := c . components [ componentName ] ; ok {
2024-12-13 12:15:24 +08:00
logx . Warnf ( "the component name [%s] has been registered to the container. Repeat the registration..." , componentName )
2024-01-21 22:52:20 +08:00
}
2024-12-13 12:15:24 +08:00
logx . Debugf ( "ioc register : %s = %s.%s" , componentName , indirectCType . PkgPath ( ) , indirectCType . Name ( ) )
2024-01-21 22:52:20 +08:00
c . components [ componentName ] = component
}
2024-12-13 12:15:24 +08:00
// Inject 注册对象实例的字段含有注入标签或者Setter方法, 则注入对应组件实例
2024-01-21 22:52:20 +08:00
func ( c * Container ) Inject ( obj any ) error {
objValue := reflect . ValueOf ( obj )
2024-01-22 11:35:28 +08:00
if structx . Indirect ( objValue ) . Kind ( ) != reflect . Struct {
return nil
}
2024-12-16 23:29:18 +08:00
ctx := contextx . NewTraceId ( )
if err := c . injectWithField ( ctx , objValue ) ; err != nil {
2024-01-21 22:52:20 +08:00
return err
}
2024-12-16 23:29:18 +08:00
if err := c . injectWithMethod ( ctx , objValue ) ; err != nil {
2024-01-21 22:52:20 +08:00
return err
}
return nil
}
// 对所有组件实例执行Inject。即为实例字段注入依赖的组件实例
func ( c * Container ) InjectComponents ( ) error {
2025-04-18 22:07:37 +08:00
componentsGroups := collx . ArraySplit [ * Component ] ( collx . MapValues ( c . components ) , 10 )
2024-01-26 17:17:26 +08:00
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
errGroup , _ := errgroup . WithContext ( ctx )
for _ , components := range componentsGroups {
errGroup . Go ( func ( ) error {
for _ , v := range components {
if err := c . Inject ( v . Value ) ; err != nil {
cancel ( ) // 取消所有协程的执行
return err
}
}
return nil
} )
2024-01-21 22:52:20 +08:00
}
2024-01-26 17:17:26 +08:00
if err := errGroup . Wait ( ) ; err != nil {
return err
}
2024-01-21 22:52:20 +08:00
return nil
}
2024-12-16 23:29:18 +08:00
// Get 根据组件实例名,获取对应实例信息
2024-01-21 22:52:20 +08:00
func ( c * Container ) Get ( name string ) ( any , error ) {
2024-12-16 23:29:18 +08:00
if comp , err := c . GetComponent ( name ) ; err == nil {
return comp . Value , nil
} else {
return nil , err
}
}
// GetComponent 根据组件名,获取对应组件信息
func ( c * Container ) GetComponent ( name string ) ( * Component , error ) {
2024-01-26 17:17:26 +08:00
c . mu . RLock ( )
defer c . mu . RUnlock ( )
2024-01-21 22:52:20 +08:00
component , ok := c . components [ name ]
if ! ok {
return nil , errors . New ( "component not found: " + name )
}
2024-12-16 23:29:18 +08:00
return component , nil
2024-01-21 22:52:20 +08:00
}
2024-12-13 12:15:24 +08:00
// GetByType 根据组件实例类型获取组件实例
func ( c * Container ) GetByType ( fieldType reflect . Type ) ( any , error ) {
2024-12-16 23:29:18 +08:00
if comp , err := c . GetComponentByType ( fieldType ) ; err == nil {
return comp . Value , nil
} else {
return nil , err
}
}
// GetComponentByType 根据组件实例类型获取组件信息
func ( c * Container ) GetComponentByType ( fieldType reflect . Type ) ( * Component , error ) {
2024-12-13 12:15:24 +08:00
c . mu . RLock ( )
defer c . mu . RUnlock ( )
for _ , component := range c . components {
2024-12-16 23:29:18 +08:00
if component . GetType ( ) . AssignableTo ( fieldType ) {
return component , nil
2024-12-13 12:15:24 +08:00
}
}
return nil , errors . New ( "component type not found: " + fmt . Sprintf ( "%s.%s" , fieldType . PkgPath ( ) , fieldType . Name ( ) ) )
}
2024-12-16 23:29:18 +08:00
// GetBeansByType 根据组件实例类型获取所有对应类型的组件实例
func ( c * Container ) GetBeansByType ( fieldType reflect . Type ) [ ] any {
return collx . ArrayMap ( c . GetComponentsByType ( fieldType ) , func ( comp * Component ) any { return comp . Value } )
}
// GetComponentsByType 根据组件实例类型获取指定类型的所有组件信息
func ( c * Container ) GetComponentsByType ( fieldType reflect . Type ) [ ] * Component {
c . mu . RLock ( )
defer c . mu . RUnlock ( )
components := make ( [ ] * Component , 0 )
for _ , component := range c . components {
if component . GetType ( ) . AssignableTo ( fieldType ) {
components = append ( components , component )
}
}
return components
}
2025-04-18 22:07:37 +08:00
// injectWithField 根据实例字段的inject:"xxx"标签进行依赖注入
2024-12-16 23:29:18 +08:00
func ( c * Container ) injectWithField ( context context . Context , objValue reflect . Value ) error {
2024-01-21 22:52:20 +08:00
objValue = structx . Indirect ( objValue )
objType := objValue . Type ( )
2024-12-16 23:29:18 +08:00
logx . DebugfContext ( context , "start ioc inject with field: %s.%s" , objType . PkgPath ( ) , objType . Name ( ) )
2024-12-13 12:15:24 +08:00
2024-01-21 22:52:20 +08:00
for i := 0 ; i < objType . NumField ( ) ; i ++ {
field := objType . Field ( i )
2024-12-13 12:15:24 +08:00
fieldValue := objValue . Field ( i )
2024-01-21 22:52:20 +08:00
2024-12-13 12:15:24 +08:00
// 检查字段是否是通过组合包含在当前结构体中的,即嵌套结构体
if field . Anonymous && structx . IndirectType ( field . Type ) . Kind ( ) == reflect . Struct {
2024-12-16 23:29:18 +08:00
c . injectWithField ( context , fieldValue )
2024-01-21 22:52:20 +08:00
continue
}
2024-12-13 12:15:24 +08:00
componentName , ok := field . Tag . Lookup ( InjectTag )
if ! ok {
continue
2024-01-21 22:52:20 +08:00
}
2024-12-13 12:15:24 +08:00
// 如果组件名为指定的根据类型注入值,则根据类型注入
if componentName == ByTypeComponentName {
2024-12-16 23:29:18 +08:00
if err := c . injectByType ( context , objType , field , fieldValue ) ; err != nil {
2024-12-13 12:15:24 +08:00
return err
}
continue
2024-01-21 22:52:20 +08:00
}
2024-12-16 23:29:18 +08:00
if err := c . injectByName ( context , objType , field , fieldValue , componentName ) ; err != nil {
2024-12-13 12:15:24 +08:00
return err
2024-01-21 22:52:20 +08:00
}
2024-12-13 12:15:24 +08:00
}
return nil
}
// injectByName 根据实例组件名进行依赖注入
2024-12-16 23:29:18 +08:00
func ( c * Container ) injectByName ( context context . Context , structType reflect . Type , field reflect . StructField , fieldValue reflect . Value , componentName string ) error {
2024-12-13 12:15:24 +08:00
// inject tag字段名为空则默认为字段名
if componentName == "" {
componentName = field . Name
}
2024-12-16 23:29:18 +08:00
injectInfo := fmt . Sprintf ( "ioc field inject by name => [%s -> %s.%s#%s]" , componentName , structType . PkgPath ( ) , structType . Name ( ) , field . Name )
2024-12-13 12:15:24 +08:00
2024-12-16 23:29:18 +08:00
component , err := c . GetComponent ( componentName )
2024-12-13 12:15:24 +08:00
if err != nil {
return fmt . Errorf ( "%s error: %s" , injectInfo , err . Error ( ) )
}
// 判断字段类型与需要注入的组件类型是否为可赋值关系
2024-12-16 23:29:18 +08:00
componentType := component . GetType ( )
2024-12-13 12:15:24 +08:00
if ! componentType . AssignableTo ( field . Type ) {
indirectComponentType := structx . IndirectType ( componentType )
return fmt . Errorf ( "%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)" , injectInfo , field . Type . PkgPath ( ) , field . Type . Name ( ) , indirectComponentType . PkgPath ( ) , indirectComponentType . Name ( ) )
}
2025-04-18 22:07:37 +08:00
logx . DebugfContext ( context , "ioc field inject by name => [%s (%s) -> %s.%s#%s]" , componentName , getComponentValueDesc ( componentType ) , structType . PkgPath ( ) , structType . Name ( ) , field . Name )
2024-12-16 23:29:18 +08:00
if err := setFieldValue ( fieldValue , component . Value ) ; err != nil {
2024-12-13 12:15:24 +08:00
return fmt . Errorf ( "%s error: %s" , injectInfo , err . Error ( ) )
}
2024-01-23 19:30:28 +08:00
2024-12-13 12:15:24 +08:00
return nil
}
// injectByType 根据实例类型进行依赖注入
2024-12-16 23:29:18 +08:00
func ( c * Container ) injectByType ( context context . Context , structType reflect . Type , field reflect . StructField , fieldValue reflect . Value ) error {
2024-12-13 12:15:24 +08:00
fieldType := field . Type
2024-12-16 23:29:18 +08:00
injectInfo := fmt . Sprintf ( "ioc field inject by type => [%s.%s -> %s.%s#%s]" , fieldType . PkgPath ( ) , fieldType . Name ( ) , structType . PkgPath ( ) , structType . Name ( ) , field . Name )
2024-12-13 12:15:24 +08:00
2024-12-16 23:29:18 +08:00
component , err := c . GetComponentByType ( fieldType )
2024-12-13 12:15:24 +08:00
if err != nil {
return fmt . Errorf ( "%s error: %s" , injectInfo , err . Error ( ) )
}
2025-04-18 22:07:37 +08:00
logx . DebugfContext ( context , "ioc field inject by type => [%s.%s (%s) -> %s.%s#%s]" , fieldType . PkgPath ( ) , fieldType . Name ( ) , getComponentValueDesc ( component . GetType ( ) ) , structType . PkgPath ( ) , structType . Name ( ) , field . Name )
2024-12-16 23:29:18 +08:00
if err := setFieldValue ( fieldValue , component . Value ) ; err != nil {
2024-12-13 12:15:24 +08:00
return fmt . Errorf ( "%s error: %s" , injectInfo , err . Error ( ) )
2024-01-21 22:52:20 +08:00
}
return nil
}
// 根据实例的Inject方法进行依赖注入
2024-12-16 23:29:18 +08:00
func ( c * Container ) injectWithMethod ( context context . Context , objValue reflect . Value ) error {
2024-01-21 22:52:20 +08:00
objType := objValue . Type ( )
for i := 0 ; i < objType . NumMethod ( ) ; i ++ {
method := objType . Method ( i )
methodName := method . Name
2025-04-18 22:07:37 +08:00
// 不是以指定方法名前缀开头的函数,则默认跳过
if ! strings . HasPrefix ( methodName , InjectMethodPrefix ) {
2024-01-21 22:52:20 +08:00
continue
}
// 获取组件名, InjectTestApp -> TestApp
componentName := methodName [ 6 : ]
injectInfo := fmt . Sprintf ( "ioc method inject [%s.%s#%s(%s)]" , objType . Elem ( ) . PkgPath ( ) , objType . Elem ( ) . Name ( ) , methodName , componentName )
2024-12-16 23:29:18 +08:00
logx . DebugfContext ( context , injectInfo )
2024-01-21 22:52:20 +08:00
if method . Type . NumIn ( ) != 2 {
2024-12-16 23:29:18 +08:00
logx . WarnfContext ( context , "%s error: the method cannot be injected if it does not have one parameter" , injectInfo )
2024-01-21 22:52:20 +08:00
continue
}
component , err := c . Get ( componentName )
if err != nil {
return fmt . Errorf ( "%s error: %s" , injectInfo , err . Error ( ) )
}
componentType := reflect . TypeOf ( component )
// 期望的组件类型,即参数入参类型
expectedComponentType := method . Type . In ( 1 )
if ! componentType . AssignableTo ( expectedComponentType ) {
componentType = structx . IndirectType ( componentType )
2024-12-13 12:15:24 +08:00
return fmt . Errorf ( "%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)" , injectInfo , expectedComponentType . PkgPath ( ) , expectedComponentType . Name ( ) , componentType . PkgPath ( ) , componentType . Name ( ) )
2024-01-21 22:52:20 +08:00
}
method . Func . Call ( [ ] reflect . Value { objValue , reflect . ValueOf ( component ) } )
}
return nil
}
2024-12-13 12:15:24 +08:00
func setFieldValue ( fieldValue reflect . Value , component any ) error {
if ! fieldValue . IsValid ( ) || ! fieldValue . CanSet ( ) {
// 不可导出变量处理
fieldPtrValue := reflect . NewAt ( fieldValue . Type ( ) , fieldValue . Addr ( ) . UnsafePointer ( ) )
fieldValue = fieldPtrValue . Elem ( )
if ! fieldValue . IsValid ( ) || ! fieldValue . CanSet ( ) {
return errors . New ( "the field is invalid or a non-exportable type" )
}
}
fieldValue . Set ( reflect . ValueOf ( component ) )
return nil
}
2024-12-16 23:29:18 +08:00
func getComponentValueDesc ( componentValueType reflect . Type ) string {
if componentValueType . Kind ( ) == reflect . Ptr {
componentValueType = structx . IndirectType ( componentValueType )
return fmt . Sprintf ( "*%s.%s" , componentValueType . PkgPath ( ) , componentValueType . Name ( ) )
}
return fmt . Sprintf ( "%s.%s" , componentValueType . PkgPath ( ) , componentValueType . Name ( ) )
}