79708504

Date: 2025-07-21 03:07:18
Score: 0.5
Natty:
Report link

Following Fildor's advice I decided to implement a source generator to solve this issue. I've never implemented one before, so I'm open to feedback on how to improve this code. Specifically, if anyone has a working example of how to ensure that the ImmutableArray<> class is indeed the System.Collections.Immutable variant I'd be interested.

This code should account for edge cases such as the record being nested within interfaces/classes/records, the record being generic, the ImmutableArrays being nullable, custom Equals/GetHashCode methods already existing, etc.

I decided against locating the records via an attribute, as I want this to be the default behaviour for all of my libraries. I also have the partial check disabled in my codebase, as I want to be notified when I've created a record that has an ImmutableArray property/field and no override of the equals check, as I feel the lack of sequence equality is 'surprising' behaviour despite being the .NET default.

I've also added a custom attribute Key to decorate properties/fields when a subset of fields should uniquely identify the record. This allows for GetHashCode to be more performant by excluding other fields.

Here is the code for the source generator:

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;

namespace SourceGenerator;

[Generator(LanguageNames.CSharp)]
public sealed class RecordEqualityGenerator : IIncrementalGenerator {
    private sealed record Field {
        internal Field(string name, bool isImmutableArray, bool isKey) {
            Name = name;
            IsImmutableArray = isImmutableArray;
            IsKey = isKey;
        }

        internal string Name { get; }
        internal bool IsImmutableArray { get; }
        internal bool IsKey { get; }
    }

    public void Initialize(IncrementalGeneratorInitializationContext context) {
        var provider = context.SyntaxProvider.CreateSyntaxProvider(Predicate, Transform);
        context.RegisterSourceOutput(provider, Execute);
    }

    private static bool Predicate(SyntaxNode node, CancellationToken _) {
        if (node is not RecordDeclarationSyntax record)
            return false;

        if (!record.Modifiers.Any(m => m.Text == "partial"))
            return false;

        var hasImmutableArray = false;
        var hasEquals = false;
        var hasGetHashCode = false;
        foreach (var member in record.Members) {
            if (IsImmutableArrayPropertyOrField(member))
                hasImmutableArray = true;
            if (IsEqualsMethod(member, record))
                hasEquals = true;
            if (IsGetHashCodeMethod(member))
                hasGetHashCode = true;
        }

        return hasImmutableArray && (!hasEquals || !hasGetHashCode);
    }

    private static RecordDeclarationSyntax Transform(GeneratorSyntaxContext ctx, CancellationToken _) => (RecordDeclarationSyntax)ctx.Node;

    private static void Execute(SourceProductionContext context, RecordDeclarationSyntax record) {
        const string Indent = "    ";

        var (hasEquals, hasGetHashCode) = HasEqualsAndGetHashCodeMethods(record);

        var recordName = GetNameWithGenericParameters(record);
        var scopes = GetScopes(record);
        var fields = GetPropertyAndFieldNames(record);
        var hashes = GetHashes(fields);
        var filename = GetFilename(record);

        var codeBuilder = new StringBuilder()
            .AppendLine("using SourceGenerator;")
            .AppendLine("using System.Linq;")
            .AppendLine();

        var indent = string.Empty;
        foreach (var scope in scopes) {
            _ = codeBuilder.AppendLine($"{indent}{scope}");
            if (scope.Last() == '{')
                indent += Indent;
        }

        if (!hasEquals) {
            if (fields.Length == 1) {
                _ = codeBuilder.AppendLine($"{indent}public bool Equals({recordName}? other) => other?.{fields[0].Name}.SequenceEqual({fields[0].Name}) ?? false;");
            }
            else {
                var equalityText = GetEqualityText(fields);
                _ = codeBuilder
                    .AppendLine($"{indent}public bool Equals({recordName}? other) {{")
                    .AppendLine($"{indent}    if (other == null)")
                    .AppendLine($"{indent}        return false;")
                    .AppendLine($"{indent}    if (ReferenceEquals(this, other))")
                    .AppendLine($"{indent}        return true;")
                    .AppendLine()
                    .AppendLine($"{indent}    return {equalityText};")
                    .AppendLine($"{indent}}}");
            }

            if (!hasGetHashCode)
                _ = codeBuilder.AppendLine();
        }

        if (!hasGetHashCode) {
            if (hashes.Length == 1) {
                _ = codeBuilder.AppendLine($"{indent}public override int GetHashCode() => {hashes[0]};");
            }
            else {
                _ = codeBuilder
                    .AppendLine($"{indent}public override int GetHashCode() {{")
                    .AppendLine($"{indent}    const int mod = 92821;")
                    .AppendLine($"{indent}    var hash = 17;")
                    .AppendLine()
                    .AppendLine($"{indent}    unchecked {{");
        
                foreach (var hash in hashes)
                    _ = codeBuilder.AppendLine($"{indent}        hash = hash * mod + {hash};");

                _ = codeBuilder
                    .AppendLine($"{indent}    }}")
                    .AppendLine()
                    .AppendLine($"{indent}    return hash;")
                    .AppendLine($"{indent}}}");
            }
        }

        foreach (var scope in scopes) {
            if (scope.Last() != '{')
                continue;

            indent = indent.Substring(0, indent.Length - Indent.Length);
            _ = codeBuilder.AppendLine($"{indent}}}");
        }

        context.AddSource(filename, codeBuilder.ToString());
    }

    private static (bool HasEquals, bool HasGetHashCode) HasEqualsAndGetHashCodeMethods(RecordDeclarationSyntax record) {
        var hasEquals = false;
        var hasGetHashCode = false;
        foreach (var member in record.Members) {
            if (member is not MethodDeclarationSyntax method)
                continue;

            switch (method.Identifier.Text) {
                case "Equals":
                    if (IsEqualsMethod(method, record))
                        hasEquals = true;
                    break;
                case "GetHashCode":
                    if (IsGetHashCodeMethod(method))
                        hasGetHashCode = true;
                    break;
                default:
                    break;
            }
        }
        return (hasEquals, hasGetHashCode);
    }

    private static bool IsEqualsMethod(MemberDeclarationSyntax member, RecordDeclarationSyntax record) {
        if (member is not MethodDeclarationSyntax method)
            return false;
        if (method.Identifier.Text != "Equals")
            return false;
        if (!method.Modifiers.Any(m => m.Text == "public"))
            return false;
        if (method.Modifiers.Any(m => m.Text == "static"))
            return false;
        if (method.ReturnType is not PredefinedTypeSyntax returnType)
            return false;
        if (returnType.Keyword.Text != "bool")
            return false;
        if (method.ParameterList.Parameters.Count != 1)
            return false;
        if (method.ParameterList.Parameters[0].Type is not NullableTypeSyntax nullableParameter)
            return false;

        if (record.TypeParameterList?.Parameters.Any() ?? false) {
            if (nullableParameter.ElementType is not GenericNameSyntax genericName)
                return false;
            if (genericName.Identifier.Text != record.Identifier.Text)
                return false;
            if (genericName.TypeArgumentList.Arguments.Count != record.TypeParameterList.Parameters.Count)
                return false;
            if (!genericName.TypeArgumentList.Arguments.All(a => a is IdentifierNameSyntax))
                return false;
            if (!genericName.TypeArgumentList.Arguments.Cast<IdentifierNameSyntax>().Select(a => a.Identifier.Text).SequenceEqual(record.TypeParameterList.Parameters.Select(p => p.Identifier.Text)))
                return false;
            return true;
        }

        if (nullableParameter.ElementType is not IdentifierNameSyntax identifierName)
            return false;
        if (identifierName.Identifier.Text != record.Identifier.Text)
            return false;
        return true;
    }

    private static bool IsGetHashCodeMethod(MemberDeclarationSyntax member) {
        if (member is not MethodDeclarationSyntax method)
            return false;
        if (method.Identifier.Text != "GetHashCode")
            return false;
        if (!method.Modifiers.Any(m => m.Text == "public"))
            return false;
        if (!method.Modifiers.Any(m => m.Text == "override"))
            return false;
        return true;
    }

    private static ImmutableArray<string> GetScopes(RecordDeclarationSyntax record) {
        return Get(record).Reverse().ToImmutableArray();

        static IEnumerable<string> Get(RecordDeclarationSyntax record) {
            for (SyntaxNode? current = record; current != null; current = current.Parent) {
                var scope = current switch {
                    FileScopedNamespaceDeclarationSyntax node => $"namespace {node.Name};\r\n",
                    NamespaceDeclarationSyntax node => $"namespace {node.Name} {{",
                    RecordDeclarationSyntax node => $"{GetModifierPrefix(node)}record {GetNameWithGenericParameters(node)} {{",
                    ClassDeclarationSyntax node => $"{GetModifierPrefix(node)}class {GetNameWithGenericParameters(node)} {{",
                    InterfaceDeclarationSyntax node => $"{GetModifierPrefix(node)}interface {GetNameWithGenericParameters(node)} {{",
                    _ => null
                };
                if (scope != null)
                    yield return scope;
            }
        }
    }

    private static string GetNameWithGenericParameters(TypeDeclarationSyntax type) {
        if (type.TypeParameterList == null || !type.TypeParameterList.Parameters.Any())
            return type.Identifier.Text;

        var parameters = type.TypeParameterList.Parameters.Select(p => p.Identifier.Text);
        var parametersText = string.Join(", ", parameters);
        return $"{type.Identifier.Text}<{parametersText}>";
    }

    private static string GetModifierPrefix(TypeDeclarationSyntax type) {
        var builder = new StringBuilder();
        foreach (var modifier in type.Modifiers)
            _ = builder.Append($"{modifier} ");
        return builder.ToString();
    }

    private static ImmutableArray<Field> GetPropertyAndFieldNames(RecordDeclarationSyntax record) {
        var fields = ImmutableArray.CreateBuilder<Field>();
        foreach (var member in record.Members) {
            if (member is PropertyDeclarationSyntax property) {
                var isImmutableArray = IsImmutableArrayProperty(property);
                var isKey = property.AttributeLists.SelectMany(l => l.Attributes).Any(a => a.Name is IdentifierNameSyntax name && name.Identifier.Text == "Key");
                fields.Add(new(property.Identifier.Text, isImmutableArray, isKey));
            }
            else if (member is FieldDeclarationSyntax field) {
                var isImmutableArray = IsImmutableArrayField(field);
                var isKey = field.AttributeLists.SelectMany(l => l.Attributes).Any(a => a.Name is IdentifierNameSyntax name && name.Identifier.Text == "Key");
                foreach (var variable in field.Declaration.Variables)
                    fields.Add(new(variable.Identifier.Text, isImmutableArray, isKey));
            }
        }
        return fields.ToImmutable();
    }


    private static bool IsImmutableArrayPropertyOrField(MemberDeclarationSyntax member) {
        if (member is PropertyDeclarationSyntax property)
            return IsImmutableArrayProperty(property);
        if (member is FieldDeclarationSyntax field)
            return IsImmutableArrayField(field);
        return false;
    }

    private static bool IsImmutableArrayProperty(PropertyDeclarationSyntax property) {
        if (property.Type is GenericNameSyntax generic) {
            if (generic.Identifier.Text != "ImmutableArray")
                return false;
            if (generic.TypeArgumentList.Arguments.Count != 1)
                return false;
            return true;
        }

        if (property.Type is not NullableTypeSyntax nullable)
            return false;
        if (nullable.ElementType is not GenericNameSyntax nullableGeneric)
            return false;
        if (nullableGeneric.Identifier.Text != "ImmutableArray")
            return false;
        if (nullableGeneric.TypeArgumentList.Arguments.Count != 1)
            return false;
        return true;
    }

    private static bool IsImmutableArrayField(FieldDeclarationSyntax field) {
        if (field.Declaration.Type is GenericNameSyntax generic) {
            if (generic.Identifier.Text != "ImmutableArray")
                return false;
            if (generic.TypeArgumentList.Arguments.Count != 1)
                return false;
            return true;
        }

        if (field.Declaration.Type is not NullableTypeSyntax nullable)
            return false;
        if (nullable.ElementType is not GenericNameSyntax nullableGeneric)
            return false;
        if (nullableGeneric.Identifier.Text != "ImmutableArray")
            return false;
        if (nullableGeneric.TypeArgumentList.Arguments.Count != 1)
            return false;
        return true;
    }

    private static string GetEqualityText(ImmutableArray<Field> fields) {
        var builder = new StringBuilder();
        var isFirst = true;
        foreach (var field in fields.Where(f => !f.IsImmutableArray)) {
            if (!isFirst)
                _ = builder.Append(" && ");
            _ = builder.Append($"{field.Name} == other.{field.Name}");
            isFirst = false;
        }
        foreach (var field in fields.Where(f => f.IsImmutableArray)) {
            if (!isFirst)
                _ = builder.Append(" && ");
            _ = builder.Append($"{field.Name}.SequenceEqual(other.{field.Name})");
            isFirst = false;
        }
        return builder.ToString();
    }

    private static ImmutableArray<string> GetHashes(ImmutableArray<Field> fields) {
        var keys = fields.Where(f => f.IsKey).ToImmutableArray();
        if (!keys.IsEmpty)
            fields = keys;

        var hashes = ImmutableArray.CreateBuilder<string>(fields.Length);
        foreach (var field in fields.Where(f => !f.IsImmutableArray))
            hashes.Add($"{field.Name}.GetHashCode()");
        foreach (var field in fields.Where(f => f.IsImmutableArray))
            hashes.Add($"{field.Name}.GenerateSequenceHashCode()");
        return hashes.MoveToImmutable();
    }

    private static string GetFilename(RecordDeclarationSyntax record) {
        var parts = GetParts(record).Reverse();
        var stem = string.Join(".", parts);
        return $"{stem}.g.cs";

        static IEnumerable<string> GetParts(RecordDeclarationSyntax record) {
            for (SyntaxNode? current = record; current != null; current = current.Parent) {
                var scope = current switch {
                    FileScopedNamespaceDeclarationSyntax node => node.Name.ToString(),
                    NamespaceDeclarationSyntax node => node.Name.ToString(),
                    RecordDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
                    ClassDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
                    InterfaceDeclarationSyntax node => GetFilenamePartWithGenericParameters(node),
                    _ => null
                };
                if (scope != null)
                    yield return scope;
            }
        }
    }

    private static string GetFilenamePartWithGenericParameters(TypeDeclarationSyntax type) {
        if (type.TypeParameterList == null || !type.TypeParameterList.Parameters.Any())
            return type.Identifier.Text;

        var parameters = type.TypeParameterList.Parameters.Select(p => p.Identifier.Text);
        var parametersText = string.Join(", ", parameters);
        return $"{type.Identifier.Text}[{parametersText}]";
    }
}

This is an attribute I implemented that enables for more efficient GetHashCode generation when a subset of fields should uniquely identify the record:

using System;

namespace SourceGenerator;

[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false, Inherited = false)]
public sealed class KeyAttribute : Attribute { }

And these are the extension methods for generating a hash code for an ImmutableArray based off the elements it contains, and for wrappering sequence equality on nullable ImmutableArrays:

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace SourceGenerator;

public static class Extensions {
    public static bool SequenceEqual<T>(this Nullable<ImmutableArray<T>> self, Nullable<ImmutableArray<T>> other, IEqualityComparer<T>? comparer = null) {
        if (self == null)
            return other == null;
        if (other == null)
            return false;
        return self.Value.SequenceEqual(other.Value, comparer);
    }

    public static int GenerateSequenceHashCode<T>(this ImmutableArray<T> values) where T : notnull {
        const int mod = 92821;
        var hash = 17;

        unchecked {
            foreach (var value in values)
                hash = hash * mod + value.GetHashCode();
        }

        return hash;
    }

    public static int GenerateSequenceHashCode<T>(this Nullable<ImmutableArray<T>> values) where T : notnull => values != null ? values.GenerateSequenceHashCode() : 0;
}
Reasons:
  • RegEx Blacklisted phrase (1): I want
  • Long answer (-1):
  • Has code block (-0.5):
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: Bioinformagician