Skip to content

Commit 9b86913

Browse files
committed
Code review updates
1 parent 895b339 commit 9b86913

3 files changed

Lines changed: 97 additions & 79 deletions

File tree

driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,22 @@ public void close() {
7777

7878
Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) {
7979
return Mono.defer(() -> {
80-
Timeout.onExistsAndExpired(operationTimeout, () -> {
81-
throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE);
82-
});
8380
long sleepMicros = keyDecryptor.sleepMicroseconds();
81+
if (sleepMicros > 0 && operationTimeout != null) {
82+
operationTimeout.run(MICROSECONDS,
83+
() -> { },
84+
remainingMicros -> {
85+
if (remainingMicros < sleepMicros) {
86+
throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE);
87+
}
88+
},
89+
() -> {
90+
throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE);
91+
});
92+
}
8493
Mono<Void> attempt = sleepMicros > 0
85-
? Mono.delay(Duration.ofNanos(MICROSECONDS.toNanos(sleepMicros))).then(attemptDecryptKey(keyDecryptor, operationTimeout))
94+
? Mono.delay(Duration.ofNanos(MICROSECONDS.toNanos(sleepMicros)))
95+
.then(attemptDecryptKey(keyDecryptor, operationTimeout))
8696
: attemptDecryptKey(keyDecryptor, operationTimeout);
8797
return attempt.onErrorMap(this::unWrapException);
8898
});
@@ -101,22 +111,23 @@ private Mono<Void> attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Null
101111
LOGGER.info("Connecting to KMS server at " + serverAddress);
102112

103113
return Mono.<Boolean>create(sink -> {
104-
Stream stream = streamFactory.create(serverAddress);
105114
OperationContext operationContext = createOperationContext(operationTimeout, socketSettings);
115+
Stream stream = streamFactory.create(serverAddress);
106116
stream.openAsync(operationContext, new AsyncCompletionHandler<Void>() {
107117
@Override
108118
public void completed(@Nullable final Void ignored) {
109-
streamWrite(stream, keyDecryptor, operationContext, sink);
119+
try {
120+
streamWrite(stream, keyDecryptor, operationContext, sink);
121+
} catch (Throwable t) {
122+
stream.close();
123+
sink.error(t);
124+
}
110125
}
111126

112127
@Override
113128
public void failed(final Throwable t) {
114129
stream.close();
115-
if (keyDecryptor.fail()) {
116-
sink.success(true);
117-
} else {
118-
handleError(t, operationContext, sink);
119-
}
130+
failOrHandleError(t, keyDecryptor, operationContext, sink);
120131
}
121132
});
122133
}).flatMap(shouldRetry -> {
@@ -133,89 +144,84 @@ private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecrypt
133144
stream.writeAsync(byteBufs, operationContext, new AsyncCompletionHandler<Void>() {
134145
@Override
135146
public void completed(@Nullable final Void aVoid) {
136-
streamRead(stream, keyDecryptor, operationContext, sink, true);
147+
try {
148+
int readSize = Math.max(keyDecryptor.bytesNeeded(), MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE);
149+
streamRead(stream, keyDecryptor, operationContext, sink, readSize);
150+
} catch (Throwable t) {
151+
stream.close();
152+
sink.error(t);
153+
}
137154
}
138155

139156
@Override
140157
public void failed(final Throwable t) {
141158
stream.close();
142-
if (keyDecryptor.fail()) {
143-
sink.success(true);
144-
} else {
145-
handleError(t, operationContext, sink);
146-
}
159+
failOrHandleError(t, keyDecryptor, operationContext, sink);
147160
}
148161
});
149162
}
150163

151164
private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecryptor,
152165
final OperationContext operationContext, final MonoSink<Boolean> sink,
153-
final boolean firstRead) {
154-
int bytesNeeded = keyDecryptor.bytesNeeded();
155-
// After a fail()-triggered retry, libmongocrypt sets should_retry=true which causes
156-
// bytesNeeded() to return 0 until feedAndRetry() clears the flag. The first read of a
157-
// freshly-written request must happen regardless of bytesNeeded(); only recursive reads
158-
// can trust bytesNeeded()==0 as "done".
159-
if (bytesNeeded > 0 || firstRead) {
160-
int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE;
161-
AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream;
162-
ByteBuf buffer = asyncStream.getBuffer(readSize);
163-
long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS();
164-
asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null,
165-
new CompletionHandler<Integer, Void>() {
166+
final int readSize) {
167+
if (readSize <= 0) {
168+
stream.close();
169+
sink.success(false);
170+
return;
171+
}
172+
AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream;
173+
ByteBuf buffer = asyncStream.getBuffer(readSize);
174+
long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS();
175+
asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null,
176+
new CompletionHandler<Integer, Void>() {
166177

167-
@Override
168-
public void completed(final Integer integer, final Void aVoid) {
178+
@Override
179+
public void completed(final Integer integer, final Void aVoid) {
180+
try {
169181
if (integer == -1) {
170182
buffer.release();
171183
stream.close();
172-
if (keyDecryptor.fail()) {
173-
sink.success(true);
174-
} else {
175-
sink.error(new MongoException(
176-
"Unexpected end of stream from KMS provider "
177-
+ keyDecryptor.getKmsProvider()));
178-
}
184+
MongoException eof = new MongoException(
185+
"Unexpected end of stream from KMS provider "
186+
+ keyDecryptor.getKmsProvider());
187+
failOrHandleError(eof, keyDecryptor, operationContext, sink);
179188
return;
180189
}
181190
buffer.flip();
182191
boolean shouldRetry;
183192
try {
184193
shouldRetry = keyDecryptor.feedAndRetry(buffer.asNIO());
185-
} catch (Throwable t) {
186-
stream.close();
187-
sink.error(t);
188-
return;
189194
} finally {
190195
buffer.release();
191196
}
192197
if (shouldRetry) {
193198
stream.close();
194199
sink.success(true);
195200
} else {
196-
streamRead(stream, keyDecryptor, operationContext, sink, false);
201+
streamRead(stream, keyDecryptor, operationContext, sink,
202+
keyDecryptor.bytesNeeded());
197203
}
198-
}
199-
200-
@Override
201-
public void failed(final Throwable t, final Void aVoid) {
202-
buffer.release();
204+
} catch (Throwable t) {
203205
stream.close();
204-
if (keyDecryptor.fail()) {
205-
sink.success(true);
206-
} else {
207-
handleError(t, operationContext, sink);
208-
}
206+
sink.error(t);
209207
}
210-
});
211-
} else {
212-
stream.close();
213-
sink.success(false);
214-
}
208+
}
209+
210+
@Override
211+
public void failed(final Throwable t, final Void aVoid) {
212+
buffer.release();
213+
stream.close();
214+
failOrHandleError(t, keyDecryptor, operationContext, sink);
215+
}
216+
});
215217
}
216218

217-
private static void handleError(final Throwable t, final OperationContext operationContext, final MonoSink<Boolean> sink) {
218-
if (isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) {
219+
private static void failOrHandleError(final Throwable t, final MongoKeyDecryptor keyDecryptor,
220+
final OperationContext operationContext, final MonoSink<Boolean> sink) {
221+
if (keyDecryptor.fail()) {
222+
LOGGER.debug("Retrying KMS request after transient error", t);
223+
sink.success(true);
224+
} else if (isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) {
219225
sink.error(TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE, t));
220226
} else {
221227
sink.error(t);

driver-sync/src/main/com/mongodb/client/internal/Crypt.java

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import com.mongodb.internal.TimeoutContext;
2828
import com.mongodb.internal.capi.MongoCryptHelper;
2929
import com.mongodb.internal.crypt.capi.MongoCrypt;
30+
import com.mongodb.internal.diagnostics.logging.Logger;
31+
import com.mongodb.internal.diagnostics.logging.Loggers;
3032
import com.mongodb.internal.crypt.capi.MongoCryptContext;
3133
import com.mongodb.internal.crypt.capi.MongoDataKeyOptions;
3234
import com.mongodb.internal.crypt.capi.MongoKeyDecryptor;
@@ -59,6 +61,8 @@
5961
*/
6062
public class Crypt implements Closeable {
6163

64+
private static final Logger LOGGER = Loggers.getLogger("client");
65+
private static final String TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit.";
6266
private static final RawBsonDocument EMPTY_RAW_BSON_DOCUMENT = RawBsonDocument.parse("{}");
6367
private final MongoCrypt mongoCrypt;
6468
private final Map<String, Map<String, Object>> kmsProviders;
@@ -367,11 +371,22 @@ private void decryptKeys(final MongoCryptContext cryptContext, @Nullable final T
367371
private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout)
368372
throws IOException, InterruptedException {
369373
while (true) {
370-
Timeout.onExistsAndExpired(operationTimeout, () -> {
371-
throw TimeoutContext.createMongoTimeoutException("KMS key decryption exceeded the timeout limit.");
372-
});
373374
long sleepMicros = keyDecryptor.sleepMicroseconds();
374375
if (sleepMicros > 0) {
376+
if (operationTimeout != null) {
377+
operationTimeout.run(TimeUnit.MICROSECONDS,
378+
() -> { },
379+
remainingMicros -> {
380+
if (remainingMicros < sleepMicros) {
381+
throw TimeoutContext.createMongoTimeoutException(
382+
TIMEOUT_ERROR_MESSAGE);
383+
}
384+
},
385+
() -> {
386+
throw TimeoutContext.createMongoTimeoutException(
387+
TIMEOUT_ERROR_MESSAGE);
388+
});
389+
}
375390
TimeUnit.MICROSECONDS.sleep(sleepMicros);
376391
}
377392
boolean shouldRetry;
@@ -381,6 +396,7 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
381396
if (!keyDecryptor.fail()) {
382397
throw e;
383398
}
399+
LOGGER.debug("Retrying KMS request after transient error", e);
384400
continue;
385401
}
386402
if (!shouldRetry) {
@@ -391,31 +407,27 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Ti
391407

392408
private boolean attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout)
393409
throws IOException {
394-
// After a fail()-triggered retry, libmongocrypt sets should_retry=true which causes
395-
// bytesNeeded() to return 0 until feedAndRetry() clears the flag. The first read of a
396-
// freshly-written request must happen regardless of bytesNeeded(); only subsequent
397-
// iterations can trust bytesNeeded()==0 as "done".
410+
Timeout.onExistsAndExpired(operationTimeout, () -> {
411+
throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE);
412+
});
413+
// After a fail()-triggered retry, bytesNeeded() may return 0 until feedAndRetry() clears
414+
// the flag, so the do-while guarantees at least one read using DEFAULT_KMS_READ_SIZE.
398415
InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(),
399416
keyDecryptor.getMessage(), operationTimeout);
400417
Throwable primary = null;
401418
try {
402-
int bytesNeeded = keyDecryptor.bytesNeeded();
403-
boolean firstRead = true;
404-
while (bytesNeeded > 0 || firstRead) {
405-
int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE;
419+
int readSize = Math.max(keyDecryptor.bytesNeeded(), MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE);
420+
do {
406421
byte[] bytes = new byte[readSize];
407422
int bytesRead = inputStream.read(bytes, 0, bytes.length);
408423
if (bytesRead == -1) {
409-
// Surface EOF as IOException so the retry loop in decryptKey() can call fail() on it,
410-
// the same way connect/write/read IOExceptions are handled.
411424
throw new EOFException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider());
412425
}
413426
if (keyDecryptor.feedAndRetry(ByteBuffer.wrap(bytes, 0, bytesRead))) {
414427
return true;
415428
}
416-
bytesNeeded = keyDecryptor.bytesNeeded();
417-
firstRead = false;
418-
}
429+
readSize = keyDecryptor.bytesNeeded();
430+
} while (readSize > 0);
419431
return false;
420432
} catch (Throwable t) {
421433
primary = t;

driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
5353
import static com.mongodb.client.Fixture.getMongoClientSettings;
5454
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
55-
import static org.junit.jupiter.api.Assertions.assertNotNull;
55+
import static org.junit.jupiter.api.Assertions.assertEquals;
5656
import static org.junit.jupiter.api.Assertions.assertThrows;
5757
import static org.junit.jupiter.api.Assumptions.assumeTrue;
5858

@@ -240,7 +240,7 @@ private static void setFailpoint(final String failpointType, final int count) {
240240
}
241241

242242
int responseCode = connection.getResponseCode();
243-
assertNotNull(responseCode);
243+
assertEquals(200, responseCode, "Failed to set KMS failpoint, HTTP status: " + responseCode);
244244
connection.disconnect();
245245
} catch (Exception e) {
246246
throw new RuntimeException("Failed to set KMS failpoint", e);

0 commit comments

Comments
 (0)