Skip to content

Commit 95c82e9

Browse files
author
Oren (electricessence)
committed
Added Join and OfType filters.
1 parent 06d336a commit 95c82e9

4 files changed

Lines changed: 171 additions & 25 deletions

File tree

Open.ChannelExtensions/Extensions.Filter.cs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,36 @@ namespace Open.ChannelExtensions
88
{
99
public static partial class Extensions
1010
{
11-
class TransformingChannelReader<TIn, TOut> : ChannelReader<TOut>
11+
class FilteringChannelReader<T> : ChannelReader<T>
1212
{
13-
public TransformingChannelReader(ChannelReader<TIn> source, Func<TIn, TOut> transform)
13+
public FilteringChannelReader(ChannelReader<T> source, Func<T, bool> predicate)
1414
{
1515
_source = source ?? throw new ArgumentNullException(nameof(source));
16-
_transform = transform ?? throw new ArgumentNullException(nameof(transform));
16+
_predicate = predicate ?? throw new ArgumentNullException(nameof(predicate));
1717
Contract.EndContractBlock();
1818
}
1919

20-
private readonly ChannelReader<TIn> _source;
21-
private readonly Func<TIn, TOut> _transform;
20+
private readonly ChannelReader<T> _source;
21+
private readonly Func<T, bool> _predicate;
2222
public override Task Completion => _source.Completion;
2323

24-
public override bool TryRead(out TOut item)
24+
public override bool TryRead(out T item)
2525
{
26-
if (_source.TryRead(out var e))
26+
while (_source.TryRead(out item))
2727
{
28-
item = _transform(e);
29-
return true;
28+
if (_predicate(item))
29+
return true;
3030
}
3131

3232
item = default;
3333
return false;
3434
}
3535

36-
public override async ValueTask<TOut> ReadAsync(CancellationToken cancellationToken = default)
37-
=> _transform(await _source.ReadAsync(cancellationToken));
38-
3936
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
4037
=> _source.WaitToReadAsync(cancellationToken);
4138
}
4239

43-
public static ChannelReader<TOut> Transform<TIn, TOut>(this ChannelReader<TIn> source, Func<TIn, TOut> transform)
44-
=> new TransformingChannelReader<TIn, TOut>(source, transform);
40+
public static ChannelReader<T> Filter<T>(this ChannelReader<T> source, Func<T, bool> predicate)
41+
=> new FilteringChannelReader<T>(source, predicate);
4542
}
4643
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics.Contracts;
4+
using System.Threading;
5+
using System.Threading.Channels;
6+
using System.Threading.Tasks;
7+
8+
namespace Open.ChannelExtensions
9+
{
10+
public static partial class Extensions
11+
{
12+
class JoiningChannelReader<TList, T> : ChannelReader<T>
13+
where TList : IEnumerable<T>
14+
{
15+
readonly Channel<T> _buffer = Channel.CreateUnbounded<T>(new UnboundedChannelOptions { AllowSynchronousContinuations = true });
16+
public JoiningChannelReader(ChannelReader<TList> source)
17+
{
18+
_source = source ?? throw new ArgumentNullException(nameof(source));
19+
Contract.EndContractBlock();
20+
}
21+
22+
private readonly ChannelReader<TList> _source;
23+
public override Task Completion
24+
=> _source.Completion.ContinueWith(t =>
25+
{
26+
// Need to be sure writing is done before we continue...
27+
lock (_buffer)
28+
{
29+
_buffer.Writer.Complete(t.Exception);
30+
}
31+
return _buffer.Reader.Completion;
32+
})
33+
.Unwrap();
34+
35+
bool TryPipeItems()
36+
{
37+
lock (_buffer)
38+
{
39+
if (!_source.TryRead(out TList batch))
40+
return false;
41+
42+
foreach (var i in batch)
43+
{
44+
// Assume this will always be true for our internal unbound channel.
45+
_buffer.Writer.TryWrite(i);
46+
}
47+
48+
return true;
49+
}
50+
}
51+
52+
public override bool TryRead(out T item)
53+
{
54+
do
55+
{
56+
if (_buffer.Reader.TryRead(out item))
57+
return true;
58+
}
59+
while (TryPipeItems());
60+
61+
item = default;
62+
return false;
63+
}
64+
65+
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
66+
{
67+
if (cancellationToken.IsCancellationRequested)
68+
return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken));
69+
70+
var b = _buffer.Reader.WaitToReadAsync(cancellationToken);
71+
if (b.IsCompletedSuccessfully || _source.Completion.IsCompleted)
72+
return b;
73+
74+
var s = _source.WaitToReadAsync(cancellationToken);
75+
if (s.IsCompletedSuccessfully && !s.Result)
76+
return b;
77+
78+
return WaitCore();
79+
80+
async ValueTask<bool> WaitCore()
81+
{
82+
cancellationToken.ThrowIfCancellationRequested();
83+
84+
// Not sure if there's a better way to 'WhenAny' with a ValueTask yet.
85+
var bt = b.AsTask();
86+
var st = s.AsTask();
87+
var first = await Task.WhenAny(bt, st);
88+
// Either one? Ok go.
89+
if (first.Result) return true;
90+
// Buffer returned false? We're done.
91+
if (first == bt) return false;
92+
// Second return false? Wait for buffer.
93+
return await bt;
94+
}
95+
}
96+
}
97+
98+
public static ChannelReader<T> Join<TList, T>(this ChannelReader<TList> source)
99+
where TList : IEnumerable<T>
100+
=> new JoiningChannelReader<TList, T>(source);
101+
}
102+
}

Open.ChannelExtensions/Extensions.Transform.cs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,39 @@ namespace Open.ChannelExtensions
88
{
99
public static partial class Extensions
1010
{
11-
class FilteringChannelReader<T> : ChannelReader<T>
11+
class TransformingChannelReader<TIn, TOut> : ChannelReader<TOut>
1212
{
13-
public FilteringChannelReader(ChannelReader<T> source, Func<T, bool> predicate)
13+
public TransformingChannelReader(ChannelReader<TIn> source, Func<TIn, TOut> transform)
1414
{
1515
_source = source ?? throw new ArgumentNullException(nameof(source));
16-
_predicate = predicate ?? throw new ArgumentNullException(nameof(predicate));
16+
_transform = transform ?? throw new ArgumentNullException(nameof(transform));
1717
Contract.EndContractBlock();
1818
}
1919

20-
private readonly ChannelReader<T> _source;
21-
private readonly Func<T, bool> _predicate;
20+
private readonly ChannelReader<TIn> _source;
21+
private readonly Func<TIn, TOut> _transform;
2222
public override Task Completion => _source.Completion;
2323

24-
public override bool TryRead(out T item)
24+
public override bool TryRead(out TOut item)
2525
{
26-
while (_source.TryRead(out item))
26+
if (_source.TryRead(out var e))
2727
{
28-
if (_predicate(item))
29-
return true;
28+
item = _transform(e);
29+
return true;
3030
}
3131

3232
item = default;
3333
return false;
3434
}
3535

36+
public override async ValueTask<TOut> ReadAsync(CancellationToken cancellationToken = default)
37+
=> _transform(await _source.ReadAsync(cancellationToken));
38+
3639
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
3740
=> _source.WaitToReadAsync(cancellationToken);
3841
}
3942

40-
public static ChannelReader<T> Filter<T>(this ChannelReader<T> source, Func<T, bool> predicate)
41-
=> new FilteringChannelReader<T>(source, predicate);
43+
public static ChannelReader<TOut> Transform<TIn, TOut>(this ChannelReader<TIn> source, Func<TIn, TOut> transform)
44+
=> new TransformingChannelReader<TIn, TOut>(source, transform);
4245
}
4346
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using System;
2+
using System.Diagnostics.Contracts;
3+
using System.Threading;
4+
using System.Threading.Channels;
5+
using System.Threading.Tasks;
6+
7+
namespace Open.ChannelExtensions
8+
{
9+
public static partial class Extensions
10+
{
11+
class TypeFilteringChannelReader<TSource, T> : ChannelReader<T>
12+
{
13+
public TypeFilteringChannelReader(ChannelReader<TSource> source)
14+
{
15+
_source = source ?? throw new ArgumentNullException(nameof(source));
16+
Contract.EndContractBlock();
17+
}
18+
19+
private readonly ChannelReader<TSource> _source;
20+
public override Task Completion => _source.Completion;
21+
22+
public override bool TryRead(out T item)
23+
{
24+
while (_source.TryRead(out TSource s))
25+
{
26+
if(s is T i)
27+
{
28+
item = i;
29+
return true;
30+
}
31+
}
32+
33+
item = default;
34+
return false;
35+
}
36+
37+
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
38+
=> _source.WaitToReadAsync(cancellationToken);
39+
}
40+
41+
public static ChannelReader<T> OfType<TSource, T>(this ChannelReader<TSource> source)
42+
=> new TypeFilteringChannelReader<TSource, T>(source);
43+
}
44+
}

0 commit comments

Comments
 (0)