Skip to main content

自动注入依赖

 

using System.Reflection;
using Microsoft.Extensions.DependencyInjection;

namespace FMS.Server.Common.Extensions.ServiceCollectionExtensions;

/// <summary>
/// <para/>作者:许植奎
/// <para/>日期:2025/03/04 16:53:13
/// <para/>说明: 自动注入依赖, 可在 Interface 上添加生命周期接口, 也可以在服务类上添加生命周期接口(适用于没有实现接口的类)
/// </summary>
public static class AutoDependencyInjectExtension
{
    /// <summary>
    /// 自动注入依赖, 可在 Interface 上添加生命周期接口, 也可以在服务类上添加生命周期接口(适用于没有实现接口的类)
    /// </summary>
    /// <param name="services"></param>
    /// <param name="assemblies"></param>
    /// <returns></returns>
    public static IServiceCollection AddAutoDependencyInjection(this IServiceCollection services, params Assembly[] assemblies)
    {
        if (assemblies == null || assemblies.Length == 0)
            throw new ArgumentException("必须提供至少一个程序集");

        foreach (var assembly in assemblies)
        {
            var types = assembly.GetExportedTypes().Where(t => t is { IsClass: true, IsAbstract: false });

            foreach (var type in types)
            {
                var (lifecycle, interfaceType) = GetServiceMetadata(type);
                var keyAttribute = type.GetCustomAttribute<DependencyKeyAttribute>();

                if (lifecycle != null)
                {
                    if (keyAttribute != null)
                    {
                        foreach (var key in keyAttribute.Keys)
                        {
                            ValidateKeyedRegistration(type, interfaceType, key);
                            RegisterKeyedService(services, type, interfaceType, lifecycle.Value, key);
                        }
                    }
                    else
                    {
                        RegisterService(services, type, interfaceType, lifecycle.Value);
                    }
                }
            }
        }

        return services;
    }

    private static (ServiceLifetime?, Type?) GetServiceMetadata(Type type)
    {
        var allInterfaces = type.GetInterfaces();
        var lifecycleInterfaces = new List<(ServiceLifetime, Type?)>();

        // 检查所有接口是否继承生命周期接口
        foreach (var interfaceType in allInterfaces)
        {
            if (typeof(ISingletonDependency).IsAssignableFrom(interfaceType))
                lifecycleInterfaces.Add((ServiceLifetime.Singleton, interfaceType));
            else if (typeof(IScopedDependency).IsAssignableFrom(interfaceType))
                lifecycleInterfaces.Add((ServiceLifetime.Scoped, interfaceType));
            else if (typeof(ITransientDependency).IsAssignableFrom(interfaceType))
                lifecycleInterfaces.Add((ServiceLifetime.Transient, interfaceType));
        }

        // 检查类本身是否直接实现生命周期接口
        if (lifecycleInterfaces.Count == 0)
        {
            if (typeof(ISingletonDependency).IsAssignableFrom(type))
                lifecycleInterfaces.Add((ServiceLifetime.Singleton, null));
            else if (typeof(IScopedDependency).IsAssignableFrom(type))
                lifecycleInterfaces.Add((ServiceLifetime.Scoped, null));
            else if (typeof(ITransientDependency).IsAssignableFrom(type))
                lifecycleInterfaces.Add((ServiceLifetime.Transient, null));
        }

        // 处理生命周期冲突
        if (lifecycleInterfaces.Select(x => x.Item1).Distinct().Count() > 1)
        {
            var conflictTypes = lifecycleInterfaces.Select(x => $"{x.Item1} (来自 {(x.Item2 == null ? "类本身" : x.Item2.Name)})");
            throw new InvalidOperationException($"{type.FullName} 存在冲突的生命周期定义: {string.Join("、", conflictTypes)}");
        }

        if (lifecycleInterfaces.Count == 0)
        {
            return (null, null);
        }

        var lifecycle = lifecycleInterfaces.First().Item1;
        var serviceInterface = GetServiceInterface(allInterfaces);
        return (lifecycle, serviceInterface);
    }

    private static Type? GetServiceInterface(IEnumerable<Type> interfaces)
    {
        return interfaces.FirstOrDefault(p => p != typeof(ISingletonDependency)
                                              && p != typeof(IScopedDependency)
                                              && p != typeof(ITransientDependency));
    }

    private static void ValidateKeyedRegistration(Type implementationType, Type? interfaceType, object key)
    {
        if (interfaceType == null)
        {
            throw new InvalidOperationException($"{implementationType.FullName} 使用[DependencyKey]必须实现接口");
        }

        if (key == null)
        {
            throw new InvalidOperationException($"{implementationType.FullName} 的[DependencyKey]不能为null");
        }
    }

    private static void RegisterKeyedService(IServiceCollection services, Type implementationType,
        Type? interfaceType, ServiceLifetime lifecycle, Object key)
    {
        var serviceType = interfaceType ?? throw new InvalidOperationException("键控服务必须实现接口");

        var method = typeof(AutoDependencyInjectExtension)
            .GetMethod(nameof(RegisterKeyedServiceInternal), BindingFlags.Static | BindingFlags.NonPublic)
            ?.MakeGenericMethod(serviceType, implementationType);

        method?.Invoke(null, [services, key, lifecycle]);
    }

    private static void RegisterKeyedServiceInternal<TService, TImplementation>(IServiceCollection services, object serviceKey,
        ServiceLifetime lifecycle) where TService : class where TImplementation : class, TService
    {
        services.Add(new ServiceDescriptor(typeof(TService), serviceKey, typeof(TImplementation), lifecycle));
    }

    private static void RegisterService(IServiceCollection services, Type implementationType, Type? interfaceType, ServiceLifetime lifecycle)
    {
        var serviceType = interfaceType ?? implementationType;
        services.Add(new ServiceDescriptor(serviceType, implementationType, lifecycle));
    }
}

/// <summary>
/// 注册为单例服务
/// </summary>
public interface ISingletonDependency;

/// <summary>
/// 注册为瞬态服务
/// </summary>
public interface ITransientDependency;

/// <summary>
/// 注册为作用域服务
/// </summary>
public interface IScopedDependency;

/// <summary>
/// 依赖注入键标识特性
/// </summary>
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
public class DependencyKeyAttribute : Attribute
{
    public Object[] Keys { get; }

    public DependencyKeyAttribute(params Object[] key)
    {
        Keys = key ?? throw new ArgumentNullException(nameof(key));
    }
}