@@ -25,7 +25,13 @@ type IAsyncEnumerable<'T> =
2525type AsyncSeq < 'T > = IAsyncEnumerable< 'T>
2626// abstract GetEnumerator : unit -> IAsyncEnumerator<'T>
2727
28- type AsyncSeqSrc < 'a > = private { mutable tail : TaskCompletionSource <( 'a * AsyncSeqSrc < 'a >) option > }
28+ type AsyncSeqSrc < 'a > = private { mutable tail : AsyncSeqSrcNode < 'a > }
29+
30+ and private AsyncSeqSrcNode < 'a > =
31+ struct
32+ val tcs : TaskCompletionSource <( 'a * AsyncSeqSrcNode < 'a >) option >
33+ new ( tcs ) = { tcs = tcs }
34+ end
2935
3036[<AutoOpen>]
3137module internal Utils =
@@ -93,7 +99,7 @@ module internal Utils =
9399
94100 let queue = new Queue<_>()
95101
96- let rec loop () = async {
102+ let rec loop () = async {
97103 match queue.Count with
98104 | 0 -> do ! tryReceive ()
99105 | _ -> do ! trySendOrReceive ()
@@ -140,6 +146,7 @@ module internal Utils =
140146
141147 /// Creates an async computation that completes when a message is available in a mailbox.
142148 let take ( mb : Mb < 'a >) = mb.Take
149+
143150
144151
145152 type internal BoundedMbReq < 'a > =
@@ -153,6 +160,14 @@ module internal Utils =
153160
154161 let queue = new Queue<_>()
155162
163+ let receive ( a : 'a , rep : AsyncReplyChannel < unit >) = async {
164+ queue.Enqueue a
165+ return rep.Reply () }
166+
167+ let send ( rep : AsyncReplyChannel < 'a >) = async {
168+ let a = queue.Dequeue ()
169+ return rep.Reply a }
170+
156171 let rec loop () = async {
157172 match queue.Count with
158173 | 0 -> do ! tryReceive ()
@@ -165,19 +180,11 @@ module internal Utils =
165180 | Put ( a, rep) -> Some ( receive ( a, rep))
166181 | _ -> None)
167182
168- and receive ( a : 'a , rep : AsyncReplyChannel < unit >) = async {
169- queue.Enqueue a
170- rep.Reply () }
171-
172183 and trySend () =
173184 agent.Scan ( function
174185 | Take rep -> Some ( send rep)
175186 | _ -> None)
176187
177- and send ( rep : AsyncReplyChannel < 'a >) = async {
178- let a = queue.Dequeue ()
179- return rep.Reply a }
180-
181188 and trySendOrReceive () = async {
182189 let! msg = agent.Receive ()
183190 match msg with
@@ -198,10 +205,10 @@ module internal Utils =
198205 /// Operations on bounded FIFO mailboxes.
199206 module BoundedMb =
200207
201- /// Creates a new unbounded mailbox.
208+ /// Creates a new bounded mailbox.
202209 let create ( capacity : int ) = new BoundedMb< 'a> ( capacity)
203210
204- /// Puts a message into a mailbox, no waiting.
211+ /// Puts a message into a mailbox, waiting if at capacity .
205212 let put ( a : 'a ) ( mb : BoundedMb < 'a >) = mb.Put a
206213
207214 /// Creates an async computation that completes when a message is available in a mailbox.
@@ -1367,59 +1374,71 @@ module AsyncSeq =
13671374
13681375
13691376 module AsyncSeqSrcImpl =
1377+
1378+ let private createNode () =
1379+ new AsyncSeqSrcNode<_>( new TaskCompletionSource<_>())
13701380
13711381 let create () : AsyncSeqSrc < 'a > =
1372- { tail = new TaskCompletionSource <_> () }
1373-
1374- let put ( a : 'a ) ( s : AsyncSeqSrc < 'a >) : unit =
1375- let newTail = create ()
1382+ { tail = createNode () }
1383+
1384+ let put ( a : 'a ) ( s : AsyncSeqSrc < 'a >) =
1385+ let newTail = createNode ()
13761386 let tail = s.tail
1377- s.tail <- newTail.tail
1378- tail.SetResult( Some( a, newTail))
1387+ s.tail <- newTail
1388+ tail.tcs. SetResult( Some( a, newTail))
13791389
13801390 let close ( s : AsyncSeqSrc < 'a >) : unit =
1381- s.tail.SetResult( None)
1391+ s.tail.tcs. SetResult( None)
13821392
13831393 let fail ( ex : exn ) ( s : AsyncSeqSrc < 'a >) : unit =
1384- s.tail.SetException ex
1394+ s.tail.tcs. SetException( ex )
13851395
1386- let rec toAsyncSeq ( s : AsyncSeqSrc < 'a >) : AsyncSeq < 'a > =
1387- let tail = s.tail
1396+ let rec private toAsyncSeqImpl ( s : AsyncSeqSrcNode < 'a >) : AsyncSeq < 'a > =
13881397 asyncSeq {
1389- let! next = tail .Task |> Async.AwaitTask
1398+ let! next = s.tcs .Task |> Async.AwaitTask
13901399 match next with
13911400 | None -> ()
13921401 | Some ( a, tl) ->
13931402 yield a
1394- yield ! toAsyncSeq tl }
1403+ yield ! toAsyncSeqImpl tl }
1404+
1405+ let toAsyncSeq ( s : AsyncSeqSrc < 'a >) : AsyncSeq < 'a > =
1406+ toAsyncSeqImpl s.tail
1407+
13951408
13961409
13971410 type private Group < 'k , 'a > = { key : 'k ; src : AsyncSeqSrc < 'a > }
13981411
13991412 let groupByAsync ( p : 'a -> Async < 'k >) ( s : AsyncSeq < 'a >) : AsyncSeq < 'k * AsyncSeq < 'a >> = asyncSeq {
14001413 let groups = Collections.Generic.Dictionary< 'k, Group< 'k, 'a>>()
1401- let close g =
1402- groups.Remove( g.key) |> ignore
1403- AsyncSeqSrcImpl.close g.src
1414+ let close group =
1415+ groups.Remove( group.key) |> ignore
1416+ AsyncSeqSrcImpl.close group.src
1417+ let closeGroups () =
1418+ groups.Values |> Seq.toArray |> Array.iter close
14041419 use enum = s.GetEnumerator()
1405- let rec go () = asyncSeq {
1406- let! next = enum .MoveNext ()
1407- match next with
1408- | None ->
1409- groups.Values |> Seq.toArray |> Array.iter close
1410- | Some a ->
1411- let! k = p a
1412- let mutable g = Unchecked.defaultof<_>
1413- if groups.TryGetValue( k, & g) then
1414- AsyncSeqSrcImpl.put a g.src
1415- yield ! go ()
1416- else
1417- let src = AsyncSeqSrcImpl.create ()
1418- AsyncSeqSrcImpl.put a src
1419- let g = { key = k ; src = src }
1420- groups.Add( k, g)
1421- yield k, src |> AsyncSeqSrcImpl.toAsyncSeq
1422- yield ! go () }
1420+ let rec go () = asyncSeq {
1421+ try
1422+ let! next = enum .MoveNext ()
1423+ match next with
1424+ | None -> closeGroups ()
1425+ | Some a ->
1426+ let! key = p a
1427+ let mutable group = Unchecked.defaultof<_>
1428+ if groups.TryGetValue( key, & group) then
1429+ AsyncSeqSrcImpl.put a group.src
1430+ yield ! go ()
1431+ else
1432+ let src = AsyncSeqSrcImpl.create ()
1433+ let subSeq = src |> AsyncSeqSrcImpl.toAsyncSeq
1434+ AsyncSeqSrcImpl.put a src
1435+ let group = { key = key ; src = src }
1436+ groups.Add( key, group)
1437+ yield key, subSeq
1438+ yield ! go ()
1439+ with ex ->
1440+ closeGroups ()
1441+ raise ex }
14231442 yield ! go () }
14241443
14251444 let groupBy ( p : 'a -> 'k ) ( s : AsyncSeq < 'a >) : AsyncSeq < 'k * AsyncSeq < 'a >> =
@@ -1443,6 +1462,7 @@ module AsyncSeqSrc =
14431462 let put a s = AsyncSeq.AsyncSeqSrcImpl.put a s
14441463 let close s = AsyncSeq.AsyncSeqSrcImpl.close s
14451464 let toAsyncSeq s = AsyncSeq.AsyncSeqSrcImpl.toAsyncSeq s
1465+ let fail e s = AsyncSeq.AsyncSeqSrcImpl.fail e s
14461466
14471467module Seq =
14481468
0 commit comments