Skip to content

Commit d4345e7

Browse files
committed
Adds chunkBySize and tests
1 parent 4bdcf15 commit d4345e7

3 files changed

Lines changed: 55 additions & 19 deletions

File tree

src/FSharp.Control.AsyncSeq/AsyncSeq.fs

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -297,20 +297,20 @@ module AsyncSeqOp =
297297
type OptimizedUnfoldEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
298298
let mutable currentState = init
299299
let mutable disposed = false
300-
300+
301301
interface IAsyncEnumerator<'T> with
302-
member __.MoveNext () : Async<'T option> =
302+
member __.MoveNext () : Async<'T option> =
303303
if disposed then async.Return None
304304
else async {
305305
let! result = f currentState
306306
match result with
307-
| None ->
307+
| None ->
308308
return None
309309
| Some (value, nextState) ->
310310
currentState <- nextState
311311
return Some value
312312
}
313-
member __.Dispose () =
313+
member __.Dispose () =
314314
disposed <- true
315315

316316
type UnfoldAsyncEnumerator<'S, 'T> (f:'S -> Async<('T * 'S) option>, init:'S) =
@@ -606,10 +606,10 @@ module AsyncSeq =
606606
// Optimized collect implementation using direct field access instead of ref cells
607607
type OptimizedCollectEnumerator<'T, 'U>(f: 'T -> AsyncSeq<'U>, inp: AsyncSeq<'T>) =
608608
// Mutable fields instead of ref cells to reduce allocations
609-
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
609+
let mutable inputEnumerator: IAsyncEnumerator<'T> option = None
610610
let mutable innerEnumerator: IAsyncEnumerator<'U> option = None
611611
let mutable disposed = false
612-
612+
613613
// Tail-recursive optimization to avoid deep continuation chains
614614
let rec moveNextLoop () : Async<'U option> = async {
615615
if disposed then return None
@@ -642,7 +642,7 @@ module AsyncSeq =
642642
inputEnumerator <- Some newOuter
643643
return! moveNextLoop ()
644644
}
645-
645+
646646
interface IAsyncEnumerator<'U> with
647647
member _.MoveNext() = moveNextLoop ()
648648
member _.Dispose() =
@@ -651,13 +651,13 @@ module AsyncSeq =
651651
match innerEnumerator with
652652
| Some inner -> inner.Dispose(); innerEnumerator <- None
653653
| None -> ()
654-
match inputEnumerator with
654+
match inputEnumerator with
655655
| Some outer -> outer.Dispose(); inputEnumerator <- None
656656
| None -> ()
657657

658658
let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
659659
{ new IAsyncEnumerable<'U> with
660-
member _.GetEnumerator() =
660+
member _.GetEnumerator() =
661661
new OptimizedCollectEnumerator<'T, 'U>(f, inp) :> IAsyncEnumerator<'U> }
662662

663663
// let collect (f: 'T -> AsyncSeq<'U>) (inp: AsyncSeq<'T>) : AsyncSeq<'U> =
@@ -749,7 +749,7 @@ module AsyncSeq =
749749
// Optimized iterAsync implementation to reduce allocations
750750
type internal OptimizedIterAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: 'T -> Async<unit>) =
751751
let mutable disposed = false
752-
752+
753753
member _.IterateAsync() =
754754
let rec loop() = async {
755755
let! next = enumerator.MoveNext()
@@ -760,17 +760,17 @@ module AsyncSeq =
760760
| None -> return ()
761761
}
762762
loop()
763-
763+
764764
interface IDisposable with
765765
member _.Dispose() =
766766
if not disposed then
767767
disposed <- true
768768
enumerator.Dispose()
769769

770-
// Optimized iteriAsync implementation with direct tail recursion
770+
// Optimized iteriAsync implementation with direct tail recursion
771771
type internal OptimizedIteriAsyncEnumerator<'T>(enumerator: IAsyncEnumerator<'T>, f: int -> 'T -> Async<unit>) =
772772
let mutable disposed = false
773-
773+
774774
member _.IterateAsync() =
775775
let rec loop count = async {
776776
let! next = enumerator.MoveNext()
@@ -781,7 +781,7 @@ module AsyncSeq =
781781
| None -> return ()
782782
}
783783
loop 0
784-
784+
785785
interface IDisposable with
786786
member _.Dispose() =
787787
if not disposed then
@@ -798,7 +798,7 @@ module AsyncSeq =
798798
let iterAsync (f: 'T -> Async<unit>) (source: AsyncSeq<'T>) =
799799
match source with
800800
| :? AsyncSeqOp<'T> as source -> source.IterAsync f
801-
| _ ->
801+
| _ ->
802802
async {
803803
let enum = source.GetEnumerator()
804804
use optimizer = new OptimizedIterAsyncEnumerator<_>(enum, f)
@@ -864,7 +864,7 @@ module AsyncSeq =
864864
// Optimized mapAsync enumerator that avoids computation builder overhead
865865
type private OptimizedMapAsyncEnumerator<'T, 'TResult>(source: IAsyncEnumerator<'T>, f: 'T -> Async<'TResult>) =
866866
let mutable disposed = false
867-
867+
868868
interface IAsyncEnumerator<'TResult> with
869869
member _.MoveNext() = async {
870870
let! moveResult = source.MoveNext()
@@ -874,7 +874,7 @@ module AsyncSeq =
874874
let! mapped = f value
875875
return Some mapped
876876
}
877-
877+
878878
member _.Dispose() =
879879
if not disposed then
880880
disposed <- true
@@ -885,7 +885,7 @@ module AsyncSeq =
885885
| :? AsyncSeqOp<'T> as source -> source.MapAsync f
886886
| _ ->
887887
{ new IAsyncEnumerable<'TResult> with
888-
member _.GetEnumerator() =
888+
member _.GetEnumerator() =
889889
new OptimizedMapAsyncEnumerator<'T, 'TResult>(source.GetEnumerator(), f) :> IAsyncEnumerator<'TResult> }
890890

891891
let mapiAsync f (source : AsyncSeq<'T>) : AsyncSeq<'TResult> = asyncSeq {
@@ -1125,6 +1125,22 @@ module AsyncSeq =
11251125
let filter f (source : AsyncSeq<'T>) =
11261126
filterAsync (f >> async.Return) source
11271127

1128+
let chunkBySize (chunkSize: int) (source: AsyncSeq<'T>) : AsyncSeq<'T array> =
1129+
if chunkSize < 1 then
1130+
invalidArg (nameof chunkSize) "must be greater than zero"
1131+
asyncSeq {
1132+
use enumerator = source.GetEnumerator()
1133+
let mutable isFinished = false
1134+
while not isFinished do
1135+
let chunk = ResizeArray<'T>(chunkSize)
1136+
while chunk.Count < chunkSize && not isFinished do
1137+
match! enumerator.MoveNext() with
1138+
| Some item -> chunk.Add(item)
1139+
| None -> isFinished <- true
1140+
if chunk.Count > 0 then
1141+
yield chunk.ToArray()
1142+
}
1143+
11281144
#if !FABLE_COMPILER
11291145
let iterAsyncParallel (f:'a -> Async<unit>) (s:AsyncSeq<'a>) : Async<unit> = async {
11301146
use mb = MailboxProcessor.Start (ignore >> async.Return)

src/FSharp.Control.AsyncSeq/AsyncSeq.fsi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ module AsyncSeq =
255255
/// and processes the input element immediately.
256256
val filter : predicate:('T -> bool) -> source:AsyncSeq<'T> -> AsyncSeq<'T>
257257

258+
/// Buffers elements up to a given chunk size and then yields the entire chunk.
259+
val chunkBySize : chunkSize: int -> source: AsyncSeq<'T> -> AsyncSeq<'T array>
260+
258261
/// Creates an asynchronous sequence that lazily takes element from an
259262
/// input synchronous sequence and returns them one-by-one.
260263
val ofSeq : source:seq<'T> -> AsyncSeq<'T>
@@ -524,7 +527,7 @@ module AsyncSeq =
524527
/// Builds a new asynchronous sequence whose elements are generated by
525528
/// applying the specified function to all elements of the input sequence.
526529
///
527-
/// The function is applied to elements in parallel, and results are emitted
530+
/// The function is applied to elements in parallel, and results are emitted
528531
/// in the order they complete (unordered), without preserving the original order.
529532
/// This can provide better performance than mapAsyncParallel when order doesn't matter.
530533
/// Parallelism is bound by the ThreadPool.

tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,23 @@ let ``AsyncSeq.filter``() =
889889
let expected = ls |> Seq.filter p |> AsyncSeq.ofSeq
890890
Assert.True(EQ expected actual)
891891

892+
[<Test>]
893+
let ``AsyncSeq.chunkBySize``() =
894+
let input = [ "a"; "b"; "c"; "d"; "e" ]
895+
let actual =
896+
input
897+
|> AsyncSeq.ofSeq
898+
|> AsyncSeq.chunkBySize 2
899+
|> AsyncSeq.toListSynchronously
900+
|> List.map List.ofSeq
901+
let expected =
902+
[
903+
[ "a"; "b" ]
904+
[ "c"; "d" ]
905+
[ "e" ]
906+
]
907+
Assert.AreEqual(expected, actual)
908+
892909
[<Test>]
893910
let ``AsyncSeq.merge``() =
894911
let ls1 = [1;2;3;4;5]

0 commit comments

Comments
 (0)