Skip to content

Commit 6dd2f4d

Browse files
committed
groupBy impl
1 parent cf4c398 commit 6dd2f4d

2 files changed

Lines changed: 271 additions & 0 deletions

File tree

src/FSharp.Control.AsyncSeq/AsyncSeq.fs

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ type IAsyncEnumerable<'T> =
2525
type AsyncSeq<'T> = IAsyncEnumerable<'T>
2626
// abstract GetEnumerator : unit -> IAsyncEnumerator<'T>
2727

28+
type AsyncSeqSrc<'a> = private { mutable tl : TaskCompletionSource<('a * AsyncSeqSrc<'a>) option> }
29+
2830
[<AutoOpen>]
2931
module internal Utils =
3032
module internal Choice =
@@ -79,6 +81,138 @@ module internal Utils =
7981
elif i = 1 then return (Choice2Of2 (b.Result, a))
8082
else return! failwith (sprintf "unreachable, i = %d" i) }
8183

84+
85+
type internal MbReq<'a> =
86+
| Put of 'a
87+
| Take of AsyncReplyChannel<'a>
88+
89+
/// An unbounded FIFO mailbox.
90+
type Mb<'a> internal () =
91+
92+
let agent = MailboxProcessor.Start <| fun agent ->
93+
94+
let queue = new Queue<_>()
95+
96+
let rec loop () = async {
97+
match queue.Count with
98+
| 0 -> do! tryReceive ()
99+
| _ -> do! trySendOrReceive ()
100+
return! loop () }
101+
102+
and tryReceive () =
103+
agent.Scan (function
104+
| Put (a) -> Some (receive(a))
105+
| _ -> None)
106+
107+
and receive (a:'a) = async {
108+
return queue.Enqueue a }
109+
110+
and send (rep:AsyncReplyChannel<'a>) = async {
111+
let a = queue.Dequeue ()
112+
return rep.Reply a }
113+
114+
and trySendOrReceive () = async {
115+
let! msg = agent.Receive ()
116+
match msg with
117+
| Put a -> return! receive a
118+
| Take rep -> return! send rep }
119+
120+
loop ()
121+
122+
member __.Put (a:'a) =
123+
agent.Post (Put a)
124+
125+
member __.Take =
126+
agent.PostAndAsyncReply (fun ch -> Take ch)
127+
128+
interface IDisposable with
129+
member __.Dispose () = (agent :> IDisposable).Dispose()
130+
131+
132+
/// Operations on unbounded FIFO mailboxes.
133+
module Mb =
134+
135+
/// Creates a new unbounded mailbox.
136+
let create () = new Mb<'a> ()
137+
138+
/// Puts a message into a mailbox, no waiting.
139+
let put (a:'a) (mb:Mb<'a>) = mb.Put a
140+
141+
/// Creates an async computation that completes when a message is available in a mailbox.
142+
let take (mb:Mb<'a>) = mb.Take
143+
144+
145+
type internal BoundedMbReq<'a> =
146+
| Put of 'a * AsyncReplyChannel<unit>
147+
| Take of AsyncReplyChannel<'a>
148+
149+
type BoundedMb<'a> internal (capacity:int) =
150+
do if capacity <= 0 then invalidArg "capacity" "must be greater than 0"
151+
152+
let agent = MailboxProcessor.Start <| fun agent ->
153+
154+
let queue = new Queue<_>()
155+
156+
let rec loop () = async {
157+
match queue.Count with
158+
| 0 -> do! tryReceive ()
159+
| n when n = capacity -> do! trySend ()
160+
| _ -> do! trySendOrReceive ()
161+
return! loop () }
162+
163+
and tryReceive () =
164+
agent.Scan (function
165+
| Put (a,rep) -> Some (receive (a,rep))
166+
| _ -> None)
167+
168+
and receive (a:'a, rep:AsyncReplyChannel<unit>) = async {
169+
queue.Enqueue a
170+
rep.Reply () }
171+
172+
and trySend () =
173+
agent.Scan (function
174+
| Take rep -> Some (send rep)
175+
| _ -> None)
176+
177+
and send (rep:AsyncReplyChannel<'a>) = async {
178+
let a = queue.Dequeue ()
179+
return rep.Reply a }
180+
181+
and trySendOrReceive () = async {
182+
let! msg = agent.Receive ()
183+
match msg with
184+
| Put (a,rep) -> return! receive (a,rep)
185+
| Take rep -> return! send rep }
186+
187+
loop ()
188+
189+
member __.Put (a:'a) =
190+
agent.PostAndAsyncReply (fun ch -> Put (a, ch))
191+
192+
member __.Take =
193+
agent.PostAndAsyncReply (fun ch -> Take ch)
194+
195+
interface IDisposable with
196+
member __.Dispose () = (agent :> IDisposable).Dispose()
197+
198+
/// Operations on bounded FIFO mailboxes.
199+
module BoundedMb =
200+
201+
/// Creates a new unbounded mailbox.
202+
let create (capacity:int) = new BoundedMb<'a> (capacity)
203+
204+
/// Puts a message into a mailbox, no waiting.
205+
let put (a:'a) (mb:BoundedMb<'a>) = mb.Put a
206+
207+
/// Creates an async computation that completes when a message is available in a mailbox.
208+
let take (mb:BoundedMb<'a>) = mb.Take
209+
210+
211+
212+
213+
214+
215+
82216
/// Module with helper functions for working with asynchronous sequences
83217
module AsyncSeq =
84218

@@ -549,6 +683,38 @@ module AsyncSeq =
549683
i := i.Value + 1L
550684
yield v }
551685

686+
let mapAsyncParallel (f:'a -> Async<'b>) (s:AsyncSeq<'a>) = asyncSeq {
687+
use mb = Mb.create ()
688+
do! s |> iterAsync (fun a -> async {
689+
let! b = Async.StartChild (f a)
690+
mb |> Mb.put (Some b) })
691+
mb.Put None
692+
let rec loop () = asyncSeq {
693+
let! b = Mb.take mb
694+
match b with
695+
| None -> ()
696+
| Some b ->
697+
let! b = b
698+
yield b
699+
yield! loop () }
700+
yield! loop () }
701+
702+
let mapAsyncParallelBounded (parallelism:int) (f:'a -> Async<'b>) (s:AsyncSeq<'a>) = asyncSeq {
703+
use mb = BoundedMb.create (parallelism)
704+
do! s |> iterAsync (fun a -> async {
705+
let! b = Async.StartChild (f a)
706+
do! mb |> BoundedMb.put (Some b) })
707+
do! mb |> BoundedMb.put None
708+
let rec loop () = asyncSeq {
709+
let! b = BoundedMb.take mb
710+
match b with
711+
| None -> ()
712+
| Some b ->
713+
let! b = b
714+
yield b
715+
yield! loop () }
716+
yield! loop () }
717+
552718
let chooseAsync f (source : AsyncSeq<'T>) : AsyncSeq<'R> = asyncSeq {
553719
for itm in source do
554720
let! v = f itm
@@ -1200,6 +1366,62 @@ module AsyncSeq =
12001366
}
12011367

12021368

1369+
module AsyncSeqSrcImpl =
1370+
1371+
let create () : AsyncSeqSrc<'a> =
1372+
{ tl = new TaskCompletionSource<_>() }
1373+
1374+
let put (a:'a) (s:AsyncSeqSrc<'a>) : unit =
1375+
let s' = create ()
1376+
s.tl.SetResult(Some(a, s'))
1377+
s.tl <- s'.tl
1378+
1379+
let close (s:AsyncSeqSrc<'a>) : unit =
1380+
s.tl.SetResult(None)
1381+
1382+
let rec toAsyncSeq (s:AsyncSeqSrc<'a>) : AsyncSeq<'a> = asyncSeq {
1383+
let! next = s.tl.Task |> Async.AwaitTask
1384+
match next with
1385+
| None -> ()
1386+
| Some (a,tl) ->
1387+
yield a
1388+
yield! toAsyncSeq tl }
1389+
1390+
1391+
type private Group<'k, 'a> = { key : 'k ; src : AsyncSeqSrc<'a> }
1392+
1393+
let groupByAsync (p:'a -> Async<'k>) (s:AsyncSeq<'a>) : AsyncSeq<'k * AsyncSeq<'a>> = asyncSeq {
1394+
let groups = Collections.Generic.Dictionary<'k, Group<'k, 'a>>()
1395+
let close g =
1396+
groups.Remove(g.key) |> ignore
1397+
AsyncSeqSrcImpl.close g.src
1398+
use enum = s.GetEnumerator()
1399+
let rec go () = asyncSeq {
1400+
let! next = enum.MoveNext ()
1401+
match next with
1402+
| None ->
1403+
groups.Values |> Seq.toArray |> Array.iter close
1404+
| Some a ->
1405+
let! k = p a
1406+
let mutable g = Unchecked.defaultof<_>
1407+
if groups.TryGetValue(k, &g) then
1408+
AsyncSeqSrcImpl.put a g.src
1409+
yield! go ()
1410+
else
1411+
let src = AsyncSeqSrcImpl.create ()
1412+
AsyncSeqSrcImpl.put a src
1413+
let g = { key = k ; src = src }
1414+
groups.Add(k, g)
1415+
yield k,src |> AsyncSeqSrcImpl.toAsyncSeq
1416+
yield! go () }
1417+
yield! go () }
1418+
1419+
let groupBy (p:'a -> 'k) (s:AsyncSeq<'a>) : AsyncSeq<'k * AsyncSeq<'a>> =
1420+
groupByAsync (p >> async.Return) s
1421+
1422+
1423+
1424+
12031425
[<AutoOpen>]
12041426
module AsyncSeqExtensions =
12051427
let asyncSeq = new AsyncSeq.AsyncSeqBuilder()
@@ -1209,6 +1431,13 @@ module AsyncSeqExtensions =
12091431
member x.For (seq:AsyncSeq<'T>, action:'T -> Async<unit>) =
12101432
seq |> AsyncSeq.iterAsync action
12111433

1434+
module AsyncSeqSrc =
1435+
1436+
let create () = AsyncSeq.AsyncSeqSrcImpl.create ()
1437+
let put a s = AsyncSeq.AsyncSeqSrcImpl.put a s
1438+
let close s = AsyncSeq.AsyncSeqSrcImpl.close s
1439+
let toAsyncSeq s = AsyncSeq.AsyncSeqSrcImpl.toAsyncSeq s
1440+
12121441
module Seq =
12131442

12141443
let ofAsyncSeq (source : AsyncSeq<'T>) =

src/FSharp.Control.AsyncSeq/AsyncSeq.fsi

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,29 @@ module AsyncSeq =
380380
[<System.Obsolete("Use .GetEnumerator directly") >]
381381
val getIterator : source:AsyncSeq<'T> -> (unit -> Async<'T option>)
382382

383+
/// Builds a new asynchronous sequence whose elements are generated by
384+
/// applying the specified function to all elements of the input sequence.
385+
///
386+
/// The function is applied to elements in order but in parallel - without waiting
387+
/// for a prior mapping to complete.
388+
val mapAsyncParallel : mapping:('T -> Async<'U>) -> s:AsyncSeq<'T> -> AsyncSeq<'U>
389+
390+
/// Builds a new asynchronous sequence whose elements are generated by
391+
/// applying the specified function to all elements of the input sequence.
392+
///
393+
/// The function is applied to elements in order but in parallel - without waiting
394+
/// for a prior mapping to complete. Up to the specified number of mappings will
395+
/// occur in parallel before non-blocking waiting.
396+
val mapAsyncParallelBounded : parallelism:int -> mapping:('T -> Async<'U>) -> s:AsyncSeq<'T> -> AsyncSeq<'U>
397+
398+
/// Applies a key-generating function to each element and returns an async sequence containing unique keys
399+
/// and async sequences containing elements corresponding to the key.
400+
val groupByAsync<'T, 'Key when 'Key : equality> : projection:('T -> Async<'Key>) -> source:AsyncSeq<'T> -> AsyncSeq<'Key * AsyncSeq<'T>>
401+
402+
/// Applies a key-generating function to each element and returns an async sequence containing unique keys
403+
/// and async sequences containing elements corresponding to the key.
404+
val groupBy<'T, 'Key when 'Key : equality> : projection:('T -> 'Key) -> source:AsyncSeq<'T> -> AsyncSeq<'Key * AsyncSeq<'T>>
405+
383406
/// An automatically-opened module tht contains the `asyncSeq` builder and an extension method
384407
[<AutoOpen>]
385408
module AsyncSeqExtensions =
@@ -396,3 +419,22 @@ module Seq =
396419
/// The elements of the asynchronous sequence are consumed lazily.
397420
val ofAsyncSeq : source:AsyncSeq<'T> -> seq<'T>
398421

422+
423+
/// An async sequence source.
424+
type AsyncSeqSrc<'T>
425+
426+
/// Operations on async sequence sources.
427+
module AsyncSeqSrc =
428+
429+
/// Creates a new async sequence source.
430+
val create : unit -> AsyncSeqSrc<'T>
431+
432+
/// Puts an item into the async sequence source causing any created async sequences to yield the item.
433+
val put : item:'T -> src:AsyncSeqSrc<'T> -> unit
434+
435+
/// Closes the async sequence source casuing any created async sequences to terminate.
436+
val close : src:AsyncSeqSrc<'T> -> unit
437+
438+
/// Creates an async sequence which yields values as they are put into the source and terminates
439+
/// when the source is closed.
440+
val toAsyncSeq : src:AsyncSeqSrc<'T> -> AsyncSeq<'T>

0 commit comments

Comments
 (0)