Skip to content

Commit 2515747

Browse files
author
Oren (electricessence)
committed
Added ForceBatch() to BatchChannelReader.
Extracted buffering and batch readers to extend functionality.
1 parent 0343fe1 commit 2515747

6 files changed

Lines changed: 260 additions & 130 deletions

File tree

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.Contracts;
4+
using System.Threading.Channels;
5+
6+
namespace Open.ChannelExtensions
7+
{
8+
/// <summary>
9+
/// A ChannelReader that batches results.
10+
/// Use the .Batch extension instead of constructing this directly.
11+
/// </summary>
12+
public class BatchingChannelReader<T> : BufferingChannelReader<T, List<T>>
13+
{
14+
private readonly int _batchSize;
15+
private List<T>? _current;
16+
17+
/// <summary>
18+
/// Constructs a BatchingChannelReader.
19+
/// Use the .Batch extension instead of constructing this directly.
20+
/// </summary>
21+
public BatchingChannelReader(ChannelReader<T> source, int batchSize, bool singleReader, bool syncCont = false) : base(source, singleReader, syncCont)
22+
{
23+
if (batchSize < 1) throw new ArgumentOutOfRangeException(nameof(batchSize), batchSize, "Must be at least 1.");
24+
Contract.EndContractBlock();
25+
26+
_batchSize = batchSize;
27+
_current = source.Completion.IsCompleted ? null : new List<T>(batchSize);
28+
}
29+
30+
/// <summary>
31+
/// If no full batch is waiting, will force buffering any batch that has at least one item.
32+
/// Returns true if anything was added to the buffer.
33+
/// </summary>
34+
public bool ForceBatch()
35+
{
36+
if (Buffer == null || Buffer.Reader.Completion.IsCompleted) return false;
37+
if (TryPipeItems()) return true;
38+
39+
lock (Buffer)
40+
{
41+
if (Buffer.Reader.Completion.IsCompleted) return false;
42+
if (TryPipeItems()) return true;
43+
var c = _current;
44+
if (c == null || c.Count == 0 || Buffer.Reader.Completion.IsCompleted)
45+
return false;
46+
c.TrimExcess();
47+
_current = new List<T>(_batchSize);
48+
Buffer.Writer.TryWrite(c);
49+
}
50+
51+
return true;
52+
}
53+
54+
/// <inheritdoc />
55+
protected override bool TryPipeItems()
56+
{
57+
if (_current == null || Buffer == null || Buffer.Reader.Completion.IsCompleted)
58+
return false;
59+
60+
lock (Buffer)
61+
{
62+
var c = _current;
63+
if (c == null || Buffer.Reader.Completion.IsCompleted)
64+
return false;
65+
66+
var source = Source;
67+
if (source == null || source.Completion.IsCompleted)
68+
{
69+
// All finished, release the last batch to the buffer.
70+
c.TrimExcess();
71+
_current = null;
72+
if (c.Count == 0)
73+
return false;
74+
75+
Buffer.Writer.TryWrite(c);
76+
return true;
77+
}
78+
79+
while (source.TryRead(out T item))
80+
{
81+
if (c.Count == _batchSize)
82+
{
83+
_current = new List<T>(_batchSize) { item };
84+
Buffer.Writer.TryWrite(c);
85+
return true;
86+
}
87+
88+
c.Add(item);
89+
}
90+
91+
return false;
92+
}
93+
}
94+
}
95+
}

Open.ChannelExtensions/BufferingChannelReader.cs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,28 @@
66

77
namespace Open.ChannelExtensions
88
{
9-
abstract class BufferingChannelReader<TIn, TOut> : ChannelReader<TOut>
9+
/// <summary>
10+
/// Base class for buffering results of a source ChannelReader.
11+
/// </summary>
12+
/// <typeparam name="TIn">The input type of the buffer.</typeparam>
13+
/// <typeparam name="TOut">The output type of the buffer.</typeparam>
14+
public abstract class BufferingChannelReader<TIn, TOut> : ChannelReader<TOut>
1015
{
11-
protected ChannelReader<TIn>? Source;
12-
protected readonly Channel<TOut>? Buffer;
13-
public BufferingChannelReader(ChannelReader<TIn> source, bool singleReader, bool syncCont = false)
16+
/// <summary>
17+
/// The source of the buffer.
18+
/// </summary>
19+
protected ChannelReader<TIn>? Source { get; set; }
20+
21+
/// <summary>
22+
/// The internal channel used for buffering.
23+
/// </summary>
24+
protected Channel<TOut>? Buffer { get; }
25+
26+
27+
/// <summary>
28+
/// Base constructor for a BufferingChannelReader.
29+
/// </summary>
30+
protected BufferingChannelReader(ChannelReader<TIn> source, bool singleReader, bool syncCont = false)
1431
{
1532
Source = source ?? throw new ArgumentNullException(nameof(source));
1633
Contract.EndContractBlock();
@@ -37,10 +54,16 @@ public BufferingChannelReader(ChannelReader<TIn> source, bool singleReader, bool
3754
}
3855
}
3956

57+
/// <inheritdoc />
4058
public override Task Completion => Buffer?.Reader.Completion ?? Task.CompletedTask;
4159

60+
/// <summary>
61+
/// The method that triggers adding entries to the buffer.
62+
/// </summary>
63+
/// <returns></returns>
4264
protected abstract bool TryPipeItems();
4365

66+
/// <inheritdoc />
4467
public override bool TryRead(out TOut item)
4568
{
4669
if (Buffer != null) do
@@ -54,6 +77,7 @@ public override bool TryRead(out TOut item)
5477
return false;
5578
}
5679

80+
/// <inheritdoc />
5781
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
5882
{
5983
if (Buffer == null || Buffer.Reader.Completion.IsCompleted)

0 commit comments

Comments
 (0)