4747 */
4848
4949import org .apache .spark .sql .*;
50+ import org .apache .spark .sql .streaming .StreamingQueryException ;
5051import org .apache .spark .sql .types .StructType ;
5152import org .slf4j .Logger ;
5253import org .slf4j .LoggerFactory ;
@@ -57,22 +58,28 @@ public final class BatchCollect extends SortOperation {
5758 private static final Logger LOGGER = LoggerFactory .getLogger (BatchCollect .class );
5859 private Dataset <Row > savedDs = null ;
5960 private Dataset <Row > lastRowDs = null ;
61+ private Dataset <Row > outputDs = null ;
6062 private final String sortColumn ;
61- private final int numberOfRows ;
63+ private final int defaultLimit ;
64+ private final int postBcLimit ;
6265 private StructType inputSchema ;
6366 private boolean sortedBySingleColumn = false ;
6467
65- public BatchCollect (String sortColumn , int numberOfRows ) {
66- this (sortColumn , numberOfRows , new ArrayList <>());
68+ public BatchCollect (String sortColumn , int defaultLimit ) {
69+ this (sortColumn , defaultLimit , 0 , new ArrayList <>());
6770 }
6871
69- public BatchCollect (String sortColumn , int numberOfRows , List <SortByClause > listOfSortByClauses ) {
70- super (listOfSortByClauses );
72+ public BatchCollect (String sortColumn , int defaultLimit , int postBcLimit ) {
73+ this (sortColumn , defaultLimit , postBcLimit , new ArrayList <>());
74+ }
7175
72- LOGGER .info ("Initialized BatchCollect based on column " + sortColumn + " and a limit of " + numberOfRows + " row(s)." +
73- " SortByClauses included: " + (listOfSortByClauses != null ? listOfSortByClauses .size () : "<null>" ));
76+ public BatchCollect (String sortColumn , int defaultLimit , int postBcLimit , List <SortByClause > listOfSortByClauses ) {
77+ super (listOfSortByClauses );
78+ LOGGER .info ("Initialized BatchCollect based on column <[{}]> and a limit of <[{}]> row(s). SortByClauses included: <[{}]>. Post batchcollect limit of <[{}]> row(s)" ,
79+ sortColumn , defaultLimit , (listOfSortByClauses != null ? listOfSortByClauses .size () : "null" ), postBcLimit );
7480 this .sortColumn = sortColumn ;
75- this .numberOfRows = numberOfRows ;
81+ this .defaultLimit = defaultLimit ;
82+ this .postBcLimit = postBcLimit ;
7683 }
7784
7885 /**
@@ -84,31 +91,35 @@ public BatchCollect(String sortColumn, int numberOfRows, List<SortByClause> list
8491 */
8592 public Dataset <Row > call (Dataset <Row > df , Long id , boolean skipLimiting ) {
8693 Dataset <Row > rv ;
87- if (skipLimiting ) {
88- this .processAggregated (df );
89- }
90- else {
91- this .collect (df , id );
92- }
94+ this .collect (df , id , Collections .emptyList (), skipLimiting );
9395
9496 if (this .lastRowDs != null ) {
95- rv = this .savedDs .union (this .lastRowDs );
97+ rv = this .outputDs .union (this .lastRowDs );
9698 } else {
97- rv = this .savedDs ;
99+ rv = this .outputDs ;
98100 }
99101
100102 return rv ;
101103 }
102104
103- public void collect (Dataset <Row > batchDF , Long batchId ) {
105+ public void collect (Dataset <Row > batchDF , Long batchId , List <AbstractStep > postBcSteps , boolean skipLimiting ) {
106+ // Apply post-batchcollect limit if steps are present, otherwise use the default.
107+ // limit<=0 means no limit
108+ final int limit ;
109+ if (!postBcSteps .isEmpty ()) {
110+ limit = this .postBcLimit ;
111+ } else {
112+ limit = this .defaultLimit ;
113+ }
114+
104115 // check that sortColumn (_time) exists,
105116 // and get the sortColId
106117 // otherwise, no sorting will be done.
107118 if (this .inputSchema == null ) {
108119 this .inputSchema = batchDF .schema ();
109120 }
110121
111- if (this .getListOfSortByClauses () == null || this .getListOfSortByClauses ().size () < 1 ) {
122+ if (this .getListOfSortByClauses () == null || this .getListOfSortByClauses ().isEmpty () ) {
112123 for (String field : this .inputSchema .fieldNames ()) {
113124 if (field .equals (this .sortColumn )) {
114125 this .sortedBySingleColumn = true ;
@@ -117,38 +128,37 @@ public void collect(Dataset<Row> batchDF, Long batchId) {
117128 }
118129 }
119130
120- List <Row > collected = orderDataset (batchDF ).limit (numberOfRows ).collectAsList ();
131+ Dataset <Row > orderedDs = orderDataset (batchDF );
132+ if (!skipLimiting && limit > 0 ) {
133+ orderedDs = orderedDs .limit (limit );
134+ }
135+ List <Row > collected = orderedDs .collectAsList ();
121136 Dataset <Row > createdDsFromCollected = SparkSession .builder ().getOrCreate ().createDataFrame (collected , this .inputSchema );
122-
137+ Dataset < Row > current ;
123138 if (this .savedDs == null ) {
124- this . savedDs = createdDsFromCollected ;
139+ current = createdDsFromCollected ;
125140 }
126141 else {
127- this . savedDs = savedDs .union (createdDsFromCollected );
142+ current = savedDs .union (createdDsFromCollected );
128143 }
129144
130- this .savedDs = orderDataset (this .savedDs ).limit (numberOfRows );
131-
132- }
133-
134- // Call this instead of collect to skip limiting (for aggregatesUsed=true)
135- // TODO remove this
136- public void processAggregated (Dataset <Row > ds ) {
137- if (this .inputSchema == null ) {
138- this .inputSchema = ds .schema ();
145+ current = orderDataset (current );
146+ if (!skipLimiting && limit > 0 ) {
147+ current = current .limit (limit );
139148 }
140-
141- List <Row > collected = orderDataset (ds ).collectAsList ();
142- Dataset <Row > createdDsFromCollected = SparkSession .builder ().getOrCreate ().createDataFrame (collected , this .inputSchema );
143-
144- if (this .savedDs == null ) {
145- this .savedDs = createdDsFromCollected ;
146- }
147- else {
148- this .savedDs = savedDs .union (createdDsFromCollected );
149+ this .savedDs = current ;
150+
151+ // Post batchCollect steps processing
152+ Dataset <Row > rv = current ;
153+ for (final AbstractStep step : postBcSteps ) {
154+ try {
155+ rv = step .get (rv );
156+ } catch (StreamingQueryException e ) {
157+ throw new IllegalStateException ("Exception occurred while running post-batchcollect steps: " , e );
158+ }
149159 }
150160
151- this .savedDs = orderDataset ( this . savedDs ) ;
161+ this .outputDs = rv ;
152162 }
153163
154164 private Dataset <Row > orderDataset (Dataset <Row > ds ) {
@@ -162,16 +172,17 @@ private Dataset<Row> orderDataset(Dataset<Row> ds) {
162172 public Dataset <Row > getCollectedAsDataframe () {
163173 Dataset <Row > rv ;
164174 if (this .lastRowDs != null ) {
165- rv = this .savedDs .union (this .lastRowDs );
175+ rv = this .outputDs .union (this .lastRowDs );
166176 } else {
167- rv = this .savedDs ;
177+ rv = this .outputDs ;
168178 }
169179 return rv ;
170180 }
171181
172182 public void clear () {
173183 LOGGER .info ("dpf_02 cleared" );
174184 this .savedDs = null ;
185+ this .outputDs = null ;
175186 this .lastRowDs = null ;
176187 this .inputSchema = null ;
177188 }
0 commit comments