diff --git a/Directory.Packages.props b/Directory.Packages.props index b58c9e3..7b20a07 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -12,9 +12,8 @@ - - + + - diff --git a/src/Injectio.Generators/AnalyzerReleases.Shipped.md b/src/Injectio.Generators/AnalyzerReleases.Shipped.md new file mode 100644 index 0000000..f50bb1f --- /dev/null +++ b/src/Injectio.Generators/AnalyzerReleases.Shipped.md @@ -0,0 +1,2 @@ +; Shipped analyzer releases +; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md diff --git a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md new file mode 100644 index 0000000..7ad52dc --- /dev/null +++ b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md @@ -0,0 +1,16 @@ +; Unshipped analyzer release +; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md + +### New Rules + +Rule ID | Category | Severity | Notes +--------|----------|----------|------- +INJECT0001 | Injectio | Warning | RegisterServices method has invalid signature +INJECT0002 | Injectio | Warning | RegisterServices method has invalid second parameter +INJECT0003 | Injectio | Warning | RegisterServices method has too many parameters +INJECT0004 | Injectio | Warning | Factory method not found +INJECT0005 | Injectio | Warning | Factory method must be static +INJECT0006 | Injectio | Warning | Factory method has invalid signature +INJECT0007 | Injectio | Warning | Implementation does not implement service type +INJECT0008 | Injectio | Warning | Implementation type is abstract +INJECT0009 | Injectio | Warning | RegisterServices on non-static method in abstract class diff --git a/src/Injectio.Generators/DiagnosticDescriptors.cs b/src/Injectio.Generators/DiagnosticDescriptors.cs new file mode 100644 index 0000000..08e6d4e --- /dev/null +++ b/src/Injectio.Generators/DiagnosticDescriptors.cs @@ -0,0 +1,80 @@ +using Microsoft.CodeAnalysis; + +namespace Injectio.Generators; + +public static class DiagnosticDescriptors +{ + private const string Category = "Injectio"; + + public static readonly DiagnosticDescriptor InvalidMethodSignature = new( + id: "INJECT0001", + title: "RegisterServices method has invalid signature", + messageFormat: "Method '{0}' marked with [RegisterServices] must have IServiceCollection as its first parameter", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor InvalidMethodSecondParameter = new( + id: "INJECT0002", + title: "RegisterServices method has invalid second parameter", + messageFormat: "Method '{0}' marked with [RegisterServices] has an invalid second parameter; expected a string collection (e.g., IEnumerable)", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor MethodTooManyParameters = new( + id: "INJECT0003", + title: "RegisterServices method has too many parameters", + messageFormat: "Method '{0}' marked with [RegisterServices] has {1} parameters; expected 1 or 2", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor FactoryMethodNotFound = new( + id: "INJECT0004", + title: "Factory method not found", + messageFormat: "Factory method '{0}' was not found on type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor FactoryMethodNotStatic = new( + id: "INJECT0005", + title: "Factory method must be static", + messageFormat: "Factory method '{0}' on type '{1}' must be static", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor FactoryMethodInvalidSignature = new( + id: "INJECT0006", + title: "Factory method has invalid signature", + messageFormat: "Factory method '{0}' on type '{1}' must accept IServiceProvider as its first parameter and optionally object? as its second parameter", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor ServiceTypeMismatch = new( + id: "INJECT0007", + title: "Implementation does not implement service type", + messageFormat: "Type '{0}' does not implement or inherit from service type '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor AbstractImplementationType = new( + id: "INJECT0008", + title: "Implementation type is abstract", + messageFormat: "Implementation type '{0}' is abstract and cannot be instantiated without a factory method", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor RegisterServicesMethodOnAbstractClass = new( + id: "INJECT0009", + title: "RegisterServices on non-static method in abstract class", + messageFormat: "Method '{0}' marked with [RegisterServices] is a non-static method on abstract class '{1}'; the class cannot be instantiated", + category: Category, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true); +} diff --git a/src/Injectio.Generators/Injectio.Generators.csproj b/src/Injectio.Generators/Injectio.Generators.csproj index 9ebcf3c..0479dc4 100644 --- a/src/Injectio.Generators/Injectio.Generators.csproj +++ b/src/Injectio.Generators/Injectio.Generators.csproj @@ -27,4 +27,9 @@ + + + + + diff --git a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs new file mode 100644 index 0000000..025c05a --- /dev/null +++ b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs @@ -0,0 +1,383 @@ +using System.Collections.Immutable; + +using Injectio.Generators.Extensions; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Injectio.Generators; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class ServiceRegistrationAnalyzer : DiagnosticAnalyzer +{ + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create( + DiagnosticDescriptors.InvalidMethodSignature, + DiagnosticDescriptors.InvalidMethodSecondParameter, + DiagnosticDescriptors.MethodTooManyParameters, + DiagnosticDescriptors.FactoryMethodNotFound, + DiagnosticDescriptors.FactoryMethodNotStatic, + DiagnosticDescriptors.FactoryMethodInvalidSignature, + DiagnosticDescriptors.ServiceTypeMismatch, + DiagnosticDescriptors.AbstractImplementationType, + DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass); + + public override void Initialize(AnalysisContext context) + { + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + context.EnableConcurrentExecution(); + + context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method); + context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType); + } + + private static void AnalyzeMethod(SymbolAnalysisContext context) + { + if (context.Symbol is not IMethodSymbol methodSymbol) + return; + + var attributes = methodSymbol.GetAttributes(); + var isKnown = false; + + foreach (var attribute in attributes) + { + if (SymbolHelpers.IsMethodAttribute(attribute)) + { + isKnown = true; + break; + } + } + + if (!isKnown) + return; + + var location = methodSymbol.Locations.Length > 0 + ? methodSymbol.Locations[0] + : Location.None; + + // warn if non-static method on abstract class + if (!methodSymbol.IsStatic && methodSymbol.ContainingType.IsAbstract) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass, + location, + methodSymbol.Name, + methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat))); + } + + ValidateMethod(context, methodSymbol, location); + } + + private static void ValidateMethod(SymbolAnalysisContext context, IMethodSymbol methodSymbol, Location location) + { + if (methodSymbol.Parameters.Length > 2) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.MethodTooManyParameters, + location, + methodSymbol.Name, + methodSymbol.Parameters.Length.ToString())); + return; + } + + if (methodSymbol.Parameters.Length == 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.InvalidMethodSignature, + location, + methodSymbol.Name)); + return; + } + + var hasServiceCollection = SymbolHelpers.IsServiceCollection(methodSymbol.Parameters[0]); + + if (!hasServiceCollection) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.InvalidMethodSignature, + location, + methodSymbol.Name)); + return; + } + + if (methodSymbol.Parameters.Length == 2) + { + var hasTagCollection = SymbolHelpers.IsStringCollection(methodSymbol.Parameters[1]); + + if (!hasTagCollection) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.InvalidMethodSecondParameter, + location, + methodSymbol.Name)); + } + } + } + + private static void AnalyzeNamedType(SymbolAnalysisContext context) + { + if (context.Symbol is not INamedTypeSymbol classSymbol) + return; + + if (classSymbol.IsStatic) + return; + + var attributes = classSymbol.GetAttributes(); + + foreach (var attribute in attributes) + { + if (!SymbolHelpers.IsKnownAttribute(attribute, out _)) + continue; + + var location = classSymbol.Locations.Length > 0 + ? classSymbol.Locations[0] + : Location.None; + + AnalyzeRegistrationAttribute(context, classSymbol, attribute, location); + } + } + + private static void AnalyzeRegistrationAttribute( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + AttributeData attribute, + Location location) + { + var serviceTypes = new HashSet(); + string? implementationType = null; + string? implementationFactory = null; + string? registrationStrategy = null; + + var attributeClass = attribute.AttributeClass; + if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length) + { + for (var index = 0; index < attributeClass.TypeParameters.Length; index++) + { + var typeParameter = attributeClass.TypeParameters[index]; + var typeArgument = attributeClass.TypeArguments[index]; + + if (typeParameter.Name == "TService" || index == 0) + { + serviceTypes.Add(typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + } + else if (typeParameter.Name == "TImplementation" || index == 1) + { + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + } + } + } + + foreach (var parameter in attribute.NamedArguments) + { + var name = parameter.Key; + var value = parameter.Value.Value; + + if (string.IsNullOrEmpty(name) || value == null) + continue; + + switch (name) + { + case "ServiceType": + var serviceTypeSymbol = value as INamedTypeSymbol; + var serviceType = serviceTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString(); + serviceTypes.Add(serviceType); + break; + case "ImplementationType": + var implSymbol = value as INamedTypeSymbol; + implementationType = implSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString(); + break; + case "Factory": + implementationFactory = value.ToString(); + break; + case "Registration": + registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value); + break; + } + } + + // resolve effective implementation type + var implTypeName = implementationType.IsNullOrWhiteSpace() + ? classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) + : implementationType!; + + // determine effective registration strategy + if (registrationStrategy == null && implementationType == null && serviceTypes.Count == 0) + registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + // add interface-based service types for validation + bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + if (includeInterfaces) + { + foreach (var iface in classSymbol.AllInterfaces) + { + if (iface.ConstructedFrom.ToString() == "System.IEquatable") + continue; + + serviceTypes.Add(iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)); + } + } + + bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName + or KnownTypes.RegistrationStrategySelfWithInterfacesShortName + or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName; + + if (includeSelf || serviceTypes.Count == 0) + serviceTypes.Add(implTypeName); + + // validate abstract implementation type without factory + if (classSymbol.IsAbstract && implementationFactory.IsNullOrWhiteSpace() && implTypeName == classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.AbstractImplementationType, + location, + implTypeName)); + } + + // validate factory method + if (implementationFactory.HasValue()) + { + ValidateFactoryMethod(context, classSymbol, implementationFactory!, location); + } + + // validate service type assignability + ValidateServiceTypes(context, classSymbol, serviceTypes, location); + } + + private static void ValidateFactoryMethod( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + string factoryMethodName, + Location location) + { + var className = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + var members = classSymbol.GetMembers(factoryMethodName); + var factoryMethods = new List(); + + foreach (var member in members) + { + if (member is IMethodSymbol method) + factoryMethods.Add(method); + } + + if (factoryMethods.Count == 0) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.FactoryMethodNotFound, + location, + factoryMethodName, + className)); + return; + } + + // find at least one valid overload; only report if none exist + var hasStaticOverload = false; + + foreach (var method in factoryMethods) + { + if (!method.IsStatic) + continue; + + hasStaticOverload = true; + + if (method.Parameters.Length is not (1 or 2)) + continue; + + if (!SymbolHelpers.IsServiceProvider(method.Parameters[0])) + continue; + + // validate second parameter is object? (for keyed services) + if (method.Parameters.Length == 2 + && method.Parameters[1].Type.SpecialType != SpecialType.System_Object) + continue; + + // found a valid overload + return; + } + + context.ReportDiagnostic(Diagnostic.Create( + hasStaticOverload + ? DiagnosticDescriptors.FactoryMethodInvalidSignature + : DiagnosticDescriptors.FactoryMethodNotStatic, + location, + factoryMethodName, + className)); + } + + private static void ValidateServiceTypes( + SymbolAnalysisContext context, + INamedTypeSymbol classSymbol, + HashSet serviceTypes, + Location location) + { + var implTypeName = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + + foreach (var serviceType in serviceTypes) + { + if (serviceType == implTypeName) + continue; + + var implementsService = false; + + foreach (var iface in classSymbol.AllInterfaces) + { + var ifaceName = iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (ifaceName == serviceType) + { + implementsService = true; + break; + } + + // also check unbound generic form (e.g. IOpenGeneric<> vs IOpenGeneric) + var unboundIface = SymbolHelpers.ToUnboundGenericType(iface); + if (!SymbolEqualityComparer.Default.Equals(unboundIface, iface)) + { + var unboundName = unboundIface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundName == serviceType) + { + implementsService = true; + break; + } + } + } + + if (!implementsService) + { + var baseType = classSymbol.BaseType; + while (baseType is not null) + { + var baseName = baseType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (baseName == serviceType) + { + implementsService = true; + break; + } + + var unboundBase = SymbolHelpers.ToUnboundGenericType(baseType); + if (!SymbolEqualityComparer.Default.Equals(unboundBase, baseType)) + { + var unboundBaseName = unboundBase.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); + if (unboundBaseName == serviceType) + { + implementsService = true; + break; + } + } + + baseType = baseType.BaseType; + } + } + + if (!implementsService) + { + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.ServiceTypeMismatch, + location, + implTypeName, + serviceType)); + } + } + } +} diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs index 0556fcf..d4be12c 100644 --- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs +++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs @@ -14,11 +14,6 @@ namespace Injectio.Generators; [Generator] public class ServiceRegistrationGenerator : IIncrementalGenerator { - private static readonly SymbolDisplayFormat _fullyQualifiedNullableFormat = - SymbolDisplayFormat.FullyQualifiedFormat.AddMiscellaneousOptions( - SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier - ); - public void Initialize(IncrementalGeneratorInitializationContext context) { // find all classes and methods with attributes @@ -129,7 +124,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken // make sure attribute is for registration var attributes = methodSymbol.GetAttributes(); - var isKnown = attributes.Any(IsMethodAttribute); + var isKnown = attributes.Any(SymbolHelpers.IsMethodAttribute); if (!isKnown) return null; @@ -139,7 +134,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken var registration = new ModuleRegistration ( - ClassName: methodSymbol.ContainingType.ToDisplayString(_fullyQualifiedNullableFormat), + ClassName: methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat), MethodName: methodSymbol.Name, IsStatic: methodSymbol.IsStatic, HasTagCollection: hasTagCollection @@ -184,7 +179,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo if (methodSymbol.Parameters.Length is 1 or 2) { var parameterSymbol = methodSymbol.Parameters[0]; - hasServiceCollection = IsServiceCollection(parameterSymbol); + hasServiceCollection = SymbolHelpers.IsServiceCollection(parameterSymbol); } if (methodSymbol.Parameters.Length is 1) @@ -194,9 +189,9 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo if (methodSymbol.Parameters.Length is 2) { var parameterSymbol = methodSymbol.Parameters[1]; - hasTagCollection = IsStringCollection(parameterSymbol); + hasTagCollection = SymbolHelpers.IsStringCollection(parameterSymbol); - // to be valid, parameter 0 must be service collection and parameter 1 must be string collection, + // to be valid, parameter 0 must be service collection and parameter 1 must be string collection, return (hasServiceCollection && hasTagCollection, hasTagCollection); } @@ -207,7 +202,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute) { // check for known attribute - if (!IsKnownAttribute(attribute, out var serviceLifetime)) + if (!SymbolHelpers.IsKnownAttribute(attribute, out var serviceLifetime)) return null; // defaults @@ -231,12 +226,12 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo if (typeParameter.Name == "TService" || index == 0) { - var service = typeArgument.ToDisplayString(_fullyQualifiedNullableFormat); + var service = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); serviceTypes.Add(service); } else if (typeParameter.Name == "TImplementation" || index == 1) { - implementationType = typeArgument.ToDisplayString(_fullyQualifiedNullableFormat); + implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); } } } @@ -256,7 +251,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo var serviceTypeSymbol = value as INamedTypeSymbol; isOpenGeneric = isOpenGeneric || IsOpenGeneric(serviceTypeSymbol); - var serviceType = serviceTypeSymbol?.ToDisplayString(_fullyQualifiedNullableFormat) ?? value.ToString(); + var serviceType = serviceTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString(); serviceTypes.Add(serviceType); break; case "ServiceKey": @@ -266,7 +261,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo var implementationTypeSymbol = value as INamedTypeSymbol; isOpenGeneric = isOpenGeneric || IsOpenGeneric(implementationTypeSymbol); - implementationType = implementationTypeSymbol?.ToDisplayString(_fullyQualifiedNullableFormat) ?? value.ToString(); + implementationType = implementationTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString(); break; case "Factory": implementationFactory = value.ToString(); @@ -275,7 +270,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo duplicateStrategy = ResolveDuplicateStrategy(value); break; case "Registration": - registrationStrategy = ResolveRegistrationStrategy(value); + registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value); break; case "Tags": var tagsItems = value @@ -304,9 +299,9 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo // no implementation type set, use class attribute is on if (implementationType.IsNullOrWhiteSpace()) { - var unboundType = ToUnboundGenericType(classSymbol); + var unboundType = SymbolHelpers.ToUnboundGenericType(classSymbol); isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundType); - implementationType = unboundType.ToDisplayString(_fullyQualifiedNullableFormat); + implementationType = unboundType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); } // add implemented interfaces @@ -321,10 +316,10 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName if (implementedInterface.ConstructedFrom.ToString() == "System.IEquatable") continue; - var unboundInterface = ToUnboundGenericType(implementedInterface); + var unboundInterface = SymbolHelpers.ToUnboundGenericType(implementedInterface); isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundInterface); - var interfaceName = unboundInterface.ToDisplayString(_fullyQualifiedNullableFormat); + var interfaceName = unboundInterface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat); serviceTypes.Add(interfaceName); } } @@ -354,136 +349,6 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName IsOpenGeneric: isOpenGeneric); } - private static INamedTypeSymbol ToUnboundGenericType(INamedTypeSymbol typeSymbol) - { - if (!typeSymbol.IsGenericType || typeSymbol.IsUnboundGenericType) - return typeSymbol; - - foreach (var typeArgument in typeSymbol.TypeArguments) - { - // If TypeKind is TypeParameter, it's actually the name of a locally declared type-parameter -> placeholder - if (typeArgument.TypeKind != TypeKind.TypeParameter) - return typeSymbol; - } - - return typeSymbol.ConstructUnboundGenericType(); - } - - private static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime) - { - if (IsSingletonAttribute(attribute)) - { - serviceLifetime = KnownTypes.ServiceLifetimeSingletonFullName; - return true; - } - - if (IsScopedAttribute(attribute)) - { - serviceLifetime = KnownTypes.ServiceLifetimeScopedFullName; - return true; - } - - if (IsTransientAttribute(attribute)) - { - serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName; - return true; - } - - serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName; - return false; - } - - private static bool IsTransientAttribute(AttributeData attribute) - { - return attribute?.AttributeClass is - { - Name: KnownTypes.TransientAttributeShortName or KnownTypes.TransientAttributeTypeName, - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace.Name: "Injectio" - } - }; - } - - private static bool IsSingletonAttribute(AttributeData attribute) - { - return attribute?.AttributeClass is - { - Name: KnownTypes.SingletonAttributeShortName or KnownTypes.SingletonAttributeTypeName, - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace.Name: "Injectio" - } - }; - } - - private static bool IsScopedAttribute(AttributeData attribute) - { - return attribute?.AttributeClass is - { - Name: KnownTypes.ScopedAttributeShortName or KnownTypes.ScopedAttributeTypeName, - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace.Name: "Injectio" - } - }; - } - - private static bool IsMethodAttribute(AttributeData attribute) - { - return attribute?.AttributeClass is - { - Name: KnownTypes.ModuleAttributeShortName or KnownTypes.ModuleAttributeTypeName, - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace.Name: "Injectio" - } - }; - } - - private static bool IsServiceCollection(IParameterSymbol parameterSymbol) - { - return parameterSymbol?.Type is - { - Name: "IServiceCollection" or "ServiceCollection", - ContainingNamespace: - { - Name: "DependencyInjection", - ContainingNamespace: - { - Name: "Extensions", - ContainingNamespace.Name: "Microsoft" - } - } - }; - } - - private static bool IsStringCollection(IParameterSymbol parameterSymbol) - { - var type = parameterSymbol?.Type as INamedTypeSymbol; - - return type is - { - Name: "IEnumerable" or "IReadOnlySet" or "IReadOnlyCollection" or "ICollection" or "ISet" or "HashSet", - IsGenericType: true, - TypeArguments.Length: 1, - TypeParameters.Length: 1, - ContainingNamespace: - { - Name: "Generic", - ContainingNamespace: - { - Name: "Collections", - ContainingNamespace.Name: "System" - } - } - }; - } - private static bool IsOpenGeneric(INamedTypeSymbol? typeSymbol) { if (typeSymbol is null) @@ -510,21 +375,4 @@ private static string ResolveDuplicateStrategy(object? value) _ => KnownTypes.DuplicateStrategySkipShortName }; } - - private static string ResolveRegistrationStrategy(object? value) - { - return value switch - { - int v => v switch - { - KnownTypes.RegistrationStrategySelfValue => KnownTypes.RegistrationStrategySelfShortName, - KnownTypes.RegistrationStrategyImplementedInterfacesValue => KnownTypes.RegistrationStrategyImplementedInterfacesShortName, - KnownTypes.RegistrationStrategySelfWithInterfacesValue => KnownTypes.RegistrationStrategySelfWithInterfacesShortName, - KnownTypes.RegistrationStrategySelfWithProxyFactoryValue => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName, - _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName - }, - string text => text, - _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName - }; - } } diff --git a/src/Injectio.Generators/SymbolHelpers.cs b/src/Injectio.Generators/SymbolHelpers.cs new file mode 100644 index 0000000..54cc867 --- /dev/null +++ b/src/Injectio.Generators/SymbolHelpers.cs @@ -0,0 +1,177 @@ +using Microsoft.CodeAnalysis; + +namespace Injectio.Generators; + +internal static class SymbolHelpers +{ + public static readonly SymbolDisplayFormat FullyQualifiedNullableFormat = + SymbolDisplayFormat.FullyQualifiedFormat.AddMiscellaneousOptions( + SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier + ); + + public static bool IsMethodAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.ModuleAttributeShortName or KnownTypes.ModuleAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + + public static bool IsTransientAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.TransientAttributeShortName or KnownTypes.TransientAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + + public static bool IsSingletonAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.SingletonAttributeShortName or KnownTypes.SingletonAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + + public static bool IsScopedAttribute(AttributeData attribute) + { + return attribute?.AttributeClass is + { + Name: KnownTypes.ScopedAttributeShortName or KnownTypes.ScopedAttributeTypeName, + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace.Name: "Injectio" + } + }; + } + + public static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime) + { + if (IsSingletonAttribute(attribute)) + { + serviceLifetime = KnownTypes.ServiceLifetimeSingletonFullName; + return true; + } + + if (IsScopedAttribute(attribute)) + { + serviceLifetime = KnownTypes.ServiceLifetimeScopedFullName; + return true; + } + + if (IsTransientAttribute(attribute)) + { + serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName; + return true; + } + + serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName; + return false; + } + + public static bool IsServiceCollection(IParameterSymbol parameterSymbol) + { + return parameterSymbol?.Type is + { + Name: "IServiceCollection" or "ServiceCollection", + ContainingNamespace: + { + Name: "DependencyInjection", + ContainingNamespace: + { + Name: "Extensions", + ContainingNamespace.Name: "Microsoft" + } + } + }; + } + + public static bool IsStringCollection(IParameterSymbol parameterSymbol) + { + var type = parameterSymbol?.Type as INamedTypeSymbol; + + if (type is not + { + Name: "IEnumerable" or "IReadOnlySet" or "IReadOnlyCollection" or "ICollection" or "ISet" or "HashSet", + IsGenericType: true, + TypeArguments.Length: 1, + TypeParameters.Length: 1, + ContainingNamespace: + { + Name: "Generic", + ContainingNamespace: + { + Name: "Collections", + ContainingNamespace.Name: "System" + } + } + }) + { + return false; + } + + // verify the generic argument is string + return type.TypeArguments[0].SpecialType == SpecialType.System_String; + } + + public static bool IsServiceProvider(IParameterSymbol parameterSymbol) + { + return parameterSymbol?.Type is + { + Name: "IServiceProvider", + ContainingNamespace: + { + Name: "System", + ContainingNamespace.IsGlobalNamespace: true + } + }; + } + + public static INamedTypeSymbol ToUnboundGenericType(INamedTypeSymbol typeSymbol) + { + if (!typeSymbol.IsGenericType || typeSymbol.IsUnboundGenericType) + return typeSymbol; + + foreach (var typeArgument in typeSymbol.TypeArguments) + { + // If TypeKind is TypeParameter, it's actually the name of a locally declared type-parameter -> placeholder + if (typeArgument.TypeKind != TypeKind.TypeParameter) + return typeSymbol; + } + + return typeSymbol.ConstructUnboundGenericType(); + } + + public static string ResolveRegistrationStrategy(object? value) + { + return value switch + { + int v => v switch + { + KnownTypes.RegistrationStrategySelfValue => KnownTypes.RegistrationStrategySelfShortName, + KnownTypes.RegistrationStrategyImplementedInterfacesValue => KnownTypes.RegistrationStrategyImplementedInterfacesShortName, + KnownTypes.RegistrationStrategySelfWithInterfacesValue => KnownTypes.RegistrationStrategySelfWithInterfacesShortName, + KnownTypes.RegistrationStrategySelfWithProxyFactoryValue => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName, + _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + }, + string text => text, + _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName + }; + } +} diff --git a/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs b/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs index 882030c..bcbba2f 100644 --- a/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs +++ b/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs @@ -1,11 +1,9 @@ -using Xunit.Abstractions; - using XUnit.Hosting; namespace Injectio.Acceptance.Tests; [Collection(DependencyInjectionCollection.CollectionName)] -public abstract class DependencyInjectionBase(ITestOutputHelper output, DependencyInjectionFixture fixture) - : TestHostBase(output, fixture) +public abstract class DependencyInjectionBase(DependencyInjectionFixture fixture) + : TestHostBase(fixture) { } diff --git a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj index 8eb37db..2cdb10b 100644 --- a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj +++ b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj @@ -18,12 +18,8 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + - - all - runtime; build; native; contentfiles; analyzers; buildtransitive - diff --git a/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs b/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs index 1efb62e..4b7f6e0 100644 --- a/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs +++ b/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs @@ -4,12 +4,10 @@ using Microsoft.Extensions.DependencyInjection; -using Xunit.Abstractions; - namespace Injectio.Acceptance.Tests; [Collection(DependencyInjectionCollection.CollectionName)] -public class LibraryServiceTests(ITestOutputHelper output, DependencyInjectionFixture fixture) : DependencyInjectionBase(output, fixture) +public class LibraryServiceTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture) { [Fact] public void ShouldResolveService() diff --git a/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs b/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs index 76dfe99..2fbf1f9 100644 --- a/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs +++ b/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs @@ -4,12 +4,10 @@ using Microsoft.Extensions.DependencyInjection; -using Xunit.Abstractions; - namespace Injectio.Acceptance.Tests; [Collection(DependencyInjectionCollection.CollectionName)] -public class LocalServiceTests(ITestOutputHelper output, DependencyInjectionFixture fixture) : DependencyInjectionBase(output, fixture) +public class LocalServiceTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture) { [Fact] public void ShouldResolveLocalService() diff --git a/tests/Injectio.Tests/Injectio.Tests.csproj b/tests/Injectio.Tests/Injectio.Tests.csproj index 6bea383..236d23b 100644 --- a/tests/Injectio.Tests/Injectio.Tests.csproj +++ b/tests/Injectio.Tests/Injectio.Tests.csproj @@ -14,12 +14,8 @@ - - - - runtime; build; native; contentfiles; analyzers; buildtransitive - all - + + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs new file mode 100644 index 0000000..0aa4b15 --- /dev/null +++ b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs @@ -0,0 +1,424 @@ +using System; +using System.Collections.Immutable; +using System.Linq; +using System.Threading.Tasks; + +using AwesomeAssertions; + +using Injectio.Attributes; +using Injectio.Generators; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.Extensions.DependencyInjection; + +using Xunit; + +namespace Injectio.Tests; + +public class ServiceRegistrationDiagnosticTests +{ + [Fact] + public async Task DiagnoseRegisterServicesInvalidFirstParameter() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register(string test) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0001"); + } + + [Fact] + public async Task DiagnoseRegisterServicesInvalidSecondParameter() + { + var source = @" +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register(IServiceCollection services, string test) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0002"); + } + + [Fact] + public async Task DiagnoseRegisterServicesTooManyParameters() + { + var source = @" +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register(IServiceCollection services, string a, string b) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0003"); + } + + [Fact] + public async Task DiagnoseRegisterServicesNoParameters() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register() + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0001"); + } + + [Fact] + public async Task DiagnoseFactoryMethodNotFound() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterTransient(ServiceType = typeof(IService), Factory = ""NonExistentMethod"")] +public class MyService : IService +{ +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0004"); + } + + [Fact] + public async Task DiagnoseFactoryMethodNotStatic() + { + var source = @" +using System; +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))] +public class MyService : IService +{ + public IService ServiceFactory(IServiceProvider serviceProvider) + { + return new MyService(); + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0005"); + } + + [Fact] + public async Task DiagnoseFactoryMethodInvalidSignature() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))] +public class MyService : IService +{ + public static IService ServiceFactory(string notServiceProvider) + { + return new MyService(); + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0006"); + } + + [Fact] + public async Task DiagnoseServiceTypeMismatch() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } +public interface IOtherService { } + +[RegisterTransient(ServiceType = typeof(IOtherService))] +public class MyService : IService +{ +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0007"); + } + + [Fact] + public async Task DiagnoseRegisterServicesOnAbstractClassNonStaticMethod() + { + var source = @" +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Sample; + +public abstract class RegistrationModule +{ + [RegisterServices] + public void Register(IServiceCollection services) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0009"); + } + + [Fact] + public async Task DiagnoseAbstractImplementationTypeWithoutFactory() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterSingleton] +public abstract class AbstractService : IService +{ +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().ContainSingle(d => d.Id == "INJECT0008"); + } + + [Fact] + public async Task NoDiagnosticsForAbstractImplementationTypeWithFactory() + { + var source = @" +using System; +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterSingleton(ServiceType = typeof(IService), Factory = nameof(Create))] +public abstract class AbstractService : IService +{ + public static IService Create(IServiceProvider serviceProvider) + { + return null!; + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + [Fact] + public async Task NoDiagnosticsForValidRegistration() + { + var source = @" +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterSingleton] +public class MyService : IService +{ +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + [Fact] + public async Task NoDiagnosticsForValidFactory() + { + var source = @" +using System; +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))] +public class MyService : IService +{ + public static IService ServiceFactory(IServiceProvider serviceProvider) + { + return new MyService(); + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + [Fact] + public async Task NoDiagnosticsForValidKeyedFactory() + { + var source = @" +using System; +using Injectio.Attributes; + +namespace Injectio.Sample; + +public interface IService { } + +[RegisterTransient(ServiceType = typeof(IService), ServiceKey = ""key"", Factory = nameof(ServiceFactory))] +public class MyService : IService +{ + public static IService ServiceFactory(IServiceProvider serviceProvider, object? serviceKey) + { + return new MyService(); + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + [Fact] + public async Task NoDiagnosticsForValidRegisterServicesMethod() + { + var source = @" +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register(IServiceCollection services) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + [Fact] + public async Task NoDiagnosticsForValidRegisterServicesWithTags() + { + var source = @" +using System.Collections.Generic; +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; + +namespace Injectio.Sample; + +public static class RegistrationModule +{ + [RegisterServices] + public static void Register(IServiceCollection services, IEnumerable tags) + { + } +} +"; + + var diagnostics = await GetDiagnosticsAsync(source); + + diagnostics.Should().BeEmpty(); + } + + private static async Task> GetDiagnosticsAsync(string source) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat(new[] + { + MetadataReference.CreateFromFile(typeof(ServiceRegistrationGenerator).Assembly.Location), + MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + }); + + var compilation = CSharpCompilation.Create( + "Test.Diagnostics", + new[] { syntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + var analyzer = new ServiceRegistrationAnalyzer(); + var compilationWithAnalyzers = compilation.WithAnalyzers(ImmutableArray.Create(analyzer)); + var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync(); + + // return only Injectio diagnostics + return diagnostics + .Where(d => d.Id.StartsWith("INJECT")) + .ToImmutableArray(); + } +}