Files
mayfly-go/server/pkg/ioc/ioc.go

320 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ioc
import (
"context"
"errors"
"fmt"
"mayfly-go/pkg/contextx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/structx"
"reflect"
"strings"
"sync"
"golang.org/x/sync/errgroup"
)
const (
InjectTag = "inject"
InjectMethodPrefix = "Inject"
ByTypeComponentName = "T" // 根据类型注入的组件名
)
// 容器
type Container struct {
mu sync.RWMutex
components map[string]*Component
}
func NewContainer() *Container {
return &Container{
components: make(map[string]*Component),
}
}
// 注册实例至实例容器
func (c *Container) Register(bean any, opts ...ComponentOption) {
c.mu.Lock()
defer c.mu.Unlock()
component := NewComponent(bean, opts...)
componentName := component.Name
indirectCType := structx.IndirectType(component.GetType())
// 组件名为空,则取组件类型`包名路径.名称`作为组件名
if componentName == "" {
componentName = fmt.Sprintf("%s.%s", indirectCType.PkgPath(), indirectCType.Name())
component.Name = componentName
}
if _, ok := c.components[componentName]; ok {
logx.Warnf("the component name [%s] has been registered to the container. Repeat the registration...", componentName)
}
logx.Debugf("ioc register : %s = %s.%s", componentName, indirectCType.PkgPath(), indirectCType.Name())
c.components[componentName] = component
}
// Inject 注册对象实例的字段含有注入标签或者Setter方法则注入对应组件实例
func (c *Container) Inject(obj any) error {
objValue := reflect.ValueOf(obj)
if structx.Indirect(objValue).Kind() != reflect.Struct {
return nil
}
ctx := contextx.NewTraceId()
if err := c.injectWithField(ctx, objValue); err != nil {
return err
}
if err := c.injectWithMethod(ctx, objValue); err != nil {
return err
}
return nil
}
// 对所有组件实例执行Inject。即为实例字段注入依赖的组件实例
func (c *Container) InjectComponents() error {
componentsGroups := collx.ArraySplit[*Component](collx.MapValues(c.components), 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errGroup, _ := errgroup.WithContext(ctx)
for _, components := range componentsGroups {
errGroup.Go(func() error {
for _, v := range components {
if err := c.Inject(v.Value); err != nil {
cancel() // 取消所有协程的执行
return err
}
}
return nil
})
}
if err := errGroup.Wait(); err != nil {
return err
}
return nil
}
// Get 根据组件实例名,获取对应实例信息
func (c *Container) Get(name string) (any, error) {
if comp, err := c.GetComponent(name); err == nil {
return comp.Value, nil
} else {
return nil, err
}
}
// GetComponent 根据组件名,获取对应组件信息
func (c *Container) GetComponent(name string) (*Component, error) {
c.mu.RLock()
defer c.mu.RUnlock()
component, ok := c.components[name]
if !ok {
return nil, errors.New("component not found: " + name)
}
return component, nil
}
// GetByType 根据组件实例类型获取组件实例
func (c *Container) GetByType(fieldType reflect.Type) (any, error) {
if comp, err := c.GetComponentByType(fieldType); err == nil {
return comp.Value, nil
} else {
return nil, err
}
}
// GetComponentByType 根据组件实例类型获取组件信息
func (c *Container) GetComponentByType(fieldType reflect.Type) (*Component, error) {
c.mu.RLock()
defer c.mu.RUnlock()
for _, component := range c.components {
if component.GetType().AssignableTo(fieldType) {
return component, nil
}
}
return nil, errors.New("component type not found: " + fmt.Sprintf("%s.%s", fieldType.PkgPath(), fieldType.Name()))
}
// GetBeansByType 根据组件实例类型获取所有对应类型的组件实例
func (c *Container) GetBeansByType(fieldType reflect.Type) []any {
return collx.ArrayMap(c.GetComponentsByType(fieldType), func(comp *Component) any { return comp.Value })
}
// GetComponentsByType 根据组件实例类型获取指定类型的所有组件信息
func (c *Container) GetComponentsByType(fieldType reflect.Type) []*Component {
c.mu.RLock()
defer c.mu.RUnlock()
components := make([]*Component, 0)
for _, component := range c.components {
if component.GetType().AssignableTo(fieldType) {
components = append(components, component)
}
}
return components
}
// injectWithField 根据实例字段的inject:"xxx"标签进行依赖注入
func (c *Container) injectWithField(context context.Context, objValue reflect.Value) error {
objValue = structx.Indirect(objValue)
objType := objValue.Type()
logx.DebugfContext(context, "start ioc inject with field: %s.%s", objType.PkgPath(), objType.Name())
for i := 0; i < objType.NumField(); i++ {
field := objType.Field(i)
fieldValue := objValue.Field(i)
// 检查字段是否是通过组合包含在当前结构体中的,即嵌套结构体
if field.Anonymous && structx.IndirectType(field.Type).Kind() == reflect.Struct {
c.injectWithField(context, fieldValue)
continue
}
componentName, ok := field.Tag.Lookup(InjectTag)
if !ok {
continue
}
// 如果组件名为指定的根据类型注入值,则根据类型注入
if componentName == ByTypeComponentName {
if err := c.injectByType(context, objType, field, fieldValue); err != nil {
return err
}
continue
}
if err := c.injectByName(context, objType, field, fieldValue, componentName); err != nil {
return err
}
}
return nil
}
// injectByName 根据实例组件名进行依赖注入
func (c *Container) injectByName(context context.Context, structType reflect.Type, field reflect.StructField, fieldValue reflect.Value, componentName string) error {
// inject tag字段名为空则默认为字段名
if componentName == "" {
componentName = field.Name
}
injectInfo := fmt.Sprintf("ioc field inject by name => [%s -> %s.%s#%s]", componentName, structType.PkgPath(), structType.Name(), field.Name)
component, err := c.GetComponent(componentName)
if err != nil {
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
}
// 判断字段类型与需要注入的组件类型是否为可赋值关系
componentType := component.GetType()
if !componentType.AssignableTo(field.Type) {
indirectComponentType := structx.IndirectType(componentType)
return fmt.Errorf("%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)", injectInfo, field.Type.PkgPath(), field.Type.Name(), indirectComponentType.PkgPath(), indirectComponentType.Name())
}
logx.DebugfContext(context, "ioc field inject by name => [%s (%s) -> %s.%s#%s]", componentName, getComponentValueDesc(componentType), structType.PkgPath(), structType.Name(), field.Name)
if err := setFieldValue(fieldValue, component.Value); err != nil {
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
}
return nil
}
// injectByType 根据实例类型进行依赖注入
func (c *Container) injectByType(context context.Context, structType reflect.Type, field reflect.StructField, fieldValue reflect.Value) error {
fieldType := field.Type
injectInfo := fmt.Sprintf("ioc field inject by type => [%s.%s -> %s.%s#%s]", fieldType.PkgPath(), fieldType.Name(), structType.PkgPath(), structType.Name(), field.Name)
component, err := c.GetComponentByType(fieldType)
if err != nil {
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
}
logx.DebugfContext(context, "ioc field inject by type => [%s.%s (%s) -> %s.%s#%s]", fieldType.PkgPath(), fieldType.Name(), getComponentValueDesc(component.GetType()), structType.PkgPath(), structType.Name(), field.Name)
if err := setFieldValue(fieldValue, component.Value); err != nil {
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
}
return nil
}
// 根据实例的Inject方法进行依赖注入
func (c *Container) injectWithMethod(context context.Context, objValue reflect.Value) error {
objType := objValue.Type()
for i := 0; i < objType.NumMethod(); i++ {
method := objType.Method(i)
methodName := method.Name
// 不是以指定方法名前缀开头的函数,则默认跳过
if !strings.HasPrefix(methodName, InjectMethodPrefix) {
continue
}
// 获取组件名InjectTestApp -> TestApp
componentName := methodName[6:]
injectInfo := fmt.Sprintf("ioc method inject [%s.%s#%s(%s)]", objType.Elem().PkgPath(), objType.Elem().Name(), methodName, componentName)
logx.DebugfContext(context, injectInfo)
if method.Type.NumIn() != 2 {
logx.WarnfContext(context, "%s error: the method cannot be injected if it does not have one parameter", injectInfo)
continue
}
component, err := c.Get(componentName)
if err != nil {
return fmt.Errorf("%s error: %s", injectInfo, err.Error())
}
componentType := reflect.TypeOf(component)
// 期望的组件类型,即参数入参类型
expectedComponentType := method.Type.In(1)
if !componentType.AssignableTo(expectedComponentType) {
componentType = structx.IndirectType(componentType)
return fmt.Errorf("%s error: injection types are inconsistent(Expected type -> %s.%s, Component type -> %s.%s)", injectInfo, expectedComponentType.PkgPath(), expectedComponentType.Name(), componentType.PkgPath(), componentType.Name())
}
method.Func.Call([]reflect.Value{objValue, reflect.ValueOf(component)})
}
return nil
}
func setFieldValue(fieldValue reflect.Value, component any) error {
if !fieldValue.IsValid() || !fieldValue.CanSet() {
// 不可导出变量处理
fieldPtrValue := reflect.NewAt(fieldValue.Type(), fieldValue.Addr().UnsafePointer())
fieldValue = fieldPtrValue.Elem()
if !fieldValue.IsValid() || !fieldValue.CanSet() {
return errors.New("the field is invalid or a non-exportable type")
}
}
fieldValue.Set(reflect.ValueOf(component))
return nil
}
func getComponentValueDesc(componentValueType reflect.Type) string {
if componentValueType.Kind() == reflect.Ptr {
componentValueType = structx.IndirectType(componentValueType)
return fmt.Sprintf("*%s.%s", componentValueType.PkgPath(), componentValueType.Name())
}
return fmt.Sprintf("%s.%s", componentValueType.PkgPath(), componentValueType.Name())
}