|
1 | | -namespace Open.ChannelExtensions; |
| 1 | +using System.Collections.Immutable; |
| 2 | + |
| 3 | +namespace Open.ChannelExtensions; |
2 | 4 |
|
3 | 5 | public static partial class Extensions |
4 | 6 | { |
5 | | - sealed class MergingChannelReader<T> : ChannelReader<T> |
6 | | - { |
7 | | - public MergingChannelReader(IEnumerable<ChannelReader<T>> sources) |
8 | | - { |
9 | | - if (sources is null) throw new ArgumentNullException(nameof(sources)); |
10 | | - Contract.EndContractBlock(); |
11 | | - |
12 | | - var active = sources.Where(s => |
13 | | - { |
14 | | - Debug.Assert(s is not null); |
15 | | - return s.Completion.Status != TaskStatus.RanToCompletion; |
16 | | - }).ToArray(); |
17 | | - |
18 | | - _count = active.Length; |
19 | | - |
20 | | - if (_count == 0) |
21 | | - { |
22 | | - _sources = []; |
23 | | - Completion = Task.CompletedTask; |
24 | | - return; |
25 | | - } |
26 | | - |
27 | | - _sources = active; |
28 | | - |
29 | | - // Capture the initial list of completions. |
30 | | - var completions = active.Select(e => e.Completion).ToList(); |
31 | | - |
32 | | - // Create a task that completes when any of the sources are faulted: |
33 | | - Completion = Task.Run(async () => |
34 | | - { |
35 | | - // Wait for any of the tasks to complete and let it throw if it faults. |
36 | | - while (completions.Count != 0) |
37 | | - { |
38 | | - var completed = await Task.WhenAny(completions).ConfigureAwait(false); |
39 | | - // Propagate the exception. |
40 | | - await completed.ConfigureAwait(false); |
41 | | - completions.Remove(completed); |
42 | | - } |
43 | | - }); |
44 | | - } |
45 | | - |
46 | | - private readonly ChannelReader<T>[] _sources; |
47 | | - public override Task Completion { get; } |
48 | | - |
49 | | - readonly int _count; |
50 | | - int _next = -1; |
51 | | - |
52 | | - public override bool TryRead(out T item) |
53 | | - { |
54 | | - int previous = -1; |
55 | | - // Try as many times as there are sources before giving up. |
56 | | - for (var attempt = 0; attempt < _count; attempt++) |
57 | | - { |
58 | | - // If the value overflows, it will be negative, which is fine, we'll adapt. |
59 | | - var i = Interlocked.Increment(ref _next) % _count; |
60 | | - if (i < 0) i += _count; |
61 | | - |
62 | | - var source = _sources[i]; |
63 | | - |
64 | | - if (source.TryRead(out T? s)) |
65 | | - { |
66 | | - item = s; |
67 | | - return true; |
68 | | - } |
69 | | - |
70 | | - // Help the round-robin to try each source at least once. |
71 | | - // If previous is not -1 and i is not the next in the sequence, |
72 | | - // then another thread has already tried that source. |
73 | | - if (previous != -1 && (previous + 1) % _count != i) |
74 | | - attempt--; // Allow for an extra attempt. |
75 | | - |
76 | | - previous = i; |
77 | | - } |
78 | | - |
79 | | - item = default!; |
80 | | - return false; |
81 | | - } |
82 | | - |
83 | | - public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default) |
84 | | - { |
85 | | - var completion = Completion; |
86 | | - if (Completion.IsCompleted) |
87 | | - { |
88 | | - return completion.IsFaulted |
89 | | - ? new ValueTask<bool>(Task.FromException<bool>(completion.Exception!)) |
90 | | - : new ValueTask<bool>(false); |
91 | | - } |
92 | | - |
93 | | - if (cancellationToken.IsCancellationRequested) |
94 | | - return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken)); |
95 | | - |
96 | | - // Not complete or cancelled? Wait on the sources. |
97 | | - return WaitToReadAsyncCore(cancellationToken); |
98 | | - } |
99 | | - |
100 | | - private async ValueTask<bool> WaitToReadAsyncCore(CancellationToken cancellationToken) |
101 | | - { |
102 | | - retry: |
103 | | - // We don't care about ones that have already completed. |
104 | | - var active = _sources.Where(s => s.Completion.Status != TaskStatus.RanToCompletion).ToArray(); |
105 | | - if (active.Length == 0) return false; |
106 | | - |
107 | | - var next = await Task.WhenAny(active.Select(s => s.WaitToReadAsync(cancellationToken).AsTask())).ConfigureAwait(false); |
108 | | - |
109 | | - // Allow for possible exception to be thrown. |
110 | | - var result = await next.ConfigureAwait(false); |
111 | | - if (result) return true; |
112 | | - |
113 | | - // If result was false, then there's one less and we should try again. |
114 | | - goto retry; |
115 | | - } |
116 | | - } |
117 | | - |
118 | 7 | /// <summary> |
119 | | - /// Reads from multiple sources in a round-robin fashion. |
| 8 | + /// Creates a <see cref="MergingChannelReader{T}"/> |
| 9 | + /// that reads from multiple sources in a round-robin fashion. |
120 | 10 | /// </summary> |
121 | 11 | /// <typeparam name="T">The source type.</typeparam> |
122 | 12 | /// <param name="sources">The channels to read from.</param> |
123 | | - /// <returns> |
124 | | - /// A <see cref="ChannelReader{T}"/> that reads from all sources in a round-robin fashion. |
125 | | - /// </returns> |
126 | | - public static ChannelReader<T> Merge<T>(this IEnumerable<ChannelReader<T>> sources) |
127 | | - => new MergingChannelReader<T>(sources); |
| 13 | + public static MergingChannelReader<T> Merge<T>(this IEnumerable<ChannelReader<T>> sources) => new(sources); |
| 14 | + |
| 15 | + /// <summary> |
| 16 | + /// Merges the <paramref name="primary"/> with the <paramref name="secondary"/> |
| 17 | + /// as a <see cref="MergingChannelReader{T}"/> |
| 18 | + /// that reads from multiple sources in a round-robin fashion. |
| 19 | + /// </summary> |
| 20 | + /// <exception cref="ArgumentNullException"> |
| 21 | + /// If the <paramref name="primary"/> |
| 22 | + /// or <paramref name="secondary"/> sources are null. |
| 23 | + /// </exception> |
| 24 | + /// <inheritdoc cref="MergingChannelReader{T}.Merge(ChannelReader{T}, ChannelReader{T}[])"/>/> |
| 25 | + public static MergingChannelReader<T> Merge<T>( |
| 26 | + this ChannelReader<T> primary, |
| 27 | + ChannelReader<T> secondary, |
| 28 | + params ChannelReader<T>[] others) |
| 29 | + { |
| 30 | + if (primary is null) throw new ArgumentNullException(nameof(primary)); |
| 31 | + if (secondary is null) throw new ArgumentNullException(nameof(secondary)); |
| 32 | + Contract.EndContractBlock(); |
| 33 | + |
| 34 | + // Is this already a merging reader? Then recapture the sources so it flattens the hierarchy. |
| 35 | + if (primary is MergingChannelReader<T> mcr) |
| 36 | + return mcr.Merge(secondary, others); |
| 37 | + |
| 38 | + if(others is null || others.Length == 0) |
| 39 | + return new MergingChannelReader<T>(ImmutableArray.Create(primary, secondary)); |
| 40 | + |
| 41 | + var builder = ImmutableArray.CreateBuilder<ChannelReader<T>>(2 + others.Length); |
| 42 | + builder.Add(primary); |
| 43 | + builder.Add(secondary); |
| 44 | + builder.AddRange(others); |
| 45 | + return new MergingChannelReader<T>(builder.MoveToImmutable()); |
| 46 | + } |
128 | 47 | } |
0 commit comments