Skip to content

Commit 85e1603

Browse files
authored
Fix BatchCollect limiting the dataset even if skipLimiting is set to true (#23)
* Fix limiting the dataset even if skipLimiting is set to true * Remove getCollected(), fix BatchCollect tests
1 parent 2396554 commit 85e1603

2 files changed

Lines changed: 55 additions & 84 deletions

File tree

src/main/java/com/teragrep/functions/dpf_02/BatchCollect.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public void processAggregated(Dataset<Row> ds) {
138138
this.inputSchema = ds.schema();
139139
}
140140

141-
List<Row> collected = orderDataset(ds).limit(numberOfRows).collectAsList();
141+
List<Row> collected = orderDataset(ds).collectAsList();
142142
Dataset<Row> createdDsFromCollected = SparkSession.builder().getOrCreate().createDataFrame(collected, this.inputSchema);
143143

144144
if (this.savedDs == null) {
@@ -148,7 +148,7 @@ public void processAggregated(Dataset<Row> ds) {
148148
this.savedDs = savedDs.union(createdDsFromCollected);
149149
}
150150

151-
this.savedDs = orderDataset(this.savedDs).limit(numberOfRows);
151+
this.savedDs = orderDataset(this.savedDs);
152152
}
153153

154154
private Dataset<Row> orderDataset(Dataset<Row> ds) {
@@ -159,11 +159,6 @@ private Dataset<Row> orderDataset(Dataset<Row> ds) {
159159
}
160160
}
161161

162-
// TODO: Remove
163-
public List<Row> getCollected() {
164-
return getCollectedAsDataframe().collectAsList();
165-
}
166-
167162
public Dataset<Row> getCollectedAsDataframe() {
168163
Dataset<Row> rv;
169164
if (this.lastRowDs != null) {

src/test/java/BatchCollectTest.java

Lines changed: 53 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
5252
import org.apache.spark.sql.execution.streaming.MemoryStream;
5353
import org.apache.spark.sql.streaming.StreamingQuery;
54-
import org.apache.spark.sql.streaming.StreamingQueryException;
5554
import org.apache.spark.sql.types.DataTypes;
5655
import org.apache.spark.sql.types.MetadataBuilder;
5756
import org.apache.spark.sql.types.StructField;
@@ -66,9 +65,6 @@
6665
import java.time.LocalDateTime;
6766
import java.time.ZoneOffset;
6867
import java.util.ArrayList;
69-
import java.util.LinkedList;
70-
import java.util.List;
71-
import java.util.TreeMap;
7268

7369
public class BatchCollectTest {
7470

@@ -87,11 +83,10 @@ public class BatchCollectTest {
8783
new StructField("offset", DataTypes.LongType, false, new MetadataBuilder().build())
8884
}
8985
);
90-
91-
//@Test
92-
public void testCollect() throws StreamingQueryException, InterruptedException {
93-
94-
SparkSession sparkSession = SparkSession.builder().master("local[*]").getOrCreate();
86+
87+
@Test
88+
public void testCollectAsDataframe() {
89+
SparkSession sparkSession = SparkSession.builder().master("local[*]").getOrCreate();
9590
SQLContext sqlContext = sparkSession.sqlContext();
9691

9792
sparkSession.sparkContext().setLogLevel("ERROR");
@@ -102,13 +97,11 @@ public void testCollect() throws StreamingQueryException, InterruptedException {
10297

10398
BatchCollect batchCollect = new BatchCollect("_time", 100, null);
10499
Dataset<Row> rowDataset = rowMemoryStream.toDF();
105-
StreamingQuery streamingQuery = startStream(rowDataset, batchCollect);
100+
StreamingQuery streamingQuery = startStream(rowDataset, batchCollect, false);
106101

107102
long run = 0;
108-
long counter = 1;
103+
long counter = 0;
109104
while (streamingQuery.isActive()) {
110-
//System.out.println(batchCollect.getCollected().size());
111-
112105
Timestamp time = Timestamp.valueOf(LocalDateTime.ofInstant(Instant.now(), ZoneOffset.UTC));
113106
if (run == 3) {
114107
// make run 3 to be latest always
@@ -136,31 +129,32 @@ public void testCollect() throws StreamingQueryException, InterruptedException {
136129
counter = 0;
137130
}
138131
counter++;
139-
streamingQuery.processAllAvailable();
140132

141133
if (run == 10) {
142134
// 10 runs only
143135
// wait until the source feeds them all?
144136
// TODO there must be a better way?
145-
// streamingQuery.processAllAvailable();
146-
streamingQuery.stop();
147-
streamingQuery.awaitTermination();
137+
streamingQuery.processAllAvailable();
138+
streamingQuery.stop();
139+
Assertions.assertDoesNotThrow(() -> streamingQuery.awaitTermination());
148140
}
149141
}
150142

151-
152-
LinkedList<Integer> runs = new LinkedList<>();
153-
runs.add(3);
154-
runs.add(6);
155-
runs.add(7);
156-
runs.add(8);
157-
runs.add(9);
158-
verifyRuns(batchCollect, runs);
143+
Dataset<Row> collectedAsDF = batchCollect.getCollectedAsDataframe();
144+
Assertions.assertEquals(100, collectedAsDF.count());
145+
146+
// assert that batches are correct (the newest 100 rows of data)
147+
// batch number 3 is the newest in the test, others are in the order of creation
148+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("3")).count());
149+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("6")).count());
150+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("7")).count());
151+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("8")).count());
152+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("9")).count());
159153
}
160-
154+
161155
@Test
162-
public void testCollectAsDataframe() throws StreamingQueryException, InterruptedException {
163-
SparkSession sparkSession = SparkSession.builder().master("local[*]").getOrCreate();
156+
public void testSkipLimiting() {
157+
SparkSession sparkSession = SparkSession.builder().master("local[*]").getOrCreate();
164158
SQLContext sqlContext = sparkSession.sqlContext();
165159

166160
sparkSession.sparkContext().setLogLevel("ERROR");
@@ -169,19 +163,25 @@ public void testCollectAsDataframe() throws StreamingQueryException, Interrupted
169163
MemoryStream<Row> rowMemoryStream =
170164
new MemoryStream<>(1, sqlContext, encoder);
171165

172-
BatchCollect batchCollect = new BatchCollect("_time", 100, null);
166+
BatchCollect batchCollect = new BatchCollect("_time", 5, new ArrayList<>());
173167
Dataset<Row> rowDataset = rowMemoryStream.toDF();
174-
StreamingQuery streamingQuery = startStream(rowDataset, batchCollect);
168+
169+
// Skip limiting here
170+
StreamingQuery streamingQuery = startStream(rowDataset, batchCollect, true);
175171

176172
long run = 0;
177173
long counter = 0;
178174
while (streamingQuery.isActive()) {
179-
//System.out.println(batchCollect.getCollected().size());
180-
181175
Timestamp time = Timestamp.valueOf(LocalDateTime.ofInstant(Instant.now(), ZoneOffset.UTC));
182176
if (run == 3) {
183177
// make run 3 to be latest always
184178
time = Timestamp.valueOf(LocalDateTime.ofInstant(Instant.ofEpochSecond(13851486065L+counter), ZoneOffset.UTC));
179+
} else if (run == 10) {
180+
// 10 runs only
181+
streamingQuery.processAllAvailable();
182+
streamingQuery.stop();
183+
Assertions.assertDoesNotThrow(() -> streamingQuery.awaitTermination());
184+
break;
185185
}
186186

187187
rowMemoryStream.addData(
@@ -199,27 +199,31 @@ public void testCollectAsDataframe() throws StreamingQueryException, Interrupted
199199
)
200200
);
201201

202+
counter++;
203+
202204
// create 20 events for 10 runs
203205
if (counter == 20) {
204206
run++;
205207
counter = 0;
206208
}
207-
counter++;
208-
209-
if (run == 10) {
210-
// 10 runs only
211-
// wait until the source feeds them all?
212-
// TODO there must be a better way?
213-
streamingQuery.processAllAvailable();
214-
streamingQuery.stop();
215-
streamingQuery.awaitTermination();
216-
}
217209
}
218-
210+
219211
Dataset<Row> collectedAsDF = batchCollect.getCollectedAsDataframe();
220-
collectedAsDF.show(5, true);
221-
Assertions.assertTrue(collectedAsDF instanceof Dataset);
222-
//Assertions.assertEquals(200, collectedAsDF.count());
212+
213+
// all the rows in the dataset, the limit of 5 rows is therefore not applied
214+
Assertions.assertEquals(200, collectedAsDF.count());
215+
216+
// assert that batches are correct (all the rows, 10 batches)
217+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("0")).count());
218+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("1")).count());
219+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("2")).count());
220+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("3")).count());
221+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("4")).count());
222+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("5")).count());
223+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("6")).count());
224+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("7")).count());
225+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("8")).count());
226+
Assertions.assertEquals(20, collectedAsDF.filter(functions.col("partition").equalTo("9")).count());
223227
}
224228

225229
private Seq<Row> makeRows(Timestamp _time,
@@ -255,46 +259,18 @@ private Seq<Row> makeRows(Timestamp _time,
255259
}
256260

257261

258-
private StreamingQuery startStream(Dataset<Row> rowDataset, BatchCollect batchCollect) {
262+
private StreamingQuery startStream(Dataset<Row> rowDataset, BatchCollect batchCollect, boolean skipLimiting) {
259263
return rowDataset
260264
.writeStream()
261265
.foreachBatch(
262266
new VoidFunction2<Dataset<Row>, Long>() {
263267
@Override
264-
public void call(Dataset<Row> batchDF, Long batchId) throws Exception {
265-
batchCollect.collect(batchDF, batchId);
268+
public void call(Dataset<Row> batchDF, Long batchId) {
269+
batchCollect.call(batchDF, batchId, skipLimiting);
266270
}
267271
}
268272
)
269273
.outputMode("append")
270274
.start();
271275
}
272-
273-
private void verifyRuns(BatchCollect batchCollect, LinkedList<Integer> runs) {
274-
// test that 0-4 batches added data to 100 slots
275-
List<Row> collectedList = batchCollect.getCollected();
276-
277-
TreeMap<Integer, Long> runToRow = new TreeMap<>();
278-
279-
int arraySize = collectedList.size();
280-
while (arraySize != 0) {
281-
Row row = collectedList.get(arraySize - 1);
282-
int rowRun = Integer.parseInt(row.getString(6));
283-
284-
if(runToRow.containsKey(rowRun)) {
285-
long value = runToRow.get(rowRun);
286-
value++;
287-
runToRow.put(rowRun, value);
288-
}
289-
else {
290-
runToRow.put(rowRun, 1L);
291-
}
292-
arraySize--;
293-
294-
}
295-
296-
for(int run : runs) {
297-
Assertions.assertEquals(20, runToRow.get(run), "batch "+ run +" contained other than 20 messages");
298-
}
299-
}
300276
}

0 commit comments

Comments
 (0)