Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 319 additions & 0 deletions src/MoreAsyncLINQ.Generators/AggregateGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
using System.CodeDom.Compiler;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace MoreAsyncLINQ.Generators;

[Generator]
public class AggregateGenerator : IIncrementalGenerator
{
private const int MinAccumulators = 2;
private const int MaxAccumulators = 8;

private static readonly string[] _ordinals = ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth"];

public void Initialize(IncrementalGeneratorInitializationContext context) =>
context.RegisterPostInitializationOutput(postInitializationContext =>
{
var source = GenerateOverloads();
postInitializationContext.AddSource("MoreAsyncEnumerable.Aggregate.g.cs", source);
});

private string GenerateOverloads()
{
using var stringWriter = new StringWriter();
using var writer = new IndentedTextWriter(stringWriter);

writer.WriteLine("// <auto-generated/>");
writer.WriteLine("#nullable enable");
writer.WriteLine();
writer.WriteLine("using System;");
writer.WriteLine("using System.Collections.Generic;");
writer.WriteLine("using System.Linq;");
writer.WriteLine("using System.Threading;");
writer.WriteLine("using System.Threading.Tasks;");
writer.WriteLine();
writer.WriteLine("namespace MoreAsyncLINQ;");
writer.WriteLine();
writer.WriteLine("static partial class MoreAsyncEnumerable");
writer.WriteLine("{");
writer.Indent++;

for (var arity = MinAccumulators; arity <= MaxAccumulators; arity++)
{
GenerateSyncOverload(writer, arity);
writer.WriteLine();
GenerateAsyncOverload(writer, arity);

if (arity < MaxAccumulators)
{
writer.WriteLine();
}
}

writer.Indent--;
writer.WriteLine("}");

return stringWriter.ToString();
}

private void GenerateSyncOverload(IndentedTextWriter writer, int arity)
{
WriteXmlDocumentation(writer, arity);
WriteMethodSignature(writer, arity, isAsync: false);

writer.WriteLine("{");
writer.Indent++;

WriteNullChecks(writer, arity);
writer.WriteLine();

writer.WriteLine("return source.IsKnownEmpty()");
writer.Indent++;
writer.WriteLine("? ValueTasks.FromResult(");
writer.Indent++;
writer.WriteLine("resultSelector(");
writer.Indent++;
WriteSeedArgs(writer, arity);
writer.WriteLine("))");
writer.Indent--;
writer.Indent--;
writer.WriteLine(": Core(");
writer.Indent++;
WriteCoreCallArgs(writer, arity);
writer.WriteLine(");");
writer.Indent--;
writer.Indent--;

writer.WriteLine();
WriteCoreMethod(writer, arity, isAsync: false);

writer.Indent--;
writer.WriteLine("}");
}

private void GenerateAsyncOverload(IndentedTextWriter writer, int arity)
{
WriteXmlDocumentation(writer, arity);
WriteMethodSignature(writer, arity, isAsync: true);

writer.WriteLine("{");
writer.Indent++;

WriteNullChecks(writer, arity);
writer.WriteLine();

writer.WriteLine("return Core(");
writer.Indent++;
WriteCoreCallArgs(writer, arity);
writer.WriteLine(");");
writer.Indent--;

writer.WriteLine();
WriteCoreMethod(writer, arity, isAsync: true);

writer.Indent--;
writer.WriteLine("}");
}

private void WriteXmlDocumentation(IndentedTextWriter writer, int arity)
{
writer = new IndentedTextWriter(writer, "/// ");
writer.Indent++;

writer.WriteLine("<summary>");
writer.WriteLine($"Applies {_ordinals[arity - 1]} accumulators sequentially in a single pass over a");
writer.WriteLine("sequence.");
writer.WriteLine("</summary>");
writer.WriteLine("<typeparam name=\"TSource\">The type of elements in <paramref name=\"source\"/>.</typeparam>");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"<typeparam name=\"TAccumulate{index}\">The type of {_ordinals[index - 1]} accumulator value.</typeparam>");
}

writer.WriteLine("<typeparam name=\"TResult\">The type of the accumulated result.</typeparam>");
writer.WriteLine("<param name=\"source\">The source sequence</param>");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"<param name=\"seed{index}\">The seed value for the {_ordinals[index - 1]} accumulator.</param>");
writer.WriteLine($"<param name=\"accumulator{index}\">The {_ordinals[index - 1]} accumulator.</param>");
}

writer.WriteLine("<param name=\"resultSelector\">");
writer.WriteLine("A function that projects a single result given the result of each");
writer.WriteLine("accumulator.</param>");
writer.WriteLine("<param name=\"cancellationToken\">The optional cancellation token to be used for cancelling the sequence at any time.</param>");
writer.WriteLine("<returns>The value returned by <paramref name=\"resultSelector\"/>.</returns>");
writer.WriteLine("<remarks>");
writer.WriteLine("This operator executes immediately.");
writer.WriteLine("</remarks>");
}

private void WriteMethodSignature(IndentedTextWriter writer, int arity, bool isAsync)
{
writer.WriteLine("public static ValueTask<TResult> AggregateAsync<");
writer.Indent++;

// Type parameters
writer.WriteLine("TSource,");
for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"TAccumulate{index},");
}
writer.WriteLine("TResult>(");

// Method parameters
writer.WriteLine("this IAsyncEnumerable<TSource> source,");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"TAccumulate{index} seed{index},");

var funcType =
isAsync
? $"Func<TAccumulate{index}, TSource, CancellationToken, ValueTask<TAccumulate{index}>>"
: $"Func<TAccumulate{index}, TSource, TAccumulate{index}>";

writer.WriteLine($"{funcType} accumulator{index},");
}

var resultSelectorTypes = string.Join(", ", Enumerable.Range(1, arity).Select(index => $"TAccumulate{index}"));
var resultSelectorType =
isAsync
? $"Func<{resultSelectorTypes}, CancellationToken, ValueTask<TResult>>"
: $"Func<{resultSelectorTypes}, TResult>";

writer.WriteLine($"{resultSelectorType} resultSelector,");
writer.WriteLine("CancellationToken cancellationToken = default)");

writer.Indent--;
}

private void WriteNullChecks(IndentedTextWriter writer, int arity)
{
writer.WriteLine("if (source is null) throw new ArgumentNullException(nameof(source));");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"if (accumulator{index} is null) throw new ArgumentNullException(nameof(accumulator{index}));");
}

writer.WriteLine("if (resultSelector is null) throw new ArgumentNullException(nameof(resultSelector));");
}

private void WriteCoreCallArgs(IndentedTextWriter writer, int arity)
{
writer.WriteLine("source,");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"seed{index},");
writer.WriteLine($"accumulator{index},");
}

writer.WriteLine("resultSelector,");
writer.Write("cancellationToken");
}

private void WriteSeedArgs(IndentedTextWriter writer, int arity)
{
for (var index = 1; index <= arity; index++)
{
writer.Write($"seed{index}");

if (index < arity)
{
writer.WriteLine(",");
}
}
}

private void WriteCoreMethod(IndentedTextWriter writer, int arity, bool isAsync)
{
writer.WriteLine("static async ValueTask<TResult> Core(");
writer.Indent++;

writer.WriteLine("IAsyncEnumerable<TSource> source,");

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"TAccumulate{index} seed{index},");

var funcType =
isAsync
? $"Func<TAccumulate{index}, TSource, CancellationToken, ValueTask<TAccumulate{index}>>"
: $"Func<TAccumulate{index}, TSource, TAccumulate{index}>";

writer.WriteLine($"{funcType} accumulator{index},");
}

var resultSelectorTypes = string.Join(", ", Enumerable.Range(1, arity).Select(index => $"TAccumulate{index}"));
var resultSelectorType =
isAsync
? $"Func<{resultSelectorTypes}, CancellationToken, ValueTask<TResult>>"
: $"Func<{resultSelectorTypes}, TResult>";

writer.WriteLine($"{resultSelectorType} resultSelector,");
writer.WriteLine("CancellationToken cancellationToken)");

writer.Indent--;
writer.WriteLine("{");
writer.Indent++;

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"var accumulate{index} = seed{index};");
}

writer.WriteLine();
writer.WriteLine("await foreach (var element in source.WithCancellation(cancellationToken))");
writer.WriteLine("{");
writer.Indent++;

for (var index = 1; index <= arity; index++)
{
writer.WriteLine(
isAsync
? $"accumulate{index} = await accumulator{index}(accumulate{index}, element, cancellationToken);"
: $"accumulate{index} = accumulator{index}(accumulate{index}, element);");
}

writer.Indent--;
writer.WriteLine("}");
writer.WriteLine();

if (isAsync)
{
writer.WriteLine("return await resultSelector(");
writer.Indent++;

for (var index = 1; index <= arity; index++)
{
writer.WriteLine($"accumulate{index},");
}

writer.WriteLine("cancellationToken);");
}
else
{
writer.WriteLine("return resultSelector(");
writer.Indent++;

for (var index = 1; index <= arity; index++)
{
writer.WriteLine(
index < arity
? $"accumulate{index},"
: $"accumulate{index});");
}
}

writer.Indent--;

writer.Indent--;
writer.WriteLine("}");
}
}
Loading