Skip to content

Commit 59a40b4

Browse files
committed
fix AsyncSeqSrc, add tests
1 parent 3a0ee73 commit 59a40b4

3 files changed

Lines changed: 167 additions & 67 deletions

File tree

src/FSharp.Control.AsyncSeq/AsyncSeq.fs

Lines changed: 65 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ type IAsyncEnumerable<'T> =
2525
type AsyncSeq<'T> = IAsyncEnumerable<'T>
2626
// abstract GetEnumerator : unit -> IAsyncEnumerator<'T>
2727

28-
type AsyncSeqSrc<'a> = private { mutable tail : TaskCompletionSource<('a * AsyncSeqSrc<'a>) option> }
28+
type AsyncSeqSrc<'a> = private { mutable tail : AsyncSeqSrcNode<'a> }
29+
30+
and private AsyncSeqSrcNode<'a> =
31+
struct
32+
val tcs : TaskCompletionSource<('a * AsyncSeqSrcNode<'a>) option>
33+
new (tcs) = { tcs = tcs }
34+
end
2935

3036
[<AutoOpen>]
3137
module internal Utils =
@@ -93,7 +99,7 @@ module internal Utils =
9399

94100
let queue = new Queue<_>()
95101

96-
let rec loop () = async {
102+
let rec loop () = async {
97103
match queue.Count with
98104
| 0 -> do! tryReceive ()
99105
| _ -> do! trySendOrReceive ()
@@ -140,6 +146,7 @@ module internal Utils =
140146

141147
/// Creates an async computation that completes when a message is available in a mailbox.
142148
let take (mb:Mb<'a>) = mb.Take
149+
143150

144151

145152
type internal BoundedMbReq<'a> =
@@ -153,6 +160,14 @@ module internal Utils =
153160

154161
let queue = new Queue<_>()
155162

163+
let receive (a:'a, rep:AsyncReplyChannel<unit>) = async {
164+
queue.Enqueue a
165+
return rep.Reply () }
166+
167+
let send (rep:AsyncReplyChannel<'a>) = async {
168+
let a = queue.Dequeue ()
169+
return rep.Reply a }
170+
156171
let rec loop () = async {
157172
match queue.Count with
158173
| 0 -> do! tryReceive ()
@@ -165,19 +180,11 @@ module internal Utils =
165180
| Put (a,rep) -> Some (receive (a,rep))
166181
| _ -> None)
167182

168-
and receive (a:'a, rep:AsyncReplyChannel<unit>) = async {
169-
queue.Enqueue a
170-
rep.Reply () }
171-
172183
and trySend () =
173184
agent.Scan (function
174185
| Take rep -> Some (send rep)
175186
| _ -> None)
176187

177-
and send (rep:AsyncReplyChannel<'a>) = async {
178-
let a = queue.Dequeue ()
179-
return rep.Reply a }
180-
181188
and trySendOrReceive () = async {
182189
let! msg = agent.Receive ()
183190
match msg with
@@ -198,10 +205,10 @@ module internal Utils =
198205
/// Operations on bounded FIFO mailboxes.
199206
module BoundedMb =
200207

201-
/// Creates a new unbounded mailbox.
208+
/// Creates a new bounded mailbox.
202209
let create (capacity:int) = new BoundedMb<'a> (capacity)
203210

204-
/// Puts a message into a mailbox, no waiting.
211+
/// Puts a message into a mailbox, waiting if at capacity.
205212
let put (a:'a) (mb:BoundedMb<'a>) = mb.Put a
206213

207214
/// Creates an async computation that completes when a message is available in a mailbox.
@@ -1367,59 +1374,71 @@ module AsyncSeq =
13671374

13681375

13691376
module AsyncSeqSrcImpl =
1377+
1378+
let private createNode () =
1379+
new AsyncSeqSrcNode<_>(new TaskCompletionSource<_>())
13701380

13711381
let create () : AsyncSeqSrc<'a> =
1372-
{ tail = new TaskCompletionSource<_>() }
1373-
1374-
let put (a:'a) (s:AsyncSeqSrc<'a>) : unit =
1375-
let newTail = create ()
1382+
{ tail = createNode () }
1383+
1384+
let put (a:'a) (s:AsyncSeqSrc<'a>) =
1385+
let newTail = createNode ()
13761386
let tail = s.tail
1377-
s.tail <- newTail.tail
1378-
tail.SetResult(Some(a, newTail))
1387+
s.tail <- newTail
1388+
tail.tcs.SetResult(Some(a, newTail))
13791389

13801390
let close (s:AsyncSeqSrc<'a>) : unit =
1381-
s.tail.SetResult(None)
1391+
s.tail.tcs.SetResult(None)
13821392

13831393
let fail (ex:exn) (s:AsyncSeqSrc<'a>) : unit =
1384-
s.tail.SetException ex
1394+
s.tail.tcs.SetException(ex)
13851395

1386-
let rec toAsyncSeq (s:AsyncSeqSrc<'a>) : AsyncSeq<'a> =
1387-
let tail = s.tail
1396+
let rec private toAsyncSeqImpl (s:AsyncSeqSrcNode<'a>) : AsyncSeq<'a> =
13881397
asyncSeq {
1389-
let! next = tail.Task |> Async.AwaitTask
1398+
let! next = s.tcs.Task |> Async.AwaitTask
13901399
match next with
13911400
| None -> ()
13921401
| Some (a,tl) ->
13931402
yield a
1394-
yield! toAsyncSeq tl }
1403+
yield! toAsyncSeqImpl tl }
1404+
1405+
let toAsyncSeq (s:AsyncSeqSrc<'a>) : AsyncSeq<'a> =
1406+
toAsyncSeqImpl s.tail
1407+
13951408

13961409

13971410
type private Group<'k, 'a> = { key : 'k ; src : AsyncSeqSrc<'a> }
13981411

13991412
let groupByAsync (p:'a -> Async<'k>) (s:AsyncSeq<'a>) : AsyncSeq<'k * AsyncSeq<'a>> = asyncSeq {
14001413
let groups = Collections.Generic.Dictionary<'k, Group<'k, 'a>>()
1401-
let close g =
1402-
groups.Remove(g.key) |> ignore
1403-
AsyncSeqSrcImpl.close g.src
1414+
let close group =
1415+
groups.Remove(group.key) |> ignore
1416+
AsyncSeqSrcImpl.close group.src
1417+
let closeGroups () =
1418+
groups.Values |> Seq.toArray |> Array.iter close
14041419
use enum = s.GetEnumerator()
1405-
let rec go () = asyncSeq {
1406-
let! next = enum.MoveNext ()
1407-
match next with
1408-
| None ->
1409-
groups.Values |> Seq.toArray |> Array.iter close
1410-
| Some a ->
1411-
let! k = p a
1412-
let mutable g = Unchecked.defaultof<_>
1413-
if groups.TryGetValue(k, &g) then
1414-
AsyncSeqSrcImpl.put a g.src
1415-
yield! go ()
1416-
else
1417-
let src = AsyncSeqSrcImpl.create ()
1418-
AsyncSeqSrcImpl.put a src
1419-
let g = { key = k ; src = src }
1420-
groups.Add(k, g)
1421-
yield k,src |> AsyncSeqSrcImpl.toAsyncSeq
1422-
yield! go () }
1420+
let rec go () = asyncSeq {
1421+
try
1422+
let! next = enum.MoveNext ()
1423+
match next with
1424+
| None -> closeGroups ()
1425+
| Some a ->
1426+
let! key = p a
1427+
let mutable group = Unchecked.defaultof<_>
1428+
if groups.TryGetValue(key, &group) then
1429+
AsyncSeqSrcImpl.put a group.src
1430+
yield! go ()
1431+
else
1432+
let src = AsyncSeqSrcImpl.create ()
1433+
let subSeq = src |> AsyncSeqSrcImpl.toAsyncSeq
1434+
AsyncSeqSrcImpl.put a src
1435+
let group = { key = key ; src = src }
1436+
groups.Add(key, group)
1437+
yield key,subSeq
1438+
yield! go ()
1439+
with ex ->
1440+
closeGroups ()
1441+
raise ex }
14231442
yield! go () }
14241443

14251444
let groupBy (p:'a -> 'k) (s:AsyncSeq<'a>) : AsyncSeq<'k * AsyncSeq<'a>> =
@@ -1443,6 +1462,7 @@ module AsyncSeqSrc =
14431462
let put a s = AsyncSeq.AsyncSeqSrcImpl.put a s
14441463
let close s = AsyncSeq.AsyncSeqSrcImpl.close s
14451464
let toAsyncSeq s = AsyncSeq.AsyncSeqSrcImpl.toAsyncSeq s
1465+
let fail e s = AsyncSeq.AsyncSeqSrcImpl.fail e s
14461466

14471467
module Seq =
14481468

src/FSharp.Control.AsyncSeq/AsyncSeq.fsi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,9 @@ module AsyncSeqSrc =
435435
/// Closes the async sequence source casuing any created async sequences to terminate.
436436
val close : src:AsyncSeqSrc<'T> -> unit
437437

438+
/// Causes async sequence created immediately before the call to raise an exception.
439+
val fail : exn:exn -> src:AsyncSeqSrc<'T> -> unit
440+
438441
/// Creates an async sequence which yields values as they are put into the source and terminates
439-
/// when the source is closed.
442+
/// when the source is closed. This sequence will yield items starting with the next put.
440443
val toAsyncSeq : src:AsyncSeqSrc<'T> -> AsyncSeq<'T>

tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,24 @@ let EQ (a:AsyncSeq<'a>) (b:AsyncSeq<'a>) =
3131
type Assert with
3232
/// Determines equality of two async sequences by convering them to lists, ignoring side-effects.
3333
static member AreEqual (expected:AsyncSeq<'a>, actual:AsyncSeq<'a>) =
34-
let exp = expected |> AsyncSeq.toList
35-
let act = actual |> AsyncSeq.toList
34+
Assert.AreEqual (expected, actual, 1000, exnEq=(fun _ _ -> true))
35+
/// Determines equality of two async sequences by convering them to lists, ignoring side-effects.
36+
static member AreEqual (expected:AsyncSeq<'a>, actual:AsyncSeq<'a>, timeout) =
37+
Assert.AreEqual (expected, actual, timeout=timeout, exnEq=(fun _ _ -> true))
38+
/// Determines equality of two async sequences by convering them to lists, ignoring side-effects.
39+
static member AreEqual (expected:AsyncSeq<'a>, actual:AsyncSeq<'a>, timeout, exnEq:exn -> exn -> bool) =
40+
let exp = expected |> AsyncSeq.toListAsync |> Async.Catch
41+
let exp = Async.RunSynchronously (exp, timeout)
42+
let act = actual |> AsyncSeq.toListAsync |> Async.Catch
43+
let act = Async.RunSynchronously(act, timeout)
3644
let message = sprintf "expected=%A actual=%A" exp act
37-
Assert.True((exp = act), message)
45+
match exp,act with
46+
| Choice1Of2 exp, Choice1Of2 act ->
47+
Assert.True((exp = act), message)
48+
| Choice2Of2 exp, Choice2Of2 act ->
49+
Assert.True((exnEq exp act), message)
50+
| _ ->
51+
Assert.Fail(message)
3852

3953

4054

@@ -1077,43 +1091,106 @@ let ``AsyncSeq.take should work``() =
10771091
let ls = ss |> AsyncSeq.toList
10781092
()
10791093

1094+
[<Test>]
1095+
let ``AsyncSeq.mapParallelAsync should maintain order`` () =
1096+
let ls = List.init 500 id
1097+
let expected =
1098+
ls
1099+
|> AsyncSeq.ofSeq
1100+
|> AsyncSeq.mapAsync (async.Return)
1101+
let actual =
1102+
ls
1103+
|> AsyncSeq.ofSeq
1104+
|> AsyncSeq.mapAsyncParallel (async.Return)
1105+
Assert.AreEqual(expected, actual)
1106+
1107+
//[<Test>]
1108+
let ``AsyncSeq.mapParallelAsync should be parallel`` () =
1109+
let parallelism = 3
1110+
let barrier = new Threading.Barrier(parallelism)
1111+
let s = AsyncSeq.init (int64 parallelism) int
1112+
let expected =
1113+
s |> AsyncSeq.map id
1114+
let actual =
1115+
s
1116+
|> AsyncSeq.mapAsyncParallel (fun i -> async { barrier.SignalAndWait () ; return i })
1117+
Assert.AreEqual(expected, actual, timeout=200)
1118+
1119+
//[<Test>]
1120+
let ``AsyncSeq.mapParallelAsyncBounded should maintain order`` () =
1121+
let ls = List.init 500 id
1122+
let expected =
1123+
ls
1124+
|> AsyncSeq.ofSeq
1125+
|> AsyncSeq.mapAsync (async.Return)
1126+
let actual =
1127+
ls
1128+
|> AsyncSeq.ofSeq
1129+
|> AsyncSeq.mapAsyncParallelBounded 10 (async.Return)
1130+
Assert.AreEqual(expected, actual, timeout=200)
1131+
1132+
10801133

10811134
[<Test>]
1082-
let ``AsyncSeqSource.create should create empty sequence`` () =
1135+
let ``AsyncSeqSrc.create should create empty sequence`` () =
10831136
let src = AsyncSeqSrc.create ()
10841137
let s = src |> AsyncSeqSrc.toAsyncSeq
10851138
src |> AsyncSeqSrc.close
10861139
let expected = AsyncSeq.empty
10871140
Assert.True(EQ expected s)
10881141

1089-
10901142
[<Test>]
1091-
let ``AsyncSeqSource.put should yield`` () =
1143+
let ``AsyncSeqSrc.put should yield when tapped before put`` () =
10921144
let item = 1
10931145
let src = AsyncSeqSrc.create ()
1094-
let actual = src |> AsyncSeqSrc.toAsyncSeq
1146+
let actual = src |> AsyncSeqSrc.toAsyncSeq
10951147
src |> AsyncSeqSrc.put item
10961148
src |> AsyncSeqSrc.close
10971149
let expected = AsyncSeq.singleton item
10981150
Assert.AreEqual (expected, actual)
10991151

11001152
[<Test>]
1101-
let ``AsyncSeqSource.put should yield after async sequence is created`` () =
1102-
let item1 = 1
1103-
let item2 = 2
1104-
let src = AsyncSeqSrc.create ()
1105-
src |> AsyncSeqSrc.put item1
1106-
let actual = src |> AsyncSeqSrc.toAsyncSeq
1107-
src |> AsyncSeqSrc.put item2
1108-
src |> AsyncSeqSrc.close
1109-
let expected = AsyncSeq.ofSeq [item2]
1153+
let ``AsyncSeqSrc.put should yield when tapped after put`` () =
1154+
let item = 1
1155+
let src = AsyncSeqSrc.create ()
1156+
src |> AsyncSeqSrc.put item
1157+
let actual = src |> AsyncSeqSrc.toAsyncSeq
1158+
src |> AsyncSeqSrc.close
1159+
let expected = AsyncSeq.empty
1160+
Assert.AreEqual (expected, actual)
1161+
1162+
[<Test>]
1163+
let ``AsyncSeqSrc.fail should throw`` () =
1164+
let item = 1
1165+
let src = AsyncSeqSrc.create ()
1166+
let actual = src |> AsyncSeqSrc.toAsyncSeq
1167+
src |> AsyncSeqSrc.fail (exn("test"))
1168+
let expected = asyncSeq { raise (exn("test")) }
11101169
Assert.AreEqual (expected, actual)
11111170

11121171

11131172
[<Test>]
11141173
let ``AsyncSeq.groupBy should work``() =
1115-
let ls = List.init 4 id
1116-
let p i = i % 2
1117-
let expected = ls |> Seq.groupBy p |> Seq.map (snd >> Seq.toList) |> Seq.toList |> AsyncSeq.ofSeq
1118-
let actual = ls |> AsyncSeq.ofSeq |> AsyncSeq.groupBy p |> AsyncSeq.mapAsyncParallel (snd >> AsyncSeq.toListAsync)
1119-
Assert.True(EQ expected actual)
1174+
let ls = List.init 100 id
1175+
let p i = i % 3
1176+
let expected =
1177+
ls
1178+
|> Seq.groupBy p
1179+
|> Seq.map (snd >> Seq.toList)
1180+
|> Seq.toList
1181+
|> AsyncSeq.ofSeq
1182+
let actual =
1183+
ls
1184+
|> AsyncSeq.ofSeq
1185+
|> AsyncSeq.groupBy p
1186+
|> AsyncSeq.mapAsyncParallel (snd >> AsyncSeq.toListAsync)
1187+
Assert.AreEqual(expected, actual)
1188+
1189+
[<Test>]
1190+
let ``AsyncSeq.groupBy should propagate exception and terminate all groups``() =
1191+
let expected = asyncSeq { raise (exn("test")) }
1192+
let actual =
1193+
asyncSeq { raise (exn("test")) }
1194+
|> AsyncSeq.groupBy (fun i -> i % 3)
1195+
|> AsyncSeq.mapAsyncParallel (snd >> AsyncSeq.toListAsync)
1196+
Assert.AreEqual(expected, actual)

0 commit comments

Comments
 (0)