@@ -77,12 +77,22 @@ public void close() {
7777
7878 Mono <Void > decryptKey (final MongoKeyDecryptor keyDecryptor , @ Nullable final Timeout operationTimeout ) {
7979 return Mono .defer (() -> {
80- Timeout .onExistsAndExpired (operationTimeout , () -> {
81- throw TimeoutContext .createMongoTimeoutException (TIMEOUT_ERROR_MESSAGE );
82- });
8380 long sleepMicros = keyDecryptor .sleepMicroseconds ();
81+ if (sleepMicros > 0 && operationTimeout != null ) {
82+ operationTimeout .run (MICROSECONDS ,
83+ () -> { },
84+ remainingMicros -> {
85+ if (remainingMicros < sleepMicros ) {
86+ throw TimeoutContext .createMongoTimeoutException (TIMEOUT_ERROR_MESSAGE );
87+ }
88+ },
89+ () -> {
90+ throw TimeoutContext .createMongoTimeoutException (TIMEOUT_ERROR_MESSAGE );
91+ });
92+ }
8493 Mono <Void > attempt = sleepMicros > 0
85- ? Mono .delay (Duration .ofNanos (MICROSECONDS .toNanos (sleepMicros ))).then (attemptDecryptKey (keyDecryptor , operationTimeout ))
94+ ? Mono .delay (Duration .ofNanos (MICROSECONDS .toNanos (sleepMicros )))
95+ .then (attemptDecryptKey (keyDecryptor , operationTimeout ))
8696 : attemptDecryptKey (keyDecryptor , operationTimeout );
8797 return attempt .onErrorMap (this ::unWrapException );
8898 });
@@ -101,22 +111,23 @@ private Mono<Void> attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Null
101111 LOGGER .info ("Connecting to KMS server at " + serverAddress );
102112
103113 return Mono .<Boolean >create (sink -> {
104- Stream stream = streamFactory .create (serverAddress );
105114 OperationContext operationContext = createOperationContext (operationTimeout , socketSettings );
115+ Stream stream = streamFactory .create (serverAddress );
106116 stream .openAsync (operationContext , new AsyncCompletionHandler <Void >() {
107117 @ Override
108118 public void completed (@ Nullable final Void ignored ) {
109- streamWrite (stream , keyDecryptor , operationContext , sink );
119+ try {
120+ streamWrite (stream , keyDecryptor , operationContext , sink );
121+ } catch (Throwable t ) {
122+ stream .close ();
123+ sink .error (t );
124+ }
110125 }
111126
112127 @ Override
113128 public void failed (final Throwable t ) {
114129 stream .close ();
115- if (keyDecryptor .fail ()) {
116- sink .success (true );
117- } else {
118- handleError (t , operationContext , sink );
119- }
130+ failOrHandleError (t , keyDecryptor , operationContext , sink );
120131 }
121132 });
122133 }).flatMap (shouldRetry -> {
@@ -133,89 +144,84 @@ private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecrypt
133144 stream .writeAsync (byteBufs , operationContext , new AsyncCompletionHandler <Void >() {
134145 @ Override
135146 public void completed (@ Nullable final Void aVoid ) {
136- streamRead (stream , keyDecryptor , operationContext , sink , true );
147+ try {
148+ int readSize = Math .max (keyDecryptor .bytesNeeded (), MongoKeyDecryptor .DEFAULT_KMS_READ_SIZE );
149+ streamRead (stream , keyDecryptor , operationContext , sink , readSize );
150+ } catch (Throwable t ) {
151+ stream .close ();
152+ sink .error (t );
153+ }
137154 }
138155
139156 @ Override
140157 public void failed (final Throwable t ) {
141158 stream .close ();
142- if (keyDecryptor .fail ()) {
143- sink .success (true );
144- } else {
145- handleError (t , operationContext , sink );
146- }
159+ failOrHandleError (t , keyDecryptor , operationContext , sink );
147160 }
148161 });
149162 }
150163
151164 private void streamRead (final Stream stream , final MongoKeyDecryptor keyDecryptor ,
152165 final OperationContext operationContext , final MonoSink <Boolean > sink ,
153- final boolean firstRead ) {
154- int bytesNeeded = keyDecryptor .bytesNeeded ();
155- // After a fail()-triggered retry, libmongocrypt sets should_retry=true which causes
156- // bytesNeeded() to return 0 until feedAndRetry() clears the flag. The first read of a
157- // freshly-written request must happen regardless of bytesNeeded(); only recursive reads
158- // can trust bytesNeeded()==0 as "done".
159- if (bytesNeeded > 0 || firstRead ) {
160- int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor .DEFAULT_KMS_READ_SIZE ;
161- AsynchronousChannelStream asyncStream = (AsynchronousChannelStream ) stream ;
162- ByteBuf buffer = asyncStream .getBuffer (readSize );
163- long readTimeoutMS = operationContext .getTimeoutContext ().getReadTimeoutMS ();
164- asyncStream .getChannel ().read (buffer .asNIO (), readTimeoutMS , MILLISECONDS , null ,
165- new CompletionHandler <Integer , Void >() {
166+ final int readSize ) {
167+ if (readSize <= 0 ) {
168+ stream .close ();
169+ sink .success (false );
170+ return ;
171+ }
172+ AsynchronousChannelStream asyncStream = (AsynchronousChannelStream ) stream ;
173+ ByteBuf buffer = asyncStream .getBuffer (readSize );
174+ long readTimeoutMS = operationContext .getTimeoutContext ().getReadTimeoutMS ();
175+ asyncStream .getChannel ().read (buffer .asNIO (), readTimeoutMS , MILLISECONDS , null ,
176+ new CompletionHandler <Integer , Void >() {
166177
167- @ Override
168- public void completed (final Integer integer , final Void aVoid ) {
178+ @ Override
179+ public void completed (final Integer integer , final Void aVoid ) {
180+ try {
169181 if (integer == -1 ) {
170182 buffer .release ();
171183 stream .close ();
172- if (keyDecryptor .fail ()) {
173- sink .success (true );
174- } else {
175- sink .error (new MongoException (
176- "Unexpected end of stream from KMS provider "
177- + keyDecryptor .getKmsProvider ()));
178- }
184+ MongoException eof = new MongoException (
185+ "Unexpected end of stream from KMS provider "
186+ + keyDecryptor .getKmsProvider ());
187+ failOrHandleError (eof , keyDecryptor , operationContext , sink );
179188 return ;
180189 }
181190 buffer .flip ();
182191 boolean shouldRetry ;
183192 try {
184193 shouldRetry = keyDecryptor .feedAndRetry (buffer .asNIO ());
185- } catch (Throwable t ) {
186- stream .close ();
187- sink .error (t );
188- return ;
189194 } finally {
190195 buffer .release ();
191196 }
192197 if (shouldRetry ) {
193198 stream .close ();
194199 sink .success (true );
195200 } else {
196- streamRead (stream , keyDecryptor , operationContext , sink , false );
201+ streamRead (stream , keyDecryptor , operationContext , sink ,
202+ keyDecryptor .bytesNeeded ());
197203 }
198- }
199-
200- @ Override
201- public void failed (final Throwable t , final Void aVoid ) {
202- buffer .release ();
204+ } catch (Throwable t ) {
203205 stream .close ();
204- if (keyDecryptor .fail ()) {
205- sink .success (true );
206- } else {
207- handleError (t , operationContext , sink );
208- }
206+ sink .error (t );
209207 }
210- });
211- } else {
212- stream .close ();
213- sink .success (false );
214- }
208+ }
209+
210+ @ Override
211+ public void failed (final Throwable t , final Void aVoid ) {
212+ buffer .release ();
213+ stream .close ();
214+ failOrHandleError (t , keyDecryptor , operationContext , sink );
215+ }
216+ });
215217 }
216218
217- private static void handleError (final Throwable t , final OperationContext operationContext , final MonoSink <Boolean > sink ) {
218- if (isTimeoutException (t ) && operationContext .getTimeoutContext ().hasTimeoutMS ()) {
219+ private static void failOrHandleError (final Throwable t , final MongoKeyDecryptor keyDecryptor ,
220+ final OperationContext operationContext , final MonoSink <Boolean > sink ) {
221+ if (keyDecryptor .fail ()) {
222+ LOGGER .debug ("Retrying KMS request after transient error" , t );
223+ sink .success (true );
224+ } else if (isTimeoutException (t ) && operationContext .getTimeoutContext ().hasTimeoutMS ()) {
219225 sink .error (TimeoutContext .createMongoTimeoutException (TIMEOUT_ERROR_MESSAGE , t ));
220226 } else {
221227 sink .error (t );
0 commit comments