diff --git a/Directory.Packages.props b/Directory.Packages.props
index b58c9e3..7b20a07 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -12,9 +12,8 @@
-
-
+
+
-
diff --git a/src/Injectio.Generators/AnalyzerReleases.Shipped.md b/src/Injectio.Generators/AnalyzerReleases.Shipped.md
new file mode 100644
index 0000000..f50bb1f
--- /dev/null
+++ b/src/Injectio.Generators/AnalyzerReleases.Shipped.md
@@ -0,0 +1,2 @@
+; Shipped analyzer releases
+; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md
diff --git a/src/Injectio.Generators/AnalyzerReleases.Unshipped.md b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md
new file mode 100644
index 0000000..7ad52dc
--- /dev/null
+++ b/src/Injectio.Generators/AnalyzerReleases.Unshipped.md
@@ -0,0 +1,16 @@
+; Unshipped analyzer release
+; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md
+
+### New Rules
+
+Rule ID | Category | Severity | Notes
+--------|----------|----------|-------
+INJECT0001 | Injectio | Warning | RegisterServices method has invalid signature
+INJECT0002 | Injectio | Warning | RegisterServices method has invalid second parameter
+INJECT0003 | Injectio | Warning | RegisterServices method has too many parameters
+INJECT0004 | Injectio | Warning | Factory method not found
+INJECT0005 | Injectio | Warning | Factory method must be static
+INJECT0006 | Injectio | Warning | Factory method has invalid signature
+INJECT0007 | Injectio | Warning | Implementation does not implement service type
+INJECT0008 | Injectio | Warning | Implementation type is abstract
+INJECT0009 | Injectio | Warning | RegisterServices on non-static method in abstract class
diff --git a/src/Injectio.Generators/DiagnosticDescriptors.cs b/src/Injectio.Generators/DiagnosticDescriptors.cs
new file mode 100644
index 0000000..08e6d4e
--- /dev/null
+++ b/src/Injectio.Generators/DiagnosticDescriptors.cs
@@ -0,0 +1,80 @@
+using Microsoft.CodeAnalysis;
+
+namespace Injectio.Generators;
+
+public static class DiagnosticDescriptors
+{
+ private const string Category = "Injectio";
+
+ public static readonly DiagnosticDescriptor InvalidMethodSignature = new(
+ id: "INJECT0001",
+ title: "RegisterServices method has invalid signature",
+ messageFormat: "Method '{0}' marked with [RegisterServices] must have IServiceCollection as its first parameter",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor InvalidMethodSecondParameter = new(
+ id: "INJECT0002",
+ title: "RegisterServices method has invalid second parameter",
+ messageFormat: "Method '{0}' marked with [RegisterServices] has an invalid second parameter; expected a string collection (e.g., IEnumerable)",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor MethodTooManyParameters = new(
+ id: "INJECT0003",
+ title: "RegisterServices method has too many parameters",
+ messageFormat: "Method '{0}' marked with [RegisterServices] has {1} parameters; expected 1 or 2",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor FactoryMethodNotFound = new(
+ id: "INJECT0004",
+ title: "Factory method not found",
+ messageFormat: "Factory method '{0}' was not found on type '{1}'",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor FactoryMethodNotStatic = new(
+ id: "INJECT0005",
+ title: "Factory method must be static",
+ messageFormat: "Factory method '{0}' on type '{1}' must be static",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor FactoryMethodInvalidSignature = new(
+ id: "INJECT0006",
+ title: "Factory method has invalid signature",
+ messageFormat: "Factory method '{0}' on type '{1}' must accept IServiceProvider as its first parameter and optionally object? as its second parameter",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor ServiceTypeMismatch = new(
+ id: "INJECT0007",
+ title: "Implementation does not implement service type",
+ messageFormat: "Type '{0}' does not implement or inherit from service type '{1}'",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor AbstractImplementationType = new(
+ id: "INJECT0008",
+ title: "Implementation type is abstract",
+ messageFormat: "Implementation type '{0}' is abstract and cannot be instantiated without a factory method",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor RegisterServicesMethodOnAbstractClass = new(
+ id: "INJECT0009",
+ title: "RegisterServices on non-static method in abstract class",
+ messageFormat: "Method '{0}' marked with [RegisterServices] is a non-static method on abstract class '{1}'; the class cannot be instantiated",
+ category: Category,
+ defaultSeverity: DiagnosticSeverity.Warning,
+ isEnabledByDefault: true);
+}
diff --git a/src/Injectio.Generators/Injectio.Generators.csproj b/src/Injectio.Generators/Injectio.Generators.csproj
index 9ebcf3c..0479dc4 100644
--- a/src/Injectio.Generators/Injectio.Generators.csproj
+++ b/src/Injectio.Generators/Injectio.Generators.csproj
@@ -27,4 +27,9 @@
+
+
+
+
+
diff --git a/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs
new file mode 100644
index 0000000..025c05a
--- /dev/null
+++ b/src/Injectio.Generators/ServiceRegistrationAnalyzer.cs
@@ -0,0 +1,383 @@
+using System.Collections.Immutable;
+
+using Injectio.Generators.Extensions;
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Diagnostics;
+
+namespace Injectio.Generators;
+
+[DiagnosticAnalyzer(LanguageNames.CSharp)]
+public class ServiceRegistrationAnalyzer : DiagnosticAnalyzer
+{
+ public override ImmutableArray SupportedDiagnostics { get; } =
+ ImmutableArray.Create(
+ DiagnosticDescriptors.InvalidMethodSignature,
+ DiagnosticDescriptors.InvalidMethodSecondParameter,
+ DiagnosticDescriptors.MethodTooManyParameters,
+ DiagnosticDescriptors.FactoryMethodNotFound,
+ DiagnosticDescriptors.FactoryMethodNotStatic,
+ DiagnosticDescriptors.FactoryMethodInvalidSignature,
+ DiagnosticDescriptors.ServiceTypeMismatch,
+ DiagnosticDescriptors.AbstractImplementationType,
+ DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass);
+
+ public override void Initialize(AnalysisContext context)
+ {
+ context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
+ context.EnableConcurrentExecution();
+
+ context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method);
+ context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType);
+ }
+
+ private static void AnalyzeMethod(SymbolAnalysisContext context)
+ {
+ if (context.Symbol is not IMethodSymbol methodSymbol)
+ return;
+
+ var attributes = methodSymbol.GetAttributes();
+ var isKnown = false;
+
+ foreach (var attribute in attributes)
+ {
+ if (SymbolHelpers.IsMethodAttribute(attribute))
+ {
+ isKnown = true;
+ break;
+ }
+ }
+
+ if (!isKnown)
+ return;
+
+ var location = methodSymbol.Locations.Length > 0
+ ? methodSymbol.Locations[0]
+ : Location.None;
+
+ // warn if non-static method on abstract class
+ if (!methodSymbol.IsStatic && methodSymbol.ContainingType.IsAbstract)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.RegisterServicesMethodOnAbstractClass,
+ location,
+ methodSymbol.Name,
+ methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)));
+ }
+
+ ValidateMethod(context, methodSymbol, location);
+ }
+
+ private static void ValidateMethod(SymbolAnalysisContext context, IMethodSymbol methodSymbol, Location location)
+ {
+ if (methodSymbol.Parameters.Length > 2)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.MethodTooManyParameters,
+ location,
+ methodSymbol.Name,
+ methodSymbol.Parameters.Length.ToString()));
+ return;
+ }
+
+ if (methodSymbol.Parameters.Length == 0)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.InvalidMethodSignature,
+ location,
+ methodSymbol.Name));
+ return;
+ }
+
+ var hasServiceCollection = SymbolHelpers.IsServiceCollection(methodSymbol.Parameters[0]);
+
+ if (!hasServiceCollection)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.InvalidMethodSignature,
+ location,
+ methodSymbol.Name));
+ return;
+ }
+
+ if (methodSymbol.Parameters.Length == 2)
+ {
+ var hasTagCollection = SymbolHelpers.IsStringCollection(methodSymbol.Parameters[1]);
+
+ if (!hasTagCollection)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.InvalidMethodSecondParameter,
+ location,
+ methodSymbol.Name));
+ }
+ }
+ }
+
+ private static void AnalyzeNamedType(SymbolAnalysisContext context)
+ {
+ if (context.Symbol is not INamedTypeSymbol classSymbol)
+ return;
+
+ if (classSymbol.IsStatic)
+ return;
+
+ var attributes = classSymbol.GetAttributes();
+
+ foreach (var attribute in attributes)
+ {
+ if (!SymbolHelpers.IsKnownAttribute(attribute, out _))
+ continue;
+
+ var location = classSymbol.Locations.Length > 0
+ ? classSymbol.Locations[0]
+ : Location.None;
+
+ AnalyzeRegistrationAttribute(context, classSymbol, attribute, location);
+ }
+ }
+
+ private static void AnalyzeRegistrationAttribute(
+ SymbolAnalysisContext context,
+ INamedTypeSymbol classSymbol,
+ AttributeData attribute,
+ Location location)
+ {
+ var serviceTypes = new HashSet();
+ string? implementationType = null;
+ string? implementationFactory = null;
+ string? registrationStrategy = null;
+
+ var attributeClass = attribute.AttributeClass;
+ if (attributeClass is { IsGenericType: true } && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length)
+ {
+ for (var index = 0; index < attributeClass.TypeParameters.Length; index++)
+ {
+ var typeParameter = attributeClass.TypeParameters[index];
+ var typeArgument = attributeClass.TypeArguments[index];
+
+ if (typeParameter.Name == "TService" || index == 0)
+ {
+ serviceTypes.Add(typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat));
+ }
+ else if (typeParameter.Name == "TImplementation" || index == 1)
+ {
+ implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ }
+ }
+ }
+
+ foreach (var parameter in attribute.NamedArguments)
+ {
+ var name = parameter.Key;
+ var value = parameter.Value.Value;
+
+ if (string.IsNullOrEmpty(name) || value == null)
+ continue;
+
+ switch (name)
+ {
+ case "ServiceType":
+ var serviceTypeSymbol = value as INamedTypeSymbol;
+ var serviceType = serviceTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString();
+ serviceTypes.Add(serviceType);
+ break;
+ case "ImplementationType":
+ var implSymbol = value as INamedTypeSymbol;
+ implementationType = implSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString();
+ break;
+ case "Factory":
+ implementationFactory = value.ToString();
+ break;
+ case "Registration":
+ registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value);
+ break;
+ }
+ }
+
+ // resolve effective implementation type
+ var implTypeName = implementationType.IsNullOrWhiteSpace()
+ ? classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat)
+ : implementationType!;
+
+ // determine effective registration strategy
+ if (registrationStrategy == null && implementationType == null && serviceTypes.Count == 0)
+ registrationStrategy = KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
+ // add interface-based service types for validation
+ bool includeInterfaces = registrationStrategy is KnownTypes.RegistrationStrategyImplementedInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
+ if (includeInterfaces)
+ {
+ foreach (var iface in classSymbol.AllInterfaces)
+ {
+ if (iface.ConstructedFrom.ToString() == "System.IEquatable")
+ continue;
+
+ serviceTypes.Add(iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat));
+ }
+ }
+
+ bool includeSelf = registrationStrategy is KnownTypes.RegistrationStrategySelfShortName
+ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
+ or KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName;
+
+ if (includeSelf || serviceTypes.Count == 0)
+ serviceTypes.Add(implTypeName);
+
+ // validate abstract implementation type without factory
+ if (classSymbol.IsAbstract && implementationFactory.IsNullOrWhiteSpace() && implTypeName == classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat))
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.AbstractImplementationType,
+ location,
+ implTypeName));
+ }
+
+ // validate factory method
+ if (implementationFactory.HasValue())
+ {
+ ValidateFactoryMethod(context, classSymbol, implementationFactory!, location);
+ }
+
+ // validate service type assignability
+ ValidateServiceTypes(context, classSymbol, serviceTypes, location);
+ }
+
+ private static void ValidateFactoryMethod(
+ SymbolAnalysisContext context,
+ INamedTypeSymbol classSymbol,
+ string factoryMethodName,
+ Location location)
+ {
+ var className = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ var members = classSymbol.GetMembers(factoryMethodName);
+ var factoryMethods = new List();
+
+ foreach (var member in members)
+ {
+ if (member is IMethodSymbol method)
+ factoryMethods.Add(method);
+ }
+
+ if (factoryMethods.Count == 0)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.FactoryMethodNotFound,
+ location,
+ factoryMethodName,
+ className));
+ return;
+ }
+
+ // find at least one valid overload; only report if none exist
+ var hasStaticOverload = false;
+
+ foreach (var method in factoryMethods)
+ {
+ if (!method.IsStatic)
+ continue;
+
+ hasStaticOverload = true;
+
+ if (method.Parameters.Length is not (1 or 2))
+ continue;
+
+ if (!SymbolHelpers.IsServiceProvider(method.Parameters[0]))
+ continue;
+
+ // validate second parameter is object? (for keyed services)
+ if (method.Parameters.Length == 2
+ && method.Parameters[1].Type.SpecialType != SpecialType.System_Object)
+ continue;
+
+ // found a valid overload
+ return;
+ }
+
+ context.ReportDiagnostic(Diagnostic.Create(
+ hasStaticOverload
+ ? DiagnosticDescriptors.FactoryMethodInvalidSignature
+ : DiagnosticDescriptors.FactoryMethodNotStatic,
+ location,
+ factoryMethodName,
+ className));
+ }
+
+ private static void ValidateServiceTypes(
+ SymbolAnalysisContext context,
+ INamedTypeSymbol classSymbol,
+ HashSet serviceTypes,
+ Location location)
+ {
+ var implTypeName = classSymbol.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+
+ foreach (var serviceType in serviceTypes)
+ {
+ if (serviceType == implTypeName)
+ continue;
+
+ var implementsService = false;
+
+ foreach (var iface in classSymbol.AllInterfaces)
+ {
+ var ifaceName = iface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (ifaceName == serviceType)
+ {
+ implementsService = true;
+ break;
+ }
+
+ // also check unbound generic form (e.g. IOpenGeneric<> vs IOpenGeneric)
+ var unboundIface = SymbolHelpers.ToUnboundGenericType(iface);
+ if (!SymbolEqualityComparer.Default.Equals(unboundIface, iface))
+ {
+ var unboundName = unboundIface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (unboundName == serviceType)
+ {
+ implementsService = true;
+ break;
+ }
+ }
+ }
+
+ if (!implementsService)
+ {
+ var baseType = classSymbol.BaseType;
+ while (baseType is not null)
+ {
+ var baseName = baseType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (baseName == serviceType)
+ {
+ implementsService = true;
+ break;
+ }
+
+ var unboundBase = SymbolHelpers.ToUnboundGenericType(baseType);
+ if (!SymbolEqualityComparer.Default.Equals(unboundBase, baseType))
+ {
+ var unboundBaseName = unboundBase.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
+ if (unboundBaseName == serviceType)
+ {
+ implementsService = true;
+ break;
+ }
+ }
+
+ baseType = baseType.BaseType;
+ }
+ }
+
+ if (!implementsService)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(
+ DiagnosticDescriptors.ServiceTypeMismatch,
+ location,
+ implTypeName,
+ serviceType));
+ }
+ }
+ }
+}
diff --git a/src/Injectio.Generators/ServiceRegistrationGenerator.cs b/src/Injectio.Generators/ServiceRegistrationGenerator.cs
index 0556fcf..d4be12c 100644
--- a/src/Injectio.Generators/ServiceRegistrationGenerator.cs
+++ b/src/Injectio.Generators/ServiceRegistrationGenerator.cs
@@ -14,11 +14,6 @@ namespace Injectio.Generators;
[Generator]
public class ServiceRegistrationGenerator : IIncrementalGenerator
{
- private static readonly SymbolDisplayFormat _fullyQualifiedNullableFormat =
- SymbolDisplayFormat.FullyQualifiedFormat.AddMiscellaneousOptions(
- SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
- );
-
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// find all classes and methods with attributes
@@ -129,7 +124,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
// make sure attribute is for registration
var attributes = methodSymbol.GetAttributes();
- var isKnown = attributes.Any(IsMethodAttribute);
+ var isKnown = attributes.Any(SymbolHelpers.IsMethodAttribute);
if (!isKnown)
return null;
@@ -139,7 +134,7 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
var registration = new ModuleRegistration
(
- ClassName: methodSymbol.ContainingType.ToDisplayString(_fullyQualifiedNullableFormat),
+ ClassName: methodSymbol.ContainingType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat),
MethodName: methodSymbol.Name,
IsStatic: methodSymbol.IsStatic,
HasTagCollection: hasTagCollection
@@ -184,7 +179,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
if (methodSymbol.Parameters.Length is 1 or 2)
{
var parameterSymbol = methodSymbol.Parameters[0];
- hasServiceCollection = IsServiceCollection(parameterSymbol);
+ hasServiceCollection = SymbolHelpers.IsServiceCollection(parameterSymbol);
}
if (methodSymbol.Parameters.Length is 1)
@@ -194,9 +189,9 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
if (methodSymbol.Parameters.Length is 2)
{
var parameterSymbol = methodSymbol.Parameters[1];
- hasTagCollection = IsStringCollection(parameterSymbol);
+ hasTagCollection = SymbolHelpers.IsStringCollection(parameterSymbol);
- // to be valid, parameter 0 must be service collection and parameter 1 must be string collection,
+ // to be valid, parameter 0 must be service collection and parameter 1 must be string collection,
return (hasServiceCollection && hasTagCollection, hasTagCollection);
}
@@ -207,7 +202,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
private static ServiceRegistration? CreateServiceRegistration(INamedTypeSymbol classSymbol, AttributeData attribute)
{
// check for known attribute
- if (!IsKnownAttribute(attribute, out var serviceLifetime))
+ if (!SymbolHelpers.IsKnownAttribute(attribute, out var serviceLifetime))
return null;
// defaults
@@ -231,12 +226,12 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
if (typeParameter.Name == "TService" || index == 0)
{
- var service = typeArgument.ToDisplayString(_fullyQualifiedNullableFormat);
+ var service = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
serviceTypes.Add(service);
}
else if (typeParameter.Name == "TImplementation" || index == 1)
{
- implementationType = typeArgument.ToDisplayString(_fullyQualifiedNullableFormat);
+ implementationType = typeArgument.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
}
}
}
@@ -256,7 +251,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
var serviceTypeSymbol = value as INamedTypeSymbol;
isOpenGeneric = isOpenGeneric || IsOpenGeneric(serviceTypeSymbol);
- var serviceType = serviceTypeSymbol?.ToDisplayString(_fullyQualifiedNullableFormat) ?? value.ToString();
+ var serviceType = serviceTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString();
serviceTypes.Add(serviceType);
break;
case "ServiceKey":
@@ -266,7 +261,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
var implementationTypeSymbol = value as INamedTypeSymbol;
isOpenGeneric = isOpenGeneric || IsOpenGeneric(implementationTypeSymbol);
- implementationType = implementationTypeSymbol?.ToDisplayString(_fullyQualifiedNullableFormat) ?? value.ToString();
+ implementationType = implementationTypeSymbol?.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat) ?? value.ToString();
break;
case "Factory":
implementationFactory = value.ToString();
@@ -275,7 +270,7 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
duplicateStrategy = ResolveDuplicateStrategy(value);
break;
case "Registration":
- registrationStrategy = ResolveRegistrationStrategy(value);
+ registrationStrategy = SymbolHelpers.ResolveRegistrationStrategy(value);
break;
case "Tags":
var tagsItems = value
@@ -304,9 +299,9 @@ private static (bool isValid, bool hasTagCollection) ValidateMethod(IMethodSymbo
// no implementation type set, use class attribute is on
if (implementationType.IsNullOrWhiteSpace())
{
- var unboundType = ToUnboundGenericType(classSymbol);
+ var unboundType = SymbolHelpers.ToUnboundGenericType(classSymbol);
isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundType);
- implementationType = unboundType.ToDisplayString(_fullyQualifiedNullableFormat);
+ implementationType = unboundType.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
}
// add implemented interfaces
@@ -321,10 +316,10 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
if (implementedInterface.ConstructedFrom.ToString() == "System.IEquatable")
continue;
- var unboundInterface = ToUnboundGenericType(implementedInterface);
+ var unboundInterface = SymbolHelpers.ToUnboundGenericType(implementedInterface);
isOpenGeneric = isOpenGeneric || IsOpenGeneric(unboundInterface);
- var interfaceName = unboundInterface.ToDisplayString(_fullyQualifiedNullableFormat);
+ var interfaceName = unboundInterface.ToDisplayString(SymbolHelpers.FullyQualifiedNullableFormat);
serviceTypes.Add(interfaceName);
}
}
@@ -354,136 +349,6 @@ or KnownTypes.RegistrationStrategySelfWithInterfacesShortName
IsOpenGeneric: isOpenGeneric);
}
- private static INamedTypeSymbol ToUnboundGenericType(INamedTypeSymbol typeSymbol)
- {
- if (!typeSymbol.IsGenericType || typeSymbol.IsUnboundGenericType)
- return typeSymbol;
-
- foreach (var typeArgument in typeSymbol.TypeArguments)
- {
- // If TypeKind is TypeParameter, it's actually the name of a locally declared type-parameter -> placeholder
- if (typeArgument.TypeKind != TypeKind.TypeParameter)
- return typeSymbol;
- }
-
- return typeSymbol.ConstructUnboundGenericType();
- }
-
- private static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime)
- {
- if (IsSingletonAttribute(attribute))
- {
- serviceLifetime = KnownTypes.ServiceLifetimeSingletonFullName;
- return true;
- }
-
- if (IsScopedAttribute(attribute))
- {
- serviceLifetime = KnownTypes.ServiceLifetimeScopedFullName;
- return true;
- }
-
- if (IsTransientAttribute(attribute))
- {
- serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName;
- return true;
- }
-
- serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName;
- return false;
- }
-
- private static bool IsTransientAttribute(AttributeData attribute)
- {
- return attribute?.AttributeClass is
- {
- Name: KnownTypes.TransientAttributeShortName or KnownTypes.TransientAttributeTypeName,
- ContainingNamespace:
- {
- Name: "Attributes",
- ContainingNamespace.Name: "Injectio"
- }
- };
- }
-
- private static bool IsSingletonAttribute(AttributeData attribute)
- {
- return attribute?.AttributeClass is
- {
- Name: KnownTypes.SingletonAttributeShortName or KnownTypes.SingletonAttributeTypeName,
- ContainingNamespace:
- {
- Name: "Attributes",
- ContainingNamespace.Name: "Injectio"
- }
- };
- }
-
- private static bool IsScopedAttribute(AttributeData attribute)
- {
- return attribute?.AttributeClass is
- {
- Name: KnownTypes.ScopedAttributeShortName or KnownTypes.ScopedAttributeTypeName,
- ContainingNamespace:
- {
- Name: "Attributes",
- ContainingNamespace.Name: "Injectio"
- }
- };
- }
-
- private static bool IsMethodAttribute(AttributeData attribute)
- {
- return attribute?.AttributeClass is
- {
- Name: KnownTypes.ModuleAttributeShortName or KnownTypes.ModuleAttributeTypeName,
- ContainingNamespace:
- {
- Name: "Attributes",
- ContainingNamespace.Name: "Injectio"
- }
- };
- }
-
- private static bool IsServiceCollection(IParameterSymbol parameterSymbol)
- {
- return parameterSymbol?.Type is
- {
- Name: "IServiceCollection" or "ServiceCollection",
- ContainingNamespace:
- {
- Name: "DependencyInjection",
- ContainingNamespace:
- {
- Name: "Extensions",
- ContainingNamespace.Name: "Microsoft"
- }
- }
- };
- }
-
- private static bool IsStringCollection(IParameterSymbol parameterSymbol)
- {
- var type = parameterSymbol?.Type as INamedTypeSymbol;
-
- return type is
- {
- Name: "IEnumerable" or "IReadOnlySet" or "IReadOnlyCollection" or "ICollection" or "ISet" or "HashSet",
- IsGenericType: true,
- TypeArguments.Length: 1,
- TypeParameters.Length: 1,
- ContainingNamespace:
- {
- Name: "Generic",
- ContainingNamespace:
- {
- Name: "Collections",
- ContainingNamespace.Name: "System"
- }
- }
- };
- }
-
private static bool IsOpenGeneric(INamedTypeSymbol? typeSymbol)
{
if (typeSymbol is null)
@@ -510,21 +375,4 @@ private static string ResolveDuplicateStrategy(object? value)
_ => KnownTypes.DuplicateStrategySkipShortName
};
}
-
- private static string ResolveRegistrationStrategy(object? value)
- {
- return value switch
- {
- int v => v switch
- {
- KnownTypes.RegistrationStrategySelfValue => KnownTypes.RegistrationStrategySelfShortName,
- KnownTypes.RegistrationStrategyImplementedInterfacesValue => KnownTypes.RegistrationStrategyImplementedInterfacesShortName,
- KnownTypes.RegistrationStrategySelfWithInterfacesValue => KnownTypes.RegistrationStrategySelfWithInterfacesShortName,
- KnownTypes.RegistrationStrategySelfWithProxyFactoryValue => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName,
- _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
- },
- string text => text,
- _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
- };
- }
}
diff --git a/src/Injectio.Generators/SymbolHelpers.cs b/src/Injectio.Generators/SymbolHelpers.cs
new file mode 100644
index 0000000..54cc867
--- /dev/null
+++ b/src/Injectio.Generators/SymbolHelpers.cs
@@ -0,0 +1,177 @@
+using Microsoft.CodeAnalysis;
+
+namespace Injectio.Generators;
+
+internal static class SymbolHelpers
+{
+ public static readonly SymbolDisplayFormat FullyQualifiedNullableFormat =
+ SymbolDisplayFormat.FullyQualifiedFormat.AddMiscellaneousOptions(
+ SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
+ );
+
+ public static bool IsMethodAttribute(AttributeData attribute)
+ {
+ return attribute?.AttributeClass is
+ {
+ Name: KnownTypes.ModuleAttributeShortName or KnownTypes.ModuleAttributeTypeName,
+ ContainingNamespace:
+ {
+ Name: "Attributes",
+ ContainingNamespace.Name: "Injectio"
+ }
+ };
+ }
+
+ public static bool IsTransientAttribute(AttributeData attribute)
+ {
+ return attribute?.AttributeClass is
+ {
+ Name: KnownTypes.TransientAttributeShortName or KnownTypes.TransientAttributeTypeName,
+ ContainingNamespace:
+ {
+ Name: "Attributes",
+ ContainingNamespace.Name: "Injectio"
+ }
+ };
+ }
+
+ public static bool IsSingletonAttribute(AttributeData attribute)
+ {
+ return attribute?.AttributeClass is
+ {
+ Name: KnownTypes.SingletonAttributeShortName or KnownTypes.SingletonAttributeTypeName,
+ ContainingNamespace:
+ {
+ Name: "Attributes",
+ ContainingNamespace.Name: "Injectio"
+ }
+ };
+ }
+
+ public static bool IsScopedAttribute(AttributeData attribute)
+ {
+ return attribute?.AttributeClass is
+ {
+ Name: KnownTypes.ScopedAttributeShortName or KnownTypes.ScopedAttributeTypeName,
+ ContainingNamespace:
+ {
+ Name: "Attributes",
+ ContainingNamespace.Name: "Injectio"
+ }
+ };
+ }
+
+ public static bool IsKnownAttribute(AttributeData attribute, out string serviceLifetime)
+ {
+ if (IsSingletonAttribute(attribute))
+ {
+ serviceLifetime = KnownTypes.ServiceLifetimeSingletonFullName;
+ return true;
+ }
+
+ if (IsScopedAttribute(attribute))
+ {
+ serviceLifetime = KnownTypes.ServiceLifetimeScopedFullName;
+ return true;
+ }
+
+ if (IsTransientAttribute(attribute))
+ {
+ serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName;
+ return true;
+ }
+
+ serviceLifetime = KnownTypes.ServiceLifetimeTransientFullName;
+ return false;
+ }
+
+ public static bool IsServiceCollection(IParameterSymbol parameterSymbol)
+ {
+ return parameterSymbol?.Type is
+ {
+ Name: "IServiceCollection" or "ServiceCollection",
+ ContainingNamespace:
+ {
+ Name: "DependencyInjection",
+ ContainingNamespace:
+ {
+ Name: "Extensions",
+ ContainingNamespace.Name: "Microsoft"
+ }
+ }
+ };
+ }
+
+ public static bool IsStringCollection(IParameterSymbol parameterSymbol)
+ {
+ var type = parameterSymbol?.Type as INamedTypeSymbol;
+
+ if (type is not
+ {
+ Name: "IEnumerable" or "IReadOnlySet" or "IReadOnlyCollection" or "ICollection" or "ISet" or "HashSet",
+ IsGenericType: true,
+ TypeArguments.Length: 1,
+ TypeParameters.Length: 1,
+ ContainingNamespace:
+ {
+ Name: "Generic",
+ ContainingNamespace:
+ {
+ Name: "Collections",
+ ContainingNamespace.Name: "System"
+ }
+ }
+ })
+ {
+ return false;
+ }
+
+ // verify the generic argument is string
+ return type.TypeArguments[0].SpecialType == SpecialType.System_String;
+ }
+
+ public static bool IsServiceProvider(IParameterSymbol parameterSymbol)
+ {
+ return parameterSymbol?.Type is
+ {
+ Name: "IServiceProvider",
+ ContainingNamespace:
+ {
+ Name: "System",
+ ContainingNamespace.IsGlobalNamespace: true
+ }
+ };
+ }
+
+ public static INamedTypeSymbol ToUnboundGenericType(INamedTypeSymbol typeSymbol)
+ {
+ if (!typeSymbol.IsGenericType || typeSymbol.IsUnboundGenericType)
+ return typeSymbol;
+
+ foreach (var typeArgument in typeSymbol.TypeArguments)
+ {
+ // If TypeKind is TypeParameter, it's actually the name of a locally declared type-parameter -> placeholder
+ if (typeArgument.TypeKind != TypeKind.TypeParameter)
+ return typeSymbol;
+ }
+
+ return typeSymbol.ConstructUnboundGenericType();
+ }
+
+ public static string ResolveRegistrationStrategy(object? value)
+ {
+ return value switch
+ {
+ int v => v switch
+ {
+ KnownTypes.RegistrationStrategySelfValue => KnownTypes.RegistrationStrategySelfShortName,
+ KnownTypes.RegistrationStrategyImplementedInterfacesValue => KnownTypes.RegistrationStrategyImplementedInterfacesShortName,
+ KnownTypes.RegistrationStrategySelfWithInterfacesValue => KnownTypes.RegistrationStrategySelfWithInterfacesShortName,
+ KnownTypes.RegistrationStrategySelfWithProxyFactoryValue => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName,
+ _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
+ },
+ string text => text,
+ _ => KnownTypes.RegistrationStrategySelfWithProxyFactoryShortName
+ };
+ }
+}
diff --git a/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs b/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs
index 882030c..bcbba2f 100644
--- a/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs
+++ b/tests/Injectio.Acceptance.Tests/DependencyInjectionBase.cs
@@ -1,11 +1,9 @@
-using Xunit.Abstractions;
-
using XUnit.Hosting;
namespace Injectio.Acceptance.Tests;
[Collection(DependencyInjectionCollection.CollectionName)]
-public abstract class DependencyInjectionBase(ITestOutputHelper output, DependencyInjectionFixture fixture)
- : TestHostBase(output, fixture)
+public abstract class DependencyInjectionBase(DependencyInjectionFixture fixture)
+ : TestHostBase(fixture)
{
}
diff --git a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj
index 8eb37db..2cdb10b 100644
--- a/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj
+++ b/tests/Injectio.Acceptance.Tests/Injectio.Acceptance.Tests.csproj
@@ -18,12 +18,8 @@
runtime; build; native; contentfiles; analyzers; buildtransitive
-
+
-
- all
- runtime; build; native; contentfiles; analyzers; buildtransitive
-
diff --git a/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs b/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs
index 1efb62e..4b7f6e0 100644
--- a/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs
+++ b/tests/Injectio.Acceptance.Tests/LibraryServiceTests.cs
@@ -4,12 +4,10 @@
using Microsoft.Extensions.DependencyInjection;
-using Xunit.Abstractions;
-
namespace Injectio.Acceptance.Tests;
[Collection(DependencyInjectionCollection.CollectionName)]
-public class LibraryServiceTests(ITestOutputHelper output, DependencyInjectionFixture fixture) : DependencyInjectionBase(output, fixture)
+public class LibraryServiceTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture)
{
[Fact]
public void ShouldResolveService()
diff --git a/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs b/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs
index 76dfe99..2fbf1f9 100644
--- a/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs
+++ b/tests/Injectio.Acceptance.Tests/LocalServiceTests.cs
@@ -4,12 +4,10 @@
using Microsoft.Extensions.DependencyInjection;
-using Xunit.Abstractions;
-
namespace Injectio.Acceptance.Tests;
[Collection(DependencyInjectionCollection.CollectionName)]
-public class LocalServiceTests(ITestOutputHelper output, DependencyInjectionFixture fixture) : DependencyInjectionBase(output, fixture)
+public class LocalServiceTests(DependencyInjectionFixture fixture) : DependencyInjectionBase(fixture)
{
[Fact]
public void ShouldResolveLocalService()
diff --git a/tests/Injectio.Tests/Injectio.Tests.csproj b/tests/Injectio.Tests/Injectio.Tests.csproj
index 6bea383..236d23b 100644
--- a/tests/Injectio.Tests/Injectio.Tests.csproj
+++ b/tests/Injectio.Tests/Injectio.Tests.csproj
@@ -14,12 +14,8 @@
-
-
-
- runtime; build; native; contentfiles; analyzers; buildtransitive
- all
-
+
+
runtime; build; native; contentfiles; analyzers; buildtransitive
all
diff --git a/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs
new file mode 100644
index 0000000..0aa4b15
--- /dev/null
+++ b/tests/Injectio.Tests/ServiceRegistrationDiagnosticTests.cs
@@ -0,0 +1,424 @@
+using System;
+using System.Collections.Immutable;
+using System.Linq;
+using System.Threading.Tasks;
+
+using AwesomeAssertions;
+
+using Injectio.Attributes;
+using Injectio.Generators;
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.Diagnostics;
+using Microsoft.Extensions.DependencyInjection;
+
+using Xunit;
+
+namespace Injectio.Tests;
+
+public class ServiceRegistrationDiagnosticTests
+{
+ [Fact]
+ public async Task DiagnoseRegisterServicesInvalidFirstParameter()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register(string test)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0001");
+ }
+
+ [Fact]
+ public async Task DiagnoseRegisterServicesInvalidSecondParameter()
+ {
+ var source = @"
+using Injectio.Attributes;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register(IServiceCollection services, string test)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0002");
+ }
+
+ [Fact]
+ public async Task DiagnoseRegisterServicesTooManyParameters()
+ {
+ var source = @"
+using Injectio.Attributes;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register(IServiceCollection services, string a, string b)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0003");
+ }
+
+ [Fact]
+ public async Task DiagnoseRegisterServicesNoParameters()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register()
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0001");
+ }
+
+ [Fact]
+ public async Task DiagnoseFactoryMethodNotFound()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterTransient(ServiceType = typeof(IService), Factory = ""NonExistentMethod"")]
+public class MyService : IService
+{
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0004");
+ }
+
+ [Fact]
+ public async Task DiagnoseFactoryMethodNotStatic()
+ {
+ var source = @"
+using System;
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))]
+public class MyService : IService
+{
+ public IService ServiceFactory(IServiceProvider serviceProvider)
+ {
+ return new MyService();
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0005");
+ }
+
+ [Fact]
+ public async Task DiagnoseFactoryMethodInvalidSignature()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))]
+public class MyService : IService
+{
+ public static IService ServiceFactory(string notServiceProvider)
+ {
+ return new MyService();
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0006");
+ }
+
+ [Fact]
+ public async Task DiagnoseServiceTypeMismatch()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+public interface IOtherService { }
+
+[RegisterTransient(ServiceType = typeof(IOtherService))]
+public class MyService : IService
+{
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0007");
+ }
+
+ [Fact]
+ public async Task DiagnoseRegisterServicesOnAbstractClassNonStaticMethod()
+ {
+ var source = @"
+using Injectio.Attributes;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Injectio.Sample;
+
+public abstract class RegistrationModule
+{
+ [RegisterServices]
+ public void Register(IServiceCollection services)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0009");
+ }
+
+ [Fact]
+ public async Task DiagnoseAbstractImplementationTypeWithoutFactory()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterSingleton]
+public abstract class AbstractService : IService
+{
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().ContainSingle(d => d.Id == "INJECT0008");
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForAbstractImplementationTypeWithFactory()
+ {
+ var source = @"
+using System;
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterSingleton(ServiceType = typeof(IService), Factory = nameof(Create))]
+public abstract class AbstractService : IService
+{
+ public static IService Create(IServiceProvider serviceProvider)
+ {
+ return null!;
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForValidRegistration()
+ {
+ var source = @"
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterSingleton]
+public class MyService : IService
+{
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForValidFactory()
+ {
+ var source = @"
+using System;
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterTransient(ServiceType = typeof(IService), Factory = nameof(ServiceFactory))]
+public class MyService : IService
+{
+ public static IService ServiceFactory(IServiceProvider serviceProvider)
+ {
+ return new MyService();
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForValidKeyedFactory()
+ {
+ var source = @"
+using System;
+using Injectio.Attributes;
+
+namespace Injectio.Sample;
+
+public interface IService { }
+
+[RegisterTransient(ServiceType = typeof(IService), ServiceKey = ""key"", Factory = nameof(ServiceFactory))]
+public class MyService : IService
+{
+ public static IService ServiceFactory(IServiceProvider serviceProvider, object? serviceKey)
+ {
+ return new MyService();
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForValidRegisterServicesMethod()
+ {
+ var source = @"
+using Injectio.Attributes;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register(IServiceCollection services)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ [Fact]
+ public async Task NoDiagnosticsForValidRegisterServicesWithTags()
+ {
+ var source = @"
+using System.Collections.Generic;
+using Injectio.Attributes;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Injectio.Sample;
+
+public static class RegistrationModule
+{
+ [RegisterServices]
+ public static void Register(IServiceCollection services, IEnumerable tags)
+ {
+ }
+}
+";
+
+ var diagnostics = await GetDiagnosticsAsync(source);
+
+ diagnostics.Should().BeEmpty();
+ }
+
+ private static async Task> GetDiagnosticsAsync(string source)
+ {
+ var syntaxTree = CSharpSyntaxTree.ParseText(source);
+ var references = AppDomain.CurrentDomain.GetAssemblies()
+ .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location))
+ .Select(assembly => MetadataReference.CreateFromFile(assembly.Location))
+ .Concat(new[]
+ {
+ MetadataReference.CreateFromFile(typeof(ServiceRegistrationGenerator).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(RegisterServicesAttribute).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location),
+ });
+
+ var compilation = CSharpCompilation.Create(
+ "Test.Diagnostics",
+ new[] { syntaxTree },
+ references,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+
+ var analyzer = new ServiceRegistrationAnalyzer();
+ var compilationWithAnalyzers = compilation.WithAnalyzers(ImmutableArray.Create(analyzer));
+ var diagnostics = await compilationWithAnalyzers.GetAnalyzerDiagnosticsAsync();
+
+ // return only Injectio diagnostics
+ return diagnostics
+ .Where(d => d.Id.StartsWith("INJECT"))
+ .ToImmutableArray();
+ }
+}