Following Fildor's advice I decided to implement a source generator to solve this issue. I've never implemented one before, so I'm open to feedback on how to improve this code. Specifically, if anyone has a working example of how to ensure that the ImmutableArray<>
class is indeed the System.Collections.Immutable
variant I'd be interested.
This code should account for edge cases such as the record being nested within interfaces/classes/records, the record being generic, the ImmutableArrays being nullable, custom Equals/GetHashCode methods already existing, etc.
I decided against locating the records via an attribute, as I want this to be the default behaviour for all of my libraries. I also have the partial check disabled in my codebase, as I want to be notified when I've created a record that has an ImmutableArray property/field and no override of the equals check, as I feel the lack of sequence equality is 'surprising' behaviour despite being the .NET default.
I've also added a custom attribute Key
to decorate properties/fields when a subset of fields should uniquely identify the record. This allows for GetHashCode to be more performant by excluding other fields.
Here is the code for the source generator:
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;
namespace SourceGenerator;
[Generator(LanguageNames.CSharp)]
public sealed class RecordEqualityGenerator : IIncrementalGenerator {
private sealed record Field {
internal Field(string name, bool isImmutableArray, bool isKey) {
Name = name;
IsImmutableArray = isImmutableArray;
IsKey = isKey;
}
internal string Name { get; }
internal bool IsImmutableArray { get; }
internal bool IsKey { get; }
}
public void Initialize(IncrementalGeneratorInitializationContext context) {
var provider = context.SyntaxProvider.CreateSyntaxProvider(Predicate, Transform);
context.RegisterSourceOutput(provider, Execute);
}
private static bool Predicate(SyntaxNode node, CancellationToken _) {
if (node is not RecordDeclarationSyntax record)
return false;
if (!record.Modifiers.Any(m => m.Text == "partial"))
return false;
var hasImmutableArray = false;
var hasEquals = false;
var hasGetHashCode = false;
foreach (var member in record.Members) {
if (IsImmutableArrayPropertyOrField(member))
hasImmutableArray = true;
if (IsEqualsMethod(member, record))
hasEquals = true;
if (IsGetHashCodeMethod(member))
hasGetHashCode = true;
}
return hasImmutableArray && (!hasEquals || !hasGetHashCode);
}
private static RecordDeclarationSyntax Transform(GeneratorSyntaxContext ctx, CancellationToken _) => (RecordDeclarationSyntax)ctx.Node;
private static void Execute(SourceProductionContext context, RecordDeclarationSyntax record) {
const string Indent = " ";
var (hasEquals, hasGetHashCode) = HasEqualsAndGetHashCodeMethods(record);
var recordName = GetNameWithGenericParameters(record);
var scopes = GetScopes(record);
var fields = GetPropertyAndFieldNames(record);
var hashes = GetHashes(fields);
var filename = GetFilename(record);
var codeBuilder = new StringBuilder()
.AppendLine("using SourceGenerator;")
.AppendLine("using System.Linq;")
.AppendLine();
var indent = string.Empty;
foreach (var scope in scopes) {
_ = codeBuilder.AppendLine($"{indent}{scope}");
if (scope.Last() == '{')
indent += Indent;
}
if (!hasEquals) {
if (fields.Length == 1) {
_ = codeBuilder.AppendLine($"{indent}public bool Equals({recordName}? other) => other?.{fields[0].Name}.SequenceEqual({fields[0].Name}) ?? false;");
}
else {
var equalityText = GetEqualityText(fields);
_ = codeBuilder
.AppendLine($"{indent}public bool Equals({recordName}? other) {{")
.AppendLine($"{indent} if (other == null)")
.AppendLine($"{indent} return false;")
.AppendLine($"{indent} if (ReferenceEquals(this, other))")
.AppendLine($"{indent} return true;")
.AppendLine()
.AppendLine($"{indent} return {equalityText};")
.AppendLine($"{indent}}}");
}
if (!hasGetHashCode)
_ = codeBuilder.AppendLine();
}
if (!hasGetHashCode) {
if (hashes.Length == 1) {
_ = codeBuilder.AppendLine($"{indent}public override int GetHashCode() => {hashes[0]};");
}
else {
_ = codeBuilder
.AppendLine($"{indent}public override int GetHashCode() {{")
.AppendLine($"{indent} const int mod = 92821;")
.AppendLine($"{indent} var hash = 17;")
.AppendLine()
.AppendLine($"{indent} unchecked {{");
foreach (var hash in hashes)
_ = codeBuilder.AppendLine($"{indent} hash = hash * mod + {hash};");
_ = codeBuilder
.AppendLine($"{indent} }}")
.AppendLine()
.AppendLine($"{indent} return hash;")
.AppendLine($"{indent}}}");
}
}
foreach (var scope in scopes) {
if (scope.Last() != '{')
continue;
indent = indent.Substring(0, indent.Length - Indent.Length);
_ = codeBuilder.AppendLine($"{indent}}}");
}
context.AddSource(filename, codeBuilder.ToString());
}
private static (bool HasEquals, bool HasGetHashCode) HasEqualsAndGetHashCodeMethods(RecordDeclarationSyntax record) {
var hasEquals = false;
var hasGetHashCode = false;
foreach (var member in record.Members) {
if (member is not MethodDeclarationSyntax method)
continue;
switch (method.Identifier.Text) {
case "Equals":
if (IsEqualsMethod(method, record))
hasEquals = true;
break;
case "GetHashCode":
if (IsGetHashCodeMethod(method))
hasGetHashCode = true;
break;
default:
break;
}
}
return (hasEquals, hasGetHashCode);
}
private static bool IsEqualsMethod(MemberDeclarationSyntax member, RecordDeclarationSyntax record) {
if (member is not MethodDeclarationSyntax method)
return false;
if (method.Identifier.Text != "Equals")
return false;
if (!method.Modifiers.Any(m => m.Text == "public"))
return false;
if (method.Modifiers.Any(m => m.Text == "static"))
return false;
if (method.ReturnType is not PredefinedTypeSyntax returnType)
return false;
if (returnType.Keyword.Text != "bool")
return false;
if (method.ParameterList.Parameters.Count != 1)
return false;
if (method.ParameterList.Parameters[0].Type is not NullableTypeSyntax nullableParameter)
return false;
if (record.TypeParameterList?.Parameters.Any() ?? false) {
if (nullableParameter.ElementType is not GenericNameSyntax genericName)
return false;
if (genericName.Identifier.Text != record.Identifier.Text)
return false;
if (genericName.TypeArgumentList.Arguments.Count != record.TypeParameterList.Parameters.Count)
return false;
if (!genericName.TypeArgumentList.Arguments.All(a => a is IdentifierNameSyntax))
return false;
if (!genericName.TypeArgumentList.Arguments.Cast<IdentifierNameSyntax>().Select(a => a.Identifier.Text).SequenceEqual(record.TypeParameterList.Parameters.Select(p => p.Identifier.Text)))
return false;
return true;
}
if (nullableParameter.ElementType is not IdentifierNameSyntax identifierName)
return false;
if (identifierName.Identifier.Text != record.Identifier.Text)
return false;
return true;
}
private static bool IsGetHashCodeMethod(MemberDeclarationSyntax member) {
if (member is not MethodDeclarationSyntax method)
return false;
if (method.Identifier.Text != "GetHashCode")
return false;
if (!method.Modifiers.Any(m => m.Text == "public"))
return false;
if (!method.Modifiers.Any(m => m.Text == "override"))
return false;
return true;
}
private static ImmutableArray<string> GetScopes(RecordDeclarationSyntax record) {
return Get(record).Reverse().ToImmutableArray();
static IEnumerable<string> Get(RecordDeclarationSyntax record) {
for (SyntaxNode? current = record; current != null; current = current.Parent) {
var scope = current switch {
FileScopedNamespaceDeclarationSyntax node => $"namespace {node.Name};\r\n",
NamespaceDeclarationSyntax node => $"namespace {node.Name} {{",
RecordDeclarationSyntax node => $"{GetModifierPrefix(node)}record {GetNameWithGenericParameters(node)} {{",
ClassDeclarationSyntax node => $"{GetModifierPrefix(node)}class {GetNameWithGenericParameters(node)} {{",
InterfaceDeclarationSyntax node => $"{GetModifierPrefix(node)}interface {GetNameWithGenericParameters(node)} {{",
_ => null
};
if (scope != null)
yield return scope;
}
}
}
private static string GetNameWithGenericParameters(TypeDeclarationSyntax type) {
if (type.TypeParameterList == null || !type.TypeParameterList.Parameters.Any())
return type.Identifier.Text;
var parameters = type.TypeParameterList.Parameters.Select(p => p.Identifier.Text);
var parametersText = string.Join(", ", parameters);
return $"{type.Identifier.Text}<{parametersText}>";
}
private static string GetModifierPrefix(TypeDeclarationSyntax type) {
var builder = new StringBuilder();
foreach (var modifier in type.Modifiers)
_ = builder.Append($"{modifier} ");
return builder.ToString();
}
private static ImmutableArray<Field> GetPropertyAndFieldNames(RecordDeclarationSyntax record) {
var fields = ImmutableArray.CreateBuilder<Field>();
foreach (var member in record.Members) {
if (member is PropertyDeclarationSyntax property) {
var isImmutableArray = IsImmutableArrayProperty(property);
var isKey = property.AttributeLists.SelectMany(l => l.Attributes).Any(a => a.Name is IdentifierNameSyntax name && name.Identifier.Text == "Key");
fields.Add(new(property.Identifier.Text, isImmutableArray, isKey));
}
else if (member is FieldDeclarationSyntax field) {
var isImmutableArray = IsImmutableArrayField(field);
var isKey = field.AttributeLists.SelectMany(l => l.Attributes).Any(a => a.Name is IdentifierNameSyntax name && name.Identifier.Text == "Key");
foreach (var variable in field.Declaration.Variables)
fields.Add(new(variable.Identifier.Text, isImmutableArray, isKey));
}
}
return fields.ToImmutable();
}
private static bool IsImmutableArrayPropertyOrField(MemberDeclarationSyntax member) {
if (member is PropertyDeclarationSyntax property)
return IsImmutableArrayProperty(property);
if (member is FieldDeclarationSyntax field)
return IsImmutableArrayField(field);
return false;
}
private static bool IsImmutableArrayProperty(PropertyDeclarationSyntax property) {
if (property.Type is GenericNameSyntax generic) {
if (generic.Identifier.Text != "ImmutableArray")
return false;
if (generic.TypeArgumentList.Arguments.Count != 1)
return false;
return true;
}
if (property.Type is not NullableTypeSyntax nullable)
return false;
if (nullable.ElementType is not GenericNameSyntax nullableGeneric)
return false;
if (nullableGeneric.Identifier.Text != "ImmutableArray")
return false;
if (nullableGeneric.TypeArgumentList.Arguments.Count != 1)
return false;
return true;
}
private static bool IsImmutableArrayField(FieldDeclarationSyntax field) {
if (field.Declaration.Type is GenericNameSyntax generic) {
if (generic.Identifier.Text != "ImmutableArray")
return false;
if (generic.TypeArgumentList.Arguments.Count != 1)
return false;
return true;
}
if (field.Declaration.Type is not NullableTypeSyntax nullable)
return false;
if (nullable.ElementType is not GenericNameSyntax nullableGeneric)
return false;
if (nullableGeneric.Identifier.Text != "ImmutableArray")
return false;
if (nullableGeneric.TypeArgumentList.Arguments.Count != 1)
return false;
return true;
}
private static string GetEqualityText(ImmutableArray<Field> fields) {
var builder = new StringBuilder();
var isFirst = true;
foreach (var field in fields.Where(f => !f.IsImmutableArray)) {
if (!isFirst)
_ = builder.Append(" && ");
_ = builder.Append($"{field.Name} == other.{field.Name}");
isFirst = false;
}
foreach (var field in fields.Where(f => f.IsImmutableArray)) {
if (!isFirst)
_ = builder.Append(" && ");
_ = builder.Append($"{field.Name}.SequenceEqual(other.{field.Name})");
isFirst = false;
}
return builder.ToString();
}
private static ImmutableArray<string> GetHashes(ImmutableArray<Field> fields) {
var keys = fields.Where(f => f.IsKey).ToImmutableArray();
if (!keys.IsEmpty)
fields = keys;
var hashes = ImmutableArray.CreateBuilder<string>(fields.Length);
foreach (var field in fields.Where(f => !f.IsImmutableArray))
hashes.Add($"{field.Name}.GetHashCode()");
foreach (var field in fields.Where(f => f.IsImmutableArray))
hashes.Add($"{field.Name}.GenerateSequenceHashCode()");
return hashes.MoveToImmutable();
}
private static string GetFilename(RecordDeclarationSyntax record) {
var parts = GetParts(record).Reverse();
var stem = string.Join(".", parts);
return $"{stem}.g.cs";
static IEnumerable<string> GetParts(RecordDeclarationSyntax record) {
for (SyntaxNode? current = record; current != null; current = current.Parent) {
var scope = current switch {
FileScopedNamespaceDeclarationSyntax node => node.Name.ToString(),
NamespaceDeclarationSyntax node => node.Name.ToString(),
RecordDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
ClassDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
InterfaceDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
_ => null
};
if (scope != null)
yield return scope;
}
}
}
private static string GetFilenamePartWithGenericParameters(TypeDeclarationSyntax type) {
if (type.TypeParameterList == null || !type.TypeParameterList.Parameters.Any())
return type.Identifier.Text;
var parameters = type.TypeParameterList.Parameters.Select(p => p.Identifier.Text);
var parametersText = string.Join(", ", parameters);
return $"{type.Identifier.Text}[{parametersText}]";
}
}
This is an attribute I implemented that enables for more efficient GetHashCode generation when a subset of fields should uniquely identify the record:
using System;
namespace SourceGenerator;
[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false, Inherited = false)]
public sealed class KeyAttribute : Attribute { }
And these are the extension methods for generating a hash code for an ImmutableArray based off the elements it contains, and for wrappering sequence equality on nullable ImmutableArrays:
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
namespace SourceGenerator;
public static class Extensions {
public static bool SequenceEqual<T>(this Nullable<ImmutableArray<T>> self, Nullable<ImmutableArray<T>> other, IEqualityComparer<T>? comparer = null) {
if (self == null)
return other == null;
if (other == null)
return false;
return self.Value.SequenceEqual(other.Value, comparer);
}
public static int GenerateSequenceHashCode<T>(this ImmutableArray<T> values) where T : notnull {
const int mod = 92821;
var hash = 17;
unchecked {
foreach (var value in values)
hash = hash * mod + value.GetHashCode();
}
return hash;
}
public static int GenerateSequenceHashCode<T>(this Nullable<ImmutableArray<T>> values) where T : notnull => values != null ? values.GenerateSequenceHashCode() : 0;
}