Skip to content

Commit 316b7b8

Browse files
Added Merge extension with tests.
1 parent 7cb47fd commit 316b7b8

3 files changed

Lines changed: 189 additions & 2 deletions

File tree

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
namespace Open.ChannelExtensions.Tests;
2+
3+
public static class MergeTests
4+
{
5+
[Fact()]
6+
public static async Task BasicMergeTest()
7+
{
8+
// Arrange
9+
const int total = 3000000;
10+
const int bound = 100;
11+
12+
// 3 channels
13+
var c1 = Channel.CreateBounded<int>(bound);
14+
var c2 = Channel.CreateBounded<int>(bound);
15+
var c3 = Channel.CreateBounded<int>(bound);
16+
var writers = new[] { c1.Writer, c2.Writer, c3.Writer };
17+
18+
// 3 readers
19+
var merging = new[] { c1.Reader, c2.Reader, c3.Reader }.Merge().ToListAsync(total);
20+
21+
// Act
22+
await Parallel.ForAsync(0, total,
23+
(i, token) => writers[i % 3].WriteAsync(i, token));
24+
25+
foreach (var writer in writers)
26+
writer.Complete();
27+
28+
var merged = await merging;
29+
merged.Sort();
30+
31+
// Assert
32+
Assert.Equal(total, merged.Count);
33+
Assert.True(Enumerable.Range(0, total).SequenceEqual(merged));
34+
}
35+
36+
[Fact()]
37+
public static async Task ExceptionPropagationTest()
38+
{
39+
// Arrange
40+
const int total = 3000000;
41+
const int bound = 100;
42+
43+
// 3 channels
44+
var c1 = Channel.CreateBounded<int>(bound);
45+
var c2 = Channel.CreateBounded<int>(bound);
46+
var c3 = Channel.CreateBounded<int>(bound);
47+
var writers = new[] { c1.Writer, c2.Writer, c3.Writer };
48+
49+
// 3 readers
50+
var merging = new[] { c1.Reader, c2.Reader, c3.Reader }.Merge();
51+
var list = merging.ToListAsync(total);
52+
53+
// Act
54+
await Assert.ThrowsAsync<ChannelClosedException>(() => Parallel.ForAsync(0, total,
55+
async (i, token) =>
56+
{
57+
var w = writers[i % 3];
58+
if (i == total / 2)
59+
w.Complete(new Exception("Test"));
60+
else
61+
await w.WriteAsync(i, token).ConfigureAwait(false);
62+
}));
63+
64+
// Assert
65+
await Assert.ThrowsAsync<Exception>(list.AsTask);
66+
await Assert.ThrowsAsync<Exception>(() => merging.Completion);
67+
}
68+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
namespace Open.ChannelExtensions;
2+
3+
public static partial class Extensions
4+
{
5+
sealed class MergingChannelReader<T> : ChannelReader<T>
6+
{
7+
public MergingChannelReader(IEnumerable<ChannelReader<T>> sources)
8+
{
9+
if (sources is null) throw new ArgumentNullException(nameof(sources));
10+
Contract.EndContractBlock();
11+
12+
var active = sources.Where(s =>
13+
{
14+
Debug.Assert(s is not null);
15+
return s.Completion.Status != TaskStatus.RanToCompletion;
16+
}).ToArray();
17+
18+
_count = active.Length;
19+
20+
if (_count == 0)
21+
{
22+
_sources = [];
23+
Completion = Task.CompletedTask;
24+
return;
25+
}
26+
27+
_sources = active;
28+
29+
// Capture the initial list of completions.
30+
var completions = active.Select(e => e.Completion).ToList();
31+
32+
// Create a task that completes when any of the sources are faulted:
33+
Completion = Task.Run(async () =>
34+
{
35+
// Wait for any of the tasks to complete and let it throw if it faults.
36+
while (completions.Count != 0)
37+
{
38+
var completed = await Task.WhenAny(completions).ConfigureAwait(false);
39+
// Propagate the exception.
40+
await completed.ConfigureAwait(false);
41+
completions.Remove(completed);
42+
}
43+
});
44+
}
45+
46+
private readonly ChannelReader<T>[] _sources;
47+
public override Task Completion { get; }
48+
49+
readonly int _count;
50+
int _next = -1;
51+
52+
public override bool TryRead(out T item)
53+
{
54+
// Try as many times as there are sources before giving up.
55+
for (var attempt = 0; attempt < _count; attempt++)
56+
{
57+
// If the value overflows, it will be negative, which is fine, we'll adapt.
58+
var i = Interlocked.Increment(ref _next) % _count;
59+
if (i < 0) i += _count;
60+
var source = _sources[i];
61+
62+
if (source.TryRead(out T? s))
63+
{
64+
item = s;
65+
return true;
66+
}
67+
}
68+
69+
item = default!;
70+
return false;
71+
}
72+
73+
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
74+
{
75+
var completion = Completion;
76+
if (Completion.IsCompleted)
77+
{
78+
return completion.IsFaulted
79+
? new ValueTask<bool>(Task.FromException<bool>(completion.Exception!))
80+
: new ValueTask<bool>(false);
81+
}
82+
83+
if (cancellationToken.IsCancellationRequested)
84+
return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken));
85+
86+
// Not complete or cancelled? Wait on the sources.
87+
return WaitToReadAsyncCore(cancellationToken);
88+
}
89+
90+
private async ValueTask<bool> WaitToReadAsyncCore(CancellationToken cancellationToken)
91+
{
92+
retry:
93+
// We don't care about ones that have already completed.
94+
var active = _sources.Where(s => s.Completion.Status != TaskStatus.RanToCompletion).ToArray();
95+
if (active.Length == 0) return false;
96+
97+
var next = await Task.WhenAny(active.Select(s => s.WaitToReadAsync(cancellationToken).AsTask())).ConfigureAwait(false);
98+
99+
// Allow for possible exception to be thrown.
100+
var result = await next.ConfigureAwait(false);
101+
if (result) return true;
102+
103+
// If result was false, then there's one less and we should try again.
104+
goto retry;
105+
}
106+
}
107+
108+
/// <summary>
109+
/// Reads from multiple sources in a round-robin fashion.
110+
/// </summary>
111+
/// <typeparam name="T">The source type.</typeparam>
112+
/// <param name="sources">The channels to read from.</param>
113+
/// <returns>
114+
/// A <see cref="ChannelReader{T}"/> that reads from all sources in a round-robin fashion.
115+
/// </returns>
116+
public static ChannelReader<T> Merge<T>(this IEnumerable<ChannelReader<T>> sources)
117+
=> new MergingChannelReader<T>(sources);
118+
}

Open.ChannelExtensions/Extensions.Read.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,13 +1058,14 @@ public static ValueTask<long> ReadAllAsLines<T>(this Channel<T, string> channel,
10581058
/// </summary>
10591059
/// <typeparam name="T">The item type.</typeparam>
10601060
/// <param name="reader">The channel reader to read from.</param>
1061+
/// <param name="initialCapacity">An optional capacity to initialze the list with.</param>
10611062
/// <returns>A list containing all the items from the completed channel.</returns>
1062-
public static async ValueTask<List<T>> ToListAsync<T>(this ChannelReader<T> reader)
1063+
public static async ValueTask<List<T>> ToListAsync<T>(this ChannelReader<T> reader, int initialCapacity = 0)
10631064
{
10641065
if (reader is null) throw new ArgumentNullException(nameof(reader));
10651066
Contract.EndContractBlock();
10661067

1067-
var list = new List<T>();
1068+
List<T> list = initialCapacity == 0 ? new() : new(initialCapacity);
10681069
await ReadAll(reader, list.Add).ConfigureAwait(false);
10691070
return list;
10701071
}

0 commit comments

Comments
 (0)