mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-04 00:10:25 +08:00
320 lines
9.9 KiB
Go
320 lines
9.9 KiB
Go
package ioc
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"mayfly-go/pkg/contextx"
|
||
"mayfly-go/pkg/logx"
|
||
"mayfly-go/pkg/utils/collx"
|
||
"mayfly-go/pkg/utils/structx"
|
||
"reflect"
|
||
"strings"
|
||
"sync"
|
||
|
||
"golang.org/x/sync/errgroup"
|
||
)
|
||
|
||
const (
|
||
InjectTag = "inject"
|
||
InjectMethodPrefix = "Inject"
|
||
ByTypeComponentName = "T" // 根据类型注入的组件名
|
||
)
|
||
|
||
// 容器
|
||
type Container struct {
|
||
mu sync.RWMutex
|
||
components map[string]*Component
|
||
}
|
||
|
||
func NewContainer() *Container {
|
||
return &Container{
|
||
components: make(map[string]*Component),
|
||
}
|
||
}
|
||
|
||
// 注册实例至实例容器
|
||
func (c *Container) Register(bean any, opts ...ComponentOption) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
component := NewComponent(bean, opts...)
|
||
|
||
componentName := component.Name
|
||
indirectCType := structx.IndirectType(component.GetType())
|
||
// 组件名为空,则取组件类型`包名路径.名称`作为组件名
|
||
if componentName == "" {
|
||
componentName = fmt.Sprintf("%s.%s", indirectCType.PkgPath(), indirectCType.Name())
|
||
component.Name = componentName
|
||
}
|
||
|
||
if _, ok := c.components[componentName]; ok {
|
||
logx.Warnf("the component name [%s] has been registered to the container. Repeat the registration...", componentName)
|
||
}
|
||
|
||
logx.Debugf("ioc register : %s = %s.%s", componentName, indirectCType.PkgPath(), indirectCType.Name())
|
||
c.components[componentName] = component
|
||
}
|
||
|
||
// Inject 注册对象实例的字段含有注入标签或者Setter方法,则注入对应组件实例
|
||
func (c *Container) Inject(obj any) error {
|
||
objValue := reflect.ValueOf(obj)
|
||
if structx.Indirect(objValue).Kind() != reflect.Struct {
|
||
return nil
|
||
}
|
||
|
||
ctx := contextx.NewTraceId()
|
||
if err := c.injectWithField(ctx, objValue); err != nil {
|
||
return err
|
||
}
|
||
if err := c.injectWithMethod(ctx, objValue); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 对所有组件实例执行Inject。即为实例字段注入依赖的组件实例
|
||
func (c *Container) InjectComponents() error {
|
||
componentsGroups := collx.ArraySplit[*Component](collx.MapValues(c.components), 10)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
errGroup, _ := errgroup.WithContext(ctx)
|
||
|
||
for _, components := range componentsGroups {
|
||
errGroup.Go(func() error {
|
||
for _, v := range components {
|
||
if err := c.Inject(v.Value); err != nil {
|
||
cancel() // 取消所有协程的执行
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
if err := errGroup.Wait(); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Get 根据组件实例名,获取对应实例信息
|
||
func (c *Container) Get(name string) (any, error) {
|
||
if comp, err := c.GetComponent(name); err == nil {
|
||
return comp.Value, nil
|
||
} else {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// GetComponent 根据组件名,获取对应组件信息
|
||
func (c *Container) GetComponent(name string) (*Component, error) {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
component, ok := c.components[name]
|
||
if !ok {
|
||
return nil, errors.New("component not found: " + name)
|
||
}
|
||
|
||
return component, nil
|
||
}
|
||
|
||
// GetByType 根据组件实例类型获取组件实例
|
||
func (c *Container) GetByType(fieldType reflect.Type) (any, error) {
|
||
if comp, err := c.GetComponentByType(fieldType); err == nil {
|
||
return comp.Value, nil
|
||
} else {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// GetComponentByType 根据组件实例类型获取组件信息
|
||
func (c *Container) GetComponentByType(fieldType reflect.Type) (*Component, error) {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
for _, component := range c.components {
|
||
if component.GetType().AssignableTo(fieldType) {
|
||
return component, nil
|
||
}
|
||
}
|
||
|
||
return nil, errors.New("component type not found: " + fmt.Sprintf("%s.%s", fieldType.PkgPath(), fieldType.Name()))
|
||
}
|
||
|
||
// GetBeansByType 根据组件实例类型获取所有对应类型的组件实例
|
||
func (c *Container) GetBeansByType(fieldType reflect.Type) []any {
|
||
return collx.ArrayMap(c.GetComponentsByType(fieldType), func(comp *Component) any { return comp.Value })
|
||
}
|
||
|
||
// GetComponentsByType 根据组件实例类型获取指定类型的所有组件信息
|
||
func (c *Container) GetComponentsByType(fieldType reflect.Type) []*Component {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
components := make([]*Component, 0)
|
||
for _, component := range c.components {
|
||
if component.GetType().AssignableTo(fieldType) {
|
||
components = append(components, component)
|
||
}
|
||
}
|
||
|
||
return components
|
||
}
|
||
|
||
// injectWithField 根据实例字段的inject:"xxx"标签进行依赖注入
|
||
func (c *Container) injectWithField(context context.Context, objValue reflect.Value) error {
|
||
objValue = structx.Indirect(objValue)
|
||
objType := objValue.Type()
|
||
|
||
logx.DebugfContext(context, "start ioc inject with field: %s.%s", objType.PkgPath(), objType.Name())
|
||
|
||
for i := 0; i < objType.NumField(); i++ {
|
||
field := objType.Field(i)
|
||
fieldValue := objValue.Field(i)
|
||
|
||
// 检查字段是否是通过组合包含在当前结构体中的,即嵌套结构体
|
||
if field.Anonymous && structx.IndirectType(field.Type).Kind() == reflect.Struct {
|
||
c.injectWithField(context, fieldValue)
|
||
continue
|
||
}
|
||
|
||
componentName, ok := field.Tag.Lookup(InjectTag)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 如果组件名为指定的根据类型注入值,则根据类型注入
|
||
if componentName == ByTypeComponentName {
|
||
if err := c.injectByType(context, objType, field, fieldValue); err != nil {
|
||
return err
|
||
}
|
||
continue
|
||
}
|
||
|
||
if err := c.injectByName(context, objType, field, fieldValue, componentName); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// injectByName 根据实例组件名进行依赖注入
|
||
func (c *Container) injectByName(context context.Context, structType reflect.Type, field reflect.StructField, fieldValue reflect.Value, componentName string) error {
|
||
// inject tag字段名为空则默认为字段名
|
||
if componentName == "" {
|
||
componentName = field.Name
|
||
}
|
||
|
||
injectInfo := fmt.Sprintf("ioc field inject by name => [%s -> %s.%s#%s]", componentName, structType.PkgPath(), structType.Name(), field.Name)
|
||
|
||
component, err := c.GetComponent(componentName)
|
||
if err != nil {
|
||
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
|
||
}
|
||
|
||
// 判断字段类型与需要注入的组件类型是否为可赋值关系
|
||
componentType := component.GetType()
|
||
if !componentType.AssignableTo(field.Type) {
|
||
indirectComponentType := structx.IndirectType(componentType)
|
||
return fmt.Errorf("%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)", injectInfo, field.Type.PkgPath(), field.Type.Name(), indirectComponentType.PkgPath(), indirectComponentType.Name())
|
||
}
|
||
|
||
logx.DebugfContext(context, "ioc field inject by name => [%s (%s) -> %s.%s#%s]", componentName, getComponentValueDesc(componentType), structType.PkgPath(), structType.Name(), field.Name)
|
||
|
||
if err := setFieldValue(fieldValue, component.Value); err != nil {
|
||
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// injectByType 根据实例类型进行依赖注入
|
||
func (c *Container) injectByType(context context.Context, structType reflect.Type, field reflect.StructField, fieldValue reflect.Value) error {
|
||
fieldType := field.Type
|
||
|
||
injectInfo := fmt.Sprintf("ioc field inject by type => [%s.%s -> %s.%s#%s]", fieldType.PkgPath(), fieldType.Name(), structType.PkgPath(), structType.Name(), field.Name)
|
||
|
||
component, err := c.GetComponentByType(fieldType)
|
||
if err != nil {
|
||
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
|
||
}
|
||
|
||
logx.DebugfContext(context, "ioc field inject by type => [%s.%s (%s) -> %s.%s#%s]", fieldType.PkgPath(), fieldType.Name(), getComponentValueDesc(component.GetType()), structType.PkgPath(), structType.Name(), field.Name)
|
||
|
||
if err := setFieldValue(fieldValue, component.Value); err != nil {
|
||
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 根据实例的Inject方法进行依赖注入
|
||
func (c *Container) injectWithMethod(context context.Context, objValue reflect.Value) error {
|
||
objType := objValue.Type()
|
||
|
||
for i := 0; i < objType.NumMethod(); i++ {
|
||
method := objType.Method(i)
|
||
|
||
methodName := method.Name
|
||
// 不是以指定方法名前缀开头的函数,则默认跳过
|
||
if !strings.HasPrefix(methodName, InjectMethodPrefix) {
|
||
continue
|
||
}
|
||
|
||
// 获取组件名,InjectTestApp -> TestApp
|
||
componentName := methodName[6:]
|
||
|
||
injectInfo := fmt.Sprintf("ioc method inject [%s.%s#%s(%s)]", objType.Elem().PkgPath(), objType.Elem().Name(), methodName, componentName)
|
||
logx.DebugfContext(context, injectInfo)
|
||
|
||
if method.Type.NumIn() != 2 {
|
||
logx.WarnfContext(context, "%s error: the method cannot be injected if it does not have one parameter", injectInfo)
|
||
continue
|
||
}
|
||
|
||
component, err := c.Get(componentName)
|
||
if err != nil {
|
||
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
|
||
}
|
||
|
||
componentType := reflect.TypeOf(component)
|
||
// 期望的组件类型,即参数入参类型
|
||
expectedComponentType := method.Type.In(1)
|
||
if !componentType.AssignableTo(expectedComponentType) {
|
||
componentType = structx.IndirectType(componentType)
|
||
return fmt.Errorf("%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)", injectInfo, expectedComponentType.PkgPath(), expectedComponentType.Name(), componentType.PkgPath(), componentType.Name())
|
||
}
|
||
|
||
method.Func.Call([]reflect.Value{objValue, reflect.ValueOf(component)})
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func setFieldValue(fieldValue reflect.Value, component any) error {
|
||
if !fieldValue.IsValid() || !fieldValue.CanSet() {
|
||
// 不可导出变量处理
|
||
fieldPtrValue := reflect.NewAt(fieldValue.Type(), fieldValue.Addr().UnsafePointer())
|
||
fieldValue = fieldPtrValue.Elem()
|
||
if !fieldValue.IsValid() || !fieldValue.CanSet() {
|
||
return errors.New("the field is invalid or a non-exportable type")
|
||
}
|
||
}
|
||
|
||
fieldValue.Set(reflect.ValueOf(component))
|
||
return nil
|
||
}
|
||
|
||
func getComponentValueDesc(componentValueType reflect.Type) string {
|
||
if componentValueType.Kind() == reflect.Ptr {
|
||
componentValueType = structx.IndirectType(componentValueType)
|
||
return fmt.Sprintf("*%s.%s", componentValueType.PkgPath(), componentValueType.Name())
|
||
}
|
||
return fmt.Sprintf("%s.%s", componentValueType.PkgPath(), componentValueType.Name())
|
||
}
|