5151import org .apache .spark .sql .catalyst .encoders .RowEncoder ;
5252import org .apache .spark .sql .execution .streaming .MemoryStream ;
5353import org .apache .spark .sql .streaming .StreamingQuery ;
54- import org .apache .spark .sql .streaming .StreamingQueryException ;
5554import org .apache .spark .sql .types .DataTypes ;
5655import org .apache .spark .sql .types .MetadataBuilder ;
5756import org .apache .spark .sql .types .StructField ;
6665import java .time .LocalDateTime ;
6766import java .time .ZoneOffset ;
6867import java .util .ArrayList ;
69- import java .util .LinkedList ;
70- import java .util .List ;
71- import java .util .TreeMap ;
7268
7369public class BatchCollectTest {
7470
@@ -87,11 +83,10 @@ public class BatchCollectTest {
8783 new StructField ("offset" , DataTypes .LongType , false , new MetadataBuilder ().build ())
8884 }
8985 );
90-
91- //@Test
92- public void testCollect () throws StreamingQueryException , InterruptedException {
93-
94- SparkSession sparkSession = SparkSession .builder ().master ("local[*]" ).getOrCreate ();
86+
87+ @ Test
88+ public void testCollectAsDataframe () {
89+ SparkSession sparkSession = SparkSession .builder ().master ("local[*]" ).getOrCreate ();
9590 SQLContext sqlContext = sparkSession .sqlContext ();
9691
9792 sparkSession .sparkContext ().setLogLevel ("ERROR" );
@@ -102,13 +97,11 @@ public void testCollect() throws StreamingQueryException, InterruptedException {
10297
10398 BatchCollect batchCollect = new BatchCollect ("_time" , 100 , null );
10499 Dataset <Row > rowDataset = rowMemoryStream .toDF ();
105- StreamingQuery streamingQuery = startStream (rowDataset , batchCollect );
100+ StreamingQuery streamingQuery = startStream (rowDataset , batchCollect , false );
106101
107102 long run = 0 ;
108- long counter = 1 ;
103+ long counter = 0 ;
109104 while (streamingQuery .isActive ()) {
110- //System.out.println(batchCollect.getCollected().size());
111-
112105 Timestamp time = Timestamp .valueOf (LocalDateTime .ofInstant (Instant .now (), ZoneOffset .UTC ));
113106 if (run == 3 ) {
114107 // make run 3 to be latest always
@@ -136,31 +129,32 @@ public void testCollect() throws StreamingQueryException, InterruptedException {
136129 counter = 0 ;
137130 }
138131 counter ++;
139- streamingQuery .processAllAvailable ();
140132
141133 if (run == 10 ) {
142134 // 10 runs only
143135 // wait until the source feeds them all?
144136 // TODO there must be a better way?
145- // streamingQuery.processAllAvailable();
146- streamingQuery .stop ();
147- streamingQuery .awaitTermination ();
137+ streamingQuery .processAllAvailable ();
138+ streamingQuery .stop ();
139+ Assertions . assertDoesNotThrow (() -> streamingQuery .awaitTermination () );
148140 }
149141 }
150142
151-
152- LinkedList <Integer > runs = new LinkedList <>();
153- runs .add (3 );
154- runs .add (6 );
155- runs .add (7 );
156- runs .add (8 );
157- runs .add (9 );
158- verifyRuns (batchCollect , runs );
143+ Dataset <Row > collectedAsDF = batchCollect .getCollectedAsDataframe ();
144+ Assertions .assertEquals (100 , collectedAsDF .count ());
145+
146+ // assert that batches are correct (the newest 100 rows of data)
147+ // batch number 3 is the newest in the test, others are in the order of creation
148+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("3" )).count ());
149+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("6" )).count ());
150+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("7" )).count ());
151+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("8" )).count ());
152+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("9" )).count ());
159153 }
160-
154+
161155 @ Test
162- public void testCollectAsDataframe () throws StreamingQueryException , InterruptedException {
163- SparkSession sparkSession = SparkSession .builder ().master ("local[*]" ).getOrCreate ();
156+ public void testSkipLimiting () {
157+ SparkSession sparkSession = SparkSession .builder ().master ("local[*]" ).getOrCreate ();
164158 SQLContext sqlContext = sparkSession .sqlContext ();
165159
166160 sparkSession .sparkContext ().setLogLevel ("ERROR" );
@@ -169,19 +163,25 @@ public void testCollectAsDataframe() throws StreamingQueryException, Interrupted
169163 MemoryStream <Row > rowMemoryStream =
170164 new MemoryStream <>(1 , sqlContext , encoder );
171165
172- BatchCollect batchCollect = new BatchCollect ("_time" , 100 , null );
166+ BatchCollect batchCollect = new BatchCollect ("_time" , 5 , new ArrayList <>() );
173167 Dataset <Row > rowDataset = rowMemoryStream .toDF ();
174- StreamingQuery streamingQuery = startStream (rowDataset , batchCollect );
168+
169+ // Skip limiting here
170+ StreamingQuery streamingQuery = startStream (rowDataset , batchCollect , true );
175171
176172 long run = 0 ;
177173 long counter = 0 ;
178174 while (streamingQuery .isActive ()) {
179- //System.out.println(batchCollect.getCollected().size());
180-
181175 Timestamp time = Timestamp .valueOf (LocalDateTime .ofInstant (Instant .now (), ZoneOffset .UTC ));
182176 if (run == 3 ) {
183177 // make run 3 to be latest always
184178 time = Timestamp .valueOf (LocalDateTime .ofInstant (Instant .ofEpochSecond (13851486065L +counter ), ZoneOffset .UTC ));
179+ } else if (run == 10 ) {
180+ // 10 runs only
181+ streamingQuery .processAllAvailable ();
182+ streamingQuery .stop ();
183+ Assertions .assertDoesNotThrow (() -> streamingQuery .awaitTermination ());
184+ break ;
185185 }
186186
187187 rowMemoryStream .addData (
@@ -199,27 +199,31 @@ public void testCollectAsDataframe() throws StreamingQueryException, Interrupted
199199 )
200200 );
201201
202+ counter ++;
203+
202204 // create 20 events for 10 runs
203205 if (counter == 20 ) {
204206 run ++;
205207 counter = 0 ;
206208 }
207- counter ++;
208-
209- if (run == 10 ) {
210- // 10 runs only
211- // wait until the source feeds them all?
212- // TODO there must be a better way?
213- streamingQuery .processAllAvailable ();
214- streamingQuery .stop ();
215- streamingQuery .awaitTermination ();
216- }
217209 }
218-
210+
219211 Dataset <Row > collectedAsDF = batchCollect .getCollectedAsDataframe ();
220- collectedAsDF .show (5 , true );
221- Assertions .assertTrue (collectedAsDF instanceof Dataset );
222- //Assertions.assertEquals(200, collectedAsDF.count());
212+
213+ // all the rows in the dataset, the limit of 5 rows is therefore not applied
214+ Assertions .assertEquals (200 , collectedAsDF .count ());
215+
216+ // assert that batches are correct (all the rows, 10 batches)
217+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("0" )).count ());
218+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("1" )).count ());
219+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("2" )).count ());
220+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("3" )).count ());
221+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("4" )).count ());
222+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("5" )).count ());
223+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("6" )).count ());
224+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("7" )).count ());
225+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("8" )).count ());
226+ Assertions .assertEquals (20 , collectedAsDF .filter (functions .col ("partition" ).equalTo ("9" )).count ());
223227 }
224228
225229 private Seq <Row > makeRows (Timestamp _time ,
@@ -255,46 +259,18 @@ private Seq<Row> makeRows(Timestamp _time,
255259 }
256260
257261
258- private StreamingQuery startStream (Dataset <Row > rowDataset , BatchCollect batchCollect ) {
262+ private StreamingQuery startStream (Dataset <Row > rowDataset , BatchCollect batchCollect , boolean skipLimiting ) {
259263 return rowDataset
260264 .writeStream ()
261265 .foreachBatch (
262266 new VoidFunction2 <Dataset <Row >, Long >() {
263267 @ Override
264- public void call (Dataset <Row > batchDF , Long batchId ) throws Exception {
265- batchCollect .collect (batchDF , batchId );
268+ public void call (Dataset <Row > batchDF , Long batchId ) {
269+ batchCollect .call (batchDF , batchId , skipLimiting );
266270 }
267271 }
268272 )
269273 .outputMode ("append" )
270274 .start ();
271275 }
272-
273- private void verifyRuns (BatchCollect batchCollect , LinkedList <Integer > runs ) {
274- // test that 0-4 batches added data to 100 slots
275- List <Row > collectedList = batchCollect .getCollected ();
276-
277- TreeMap <Integer , Long > runToRow = new TreeMap <>();
278-
279- int arraySize = collectedList .size ();
280- while (arraySize != 0 ) {
281- Row row = collectedList .get (arraySize - 1 );
282- int rowRun = Integer .parseInt (row .getString (6 ));
283-
284- if (runToRow .containsKey (rowRun )) {
285- long value = runToRow .get (rowRun );
286- value ++;
287- runToRow .put (rowRun , value );
288- }
289- else {
290- runToRow .put (rowRun , 1L );
291- }
292- arraySize --;
293-
294- }
295-
296- for (int run : runs ) {
297- Assertions .assertEquals (20 , runToRow .get (run ), "batch " + run +" contained other than 20 messages" );
298- }
299- }
300276}
0 commit comments