From 77f99cc5caa519116daab3e850fd1b468b0fa971 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 29 Apr 2026 17:49:06 +0100 Subject: [PATCH] Retry KMS requests on transient errors Add libmongocrypt CAPI bindings for KMS retry support and wire retry logic through the sync and reactive driver stacks. Transient KMS HTTP and network errors are retried with backoff delays managed by libmongocrypt; retry is enabled unconditionally. - Add native bindings: mongocrypt_setopt_retry_kms, mongocrypt_kms_ctx_usleep, mongocrypt_kms_ctx_feed_with_retry, mongocrypt_kms_ctx_fail - Add sleepMicroseconds(), feedAndRetry(), fail() to MongoKeyDecryptor - Enable KMS retry unconditionally in MongoCryptImpl - Rewrite sync Crypt.decryptKey() with retry loop, timeout-aware - Add retry logic to reactive KeyManagementService.decryptKey() - Fix TlsChannelImpl.read() to preserve bytes delivered alongside close_notify (already fixed upstream in marianobarrios/tls-channel) - Add spec Section 24 KMS retry integration tests (sync + reactive) - Add Evergreen CI task for KMS retry tests JAVA-5391 --- .evergreen/.evg.yml | 27 ++ .evergreen/run-kms-retry-tests.sh | 44 +++ .../tlschannel/impl/TlsChannelImpl.java | 7 +- .../internal/crypt/KeyManagementService.java | 148 ++++++--- ...ClientSideEncryptionKmsRetryProseTest.java | 30 ++ .../com/mongodb/client/internal/Crypt.java | 89 +++++- ...ClientSideEncryptionKmsRetryProseTest.java | 282 ++++++++++++++++++ ...ClientSideEncryptionKmsRetryProseTest.java | 28 ++ .../com/mongodb/internal/crypt/capi/CAPI.java | 57 ++++ .../internal/crypt/capi/MongoCryptImpl.java | 3 + .../crypt/capi/MongoKeyDecryptor.java | 34 +++ .../crypt/capi/MongoKeyDecryptorImpl.java | 32 ++ 12 files changed, 726 insertions(+), 55 deletions(-) create mode 100755 .evergreen/run-kms-retry-tests.sh create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index 7de253baac6..b0772260acd 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -764,6 +764,16 @@ functions: set +o xtrace MONGODB_URI="${MONGODB_URI}" KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE} .evergreen/run-kms-tls-tests.sh + "run-kms-retry-test": + - command: shell.exec + type: "test" + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + set +o xtrace + MONGODB_URI="${MONGODB_URI}" .evergreen/run-kms-retry-tests.sh + "run-csfle-aws-from-environment-test": - command: shell.exec type: "test" @@ -1632,6 +1642,17 @@ tasks: AUTH: "noauth" SSL: "nossl" + - name: "test-kms-retry-task" + tags: [ "kms-retry" ] + commands: + - func: "start-mongo-orchestration" + vars: + TOPOLOGY: "server" + AUTH: "noauth" + SSL: "nossl" + - func: "start-csfle-servers" + - func: "run-kms-retry-test" + - name: "test-csfle-aws-from-environment-task" tags: [ "csfle-aws-from-environment" ] commands: @@ -2528,6 +2549,12 @@ buildvariants: tasks: - name: ".kms-tls" + - matrix_name: "kms-retry-test" + matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] } + display_name: "CSFLE KMS Retry" + tasks: + - name: ".kms-retry" + - matrix_name: "csfle-aws-from-environment-test" matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] } display_name: "CSFLE AWS From Environment" diff --git a/.evergreen/run-kms-retry-tests.sh b/.evergreen/run-kms-retry-tests.sh new file mode 100755 index 00000000000..c9a96ddafa2 --- /dev/null +++ b/.evergreen/run-kms-retry-tests.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Don't trace since the URI contains a password that shouldn't show up in the logs +set -o errexit # Exit the script with error if any of the commands fail + +# Supported/used environment variables: +# MONGODB_URI Set the suggested connection MONGODB_URI (including credentials and topology info) + +############################################ +# Main Program # +############################################ +RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")" +. "${RELATIVE_DIR_PATH}/setup-env.bash" +echo "Running KMS Retry tests" + +cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore +${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt + +export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit" + +./gradlew -version + +# Disable errexit so both suites run and their exit codes can be captured below. +set +o errexit + +./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ + -Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \ + driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsRetryProseTest +first=$? +echo $first + +./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ + -Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \ + driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsRetryProseTest +second=$? +echo $second + +if [ $first -ne 0 ]; then + exit $first +elif [ $second -ne 0 ]; then + exit $second +else + exit 0 +fi diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 20bc69e81f0..2daec943e5d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -236,10 +236,13 @@ public long read(ByteBufferSet dest) throws IOException { case NOT_HANDSHAKING: case FINISHED: UnwrapResult res = readAndUnwrap(Optional.of(dest)); + bytesToReturn = res.bytesProduced; if (res.wasClosed) { - return -1; + // JAVA-5391: return any bytes produced alongside close_notify; the next read + // sees shutdownReceived and returns -1. Fixed in upstream marianobarrios/tls-channel; + // this is the minimal patch until the vendored snapshot is refreshed. + return bytesToReturn > 0 ? bytesToReturn : -1; } - bytesToReturn = res.bytesProduced; handshakeStatus = res.lastHandshakeStatus; break; case NEED_TASK: diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java index 67ebf421c9c..076c8743f1b 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java @@ -48,10 +48,12 @@ import java.io.Closeable; import java.nio.channels.CompletionHandler; import java.nio.channels.InterruptedByTimeoutException; +import java.time.Duration; import java.util.List; import java.util.Map; import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.MICROSECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.bson.assertions.Assertions.assertTrue; @@ -74,6 +76,29 @@ public void close() { } Mono decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) { + return Mono.defer(() -> { + long sleepMicros = keyDecryptor.sleepMicroseconds(); + if (sleepMicros > 0 && operationTimeout != null) { + operationTimeout.run(MICROSECONDS, + () -> { }, + remainingMicros -> { + if (remainingMicros < sleepMicros) { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + } + }, + () -> { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + }); + } + Mono attempt = sleepMicros > 0 + ? Mono.delay(Duration.ofNanos(MICROSECONDS.toNanos(sleepMicros))) + .then(attemptDecryptKey(keyDecryptor, operationTimeout)) + : attemptDecryptKey(keyDecryptor, operationTimeout); + return attempt; + }).onErrorMap(this::unWrapException); + } + + private Mono attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) { SocketSettings socketSettings = SocketSettings.builder() .connectTimeout(timeoutMillis, MILLISECONDS) .readTimeout(timeoutMillis, MILLISECONDS) @@ -85,84 +110,119 @@ Mono decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Time LOGGER.info("Connecting to KMS server at " + serverAddress); - return Mono.create(sink -> { - Stream stream = streamFactory.create(serverAddress); + return Mono.create(sink -> { OperationContext operationContext = createOperationContext(operationTimeout, socketSettings); + Stream stream = streamFactory.create(serverAddress); stream.openAsync(operationContext, new AsyncCompletionHandler() { @Override public void completed(@Nullable final Void ignored) { - streamWrite(stream, keyDecryptor, operationContext, sink); + try { + streamWrite(stream, keyDecryptor, operationContext, sink); + } catch (Throwable t) { + stream.close(); + sink.error(t); + } } @Override public void failed(final Throwable t) { stream.close(); - handleError(t, operationContext, sink); + failOrHandleError(t, keyDecryptor, operationContext, sink); } }); - }).onErrorMap(this::unWrapException); + }).flatMap(shouldRetry -> { + if (shouldRetry) { + return decryptKey(keyDecryptor, operationTimeout); + } + return Mono.empty(); + }); } private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecryptor, - final OperationContext operationContext, final MonoSink sink) { + final OperationContext operationContext, final MonoSink sink) { List byteBufs = singletonList(new ByteBufNIO(keyDecryptor.getMessage())); stream.writeAsync(byteBufs, operationContext, new AsyncCompletionHandler() { @Override public void completed(@Nullable final Void aVoid) { - streamRead(stream, keyDecryptor, operationContext, sink); + try { + int bytesNeeded = keyDecryptor.bytesNeeded(); + int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE; + streamRead(stream, keyDecryptor, operationContext, sink, readSize); + } catch (Throwable t) { + stream.close(); + sink.error(t); + } } @Override public void failed(final Throwable t) { stream.close(); - handleError(t, operationContext, sink); + failOrHandleError(t, keyDecryptor, operationContext, sink); } }); } private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecryptor, - final OperationContext operationContext, final MonoSink sink) { - int bytesNeeded = keyDecryptor.bytesNeeded(); - if (bytesNeeded > 0) { - AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream; - ByteBuf buffer = asyncStream.getBuffer(bytesNeeded); - long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS(); - asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null, - new CompletionHandler() { - - @Override - public void completed(final Integer integer, final Void aVoid) { - if (integer == -1) { - sink.error(new MongoException( - "Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider())); - return; - } - buffer.flip(); - try { - keyDecryptor.feed(buffer.asNIO()); - buffer.release(); - streamRead(stream, keyDecryptor, operationContext, sink); - } catch (Throwable t) { - sink.error(t); - } - } - - @Override - public void failed(final Throwable t, final Void aVoid) { - buffer.release(); - stream.close(); - handleError(t, operationContext, sink); - } - }); - } else { + final OperationContext operationContext, final MonoSink sink, + final int readSize) { + if (readSize <= 0) { stream.close(); - sink.success(); + sink.success(false); + return; } + AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream; + ByteBuf buffer = asyncStream.getBuffer(readSize); + long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS(); + asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null, + new CompletionHandler() { + + @Override + public void completed(final Integer integer, final Void aVoid) { + try { + if (integer == -1) { + buffer.release(); + stream.close(); + MongoException eof = new MongoException("Unexpected end of stream from KMS provider " + + keyDecryptor.getKmsProvider()); + failOrHandleError(eof, keyDecryptor, operationContext, sink); + return; + } + buffer.flip(); + boolean shouldRetry; + try { + shouldRetry = keyDecryptor.feedAndRetry(buffer.asNIO()); + } finally { + buffer.release(); + } + if (shouldRetry) { + stream.close(); + sink.success(true); + } else { + streamRead(stream, keyDecryptor, operationContext, sink, + keyDecryptor.bytesNeeded()); + } + } catch (Throwable t) { + stream.close(); + sink.error(t); + } + } + + @Override + public void failed(final Throwable t, final Void aVoid) { + buffer.release(); + stream.close(); + failOrHandleError(t, keyDecryptor, operationContext, sink); + } + }); } - private static void handleError(final Throwable t, final OperationContext operationContext, final MonoSink sink) { + private static void failOrHandleError(final Throwable t, final MongoKeyDecryptor keyDecryptor, + final OperationContext operationContext, final MonoSink sink) { if (isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) { sink.error(TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE, t)); + } else if (keyDecryptor.fail()) { + LOGGER.debug("Retrying KMS request after transient error", t); + sink.success(true); } else { sink.error(t); } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..1dd4f0f833d --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.client.AbstractClientSideEncryptionKmsRetryProseTest; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.reactivestreams.client.syncadapter.SyncClientEncryption; +import com.mongodb.reactivestreams.client.vault.ClientEncryptions; + +public class ClientSideEncryptionKmsRetryProseTest extends AbstractClientSideEncryptionKmsRetryProseTest { + @Override + public ClientEncryption getClientEncryption(final ClientEncryptionSettings settings) { + return new SyncClientEncryption(ClientEncryptions.create(settings)); + } +} diff --git a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java index 67fac13770c..6039d64ae98 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java +++ b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java @@ -24,8 +24,11 @@ import com.mongodb.client.model.vault.EncryptOptions; import com.mongodb.client.model.vault.RewrapManyDataKeyOptions; import com.mongodb.crypt.capi.MongoCryptException; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.capi.MongoCryptHelper; import com.mongodb.internal.crypt.capi.MongoCrypt; +import com.mongodb.internal.diagnostics.logging.Logger; +import com.mongodb.internal.diagnostics.logging.Loggers; import com.mongodb.internal.crypt.capi.MongoCryptContext; import com.mongodb.internal.crypt.capi.MongoDataKeyOptions; import com.mongodb.internal.crypt.capi.MongoKeyDecryptor; @@ -38,11 +41,13 @@ import org.bson.RawBsonDocument; import java.io.Closeable; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -56,6 +61,8 @@ */ public class Crypt implements Closeable { + private static final Logger LOGGER = Loggers.getLogger("client"); + private static final String TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit."; private static final RawBsonDocument EMPTY_RAW_BSON_DOCUMENT = RawBsonDocument.parse("{}"); private final MongoCrypt mongoCrypt; private final Map> kmsProviders; @@ -361,19 +368,83 @@ private void decryptKeys(final MongoCryptContext cryptContext, @Nullable final T } } - private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) throws IOException { - try (InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(), - keyDecryptor.getMessage(), operationTimeout)) { - int bytesNeeded = keyDecryptor.bytesNeeded(); + private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) + throws IOException, InterruptedException { + while (true) { + long sleepMicros = keyDecryptor.sleepMicroseconds(); + if (sleepMicros > 0) { + if (operationTimeout != null) { + operationTimeout.run(TimeUnit.MICROSECONDS, + () -> { }, + remainingMicros -> { + if (remainingMicros < sleepMicros) { + throw TimeoutContext.createMongoTimeoutException( + TIMEOUT_ERROR_MESSAGE); + } + }, + () -> { + throw TimeoutContext.createMongoTimeoutException( + TIMEOUT_ERROR_MESSAGE); + }); + } + TimeUnit.MICROSECONDS.sleep(sleepMicros); + } + boolean shouldRetry; + try { + shouldRetry = attemptDecryptKey(keyDecryptor, operationTimeout); + } catch (IOException e) { + if (!keyDecryptor.fail()) { + throw e; + } + LOGGER.debug("Retrying KMS request after transient error", e); + continue; + } + if (!shouldRetry) { + return; + } + } + } - while (bytesNeeded > 0) { - byte[] bytes = new byte[bytesNeeded]; + private boolean attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) + throws IOException { + Timeout.onExistsAndExpired(operationTimeout, () -> { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + }); + // After a fail()-triggered retry, bytesNeeded() may return 0 until feedAndRetry() clears + // the flag, so the do-while guarantees at least one read using DEFAULT_KMS_READ_SIZE. + InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(), + keyDecryptor.getMessage(), operationTimeout); + Throwable primary = null; + try { + int bytesNeeded = keyDecryptor.bytesNeeded(); + int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE; + do { + byte[] bytes = new byte[readSize]; int bytesRead = inputStream.read(bytes, 0, bytes.length); if (bytesRead == -1) { - throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()); + throw new EOFException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()); + } + if (keyDecryptor.feedAndRetry(ByteBuffer.wrap(bytes, 0, bytesRead))) { + return true; + } + readSize = keyDecryptor.bytesNeeded(); + } while (readSize > 0); + return false; + } catch (Throwable t) { + primary = t; + throw t; + } finally { + // If the feed loop succeeded, suppress close() failures so they do not trigger a retry + // of an already-complete KMS exchange. If the feed loop threw, preserve the primary + // exception and attach any close() failure as suppressed — matching try-with-resources. + try { + inputStream.close(); + } catch (IOException closeException) { + if (primary != null) { + primary.addSuppressed(closeException); + } else { + LOGGER.debug("Ignoring close() failure after successful KMS exchange", closeException); } - keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead)); - bytesNeeded = keyDecryptor.bytesNeeded(); } } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..e7aa620d8dd --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,282 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.MongoClientException; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.client.model.vault.DataKeyOptions; +import com.mongodb.client.model.vault.EncryptOptions; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.lang.NonNull; +import com.mongodb.lang.Nullable; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import java.io.FileInputStream; +import java.io.OutputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.ClusterFixture.getEnv; +import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.client.Fixture.getMongoClientSettings; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * See + * 24. KMS Retry Tests. + * + *

Requires the {@code org.mongodb.test.kms.retry.ca.path} system property pointing to the CA cert for the + * failpoint server. + */ +public abstract class AbstractClientSideEncryptionKmsRetryProseTest { + + private static final String FAILPOINT_SERVER_ADDRESS = "127.0.0.1:9003"; + private static final String FAILPOINT_URL_BASE = "https://" + FAILPOINT_SERVER_ADDRESS; + + @NonNull + protected abstract ClientEncryption getClientEncryption(ClientEncryptionSettings settings); + + @BeforeEach + public void setUp() { + assumeTrue(System.getProperty("org.mongodb.test.kms.retry.ca.path") != null, + "org.mongodb.test.kms.retry.ca.path system property is not set"); + } + + /** + * Case 1: createDataKey and encrypt with TCP retry. + */ + @ParameterizedTest(name = "Case 1: TCP retry with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyAndEncryptWithTcpRetry(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("network", 1); + BsonBinary keyId = assertDoesNotThrow( + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + + setFailpoint("network", 1); + assertDoesNotThrow( + () -> clientEncryption.encrypt(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(keyId))); + } + } + + /** + * Case 2: createDataKey and encrypt with HTTP retry. + */ + @ParameterizedTest(name = "Case 2: HTTP retry with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyAndEncryptWithHttpRetry(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("http", 1); + BsonBinary keyId = assertDoesNotThrow( + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + + setFailpoint("http", 1); + assertDoesNotThrow( + () -> clientEncryption.encrypt(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(keyId))); + } + } + + /** + * Case 3: createDataKey fails after too many retries. + */ + @ParameterizedTest(name = "Case 3: Exhausted retries with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyFailsAfterTooManyRetries(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("network", 4); + assertThrows(MongoClientException.class, + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + } + } + + /** + * Prose test: createDataKey surfaces MongoOperationTimeoutException when the operation timeout expires + * mid-retry. Configures a 100ms operation timeout and a failpoint that triggers repeated network errors; + * the cumulative retry backoff must push the operation past its deadline, and the expiry check at the + * top of each retry iteration must surface MongoOperationTimeoutException rather than MongoClientException. + */ + @Test + public void testCreateDataKeyTimesOutDuringRetry() { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest(100L)) { + setFailpoint("network", 4); + assertThrows(MongoOperationTimeoutException.class, + () -> clientEncryption.createDataKey("aws", getDataKeyOptions("aws"))); + } + } + + private ClientEncryption createClientEncryptionForRetryTest() { + return createClientEncryptionForRetryTest(null); + } + + private ClientEncryption createClientEncryptionForRetryTest(@Nullable final Long timeoutMS) { + Map> kmsProviders = getKmsProvidersForRetryTest(); + SSLContext failpointSslContext = createFailpointSslContext(); + Map kmsProviderSslContextMap = new HashMap<>(); + kmsProviderSslContextMap.put("aws", failpointSslContext); + kmsProviderSslContextMap.put("azure", failpointSslContext); + kmsProviderSslContextMap.put("gcp", failpointSslContext); + + ClientEncryptionSettings.Builder builder = ClientEncryptionSettings.builder() + .keyVaultMongoClientSettings(getMongoClientSettings()) + .keyVaultNamespace("keyvault.datakeys") + .kmsProviders(kmsProviders) + .kmsProviderSslContextMap(kmsProviderSslContextMap); + if (timeoutMS != null) { + builder.timeout(timeoutMS, TimeUnit.MILLISECONDS); + } + + return getClientEncryption(builder.build()); + } + + private static Map> getKmsProvidersForRetryTest() { + return new HashMap>() {{ + put("aws", new HashMap() {{ + put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID")); + put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY")); + }}); + put("azure", new HashMap() {{ + put("tenantId", getEnv("AZURE_TENANT_ID")); + put("clientId", getEnv("AZURE_CLIENT_ID")); + put("clientSecret", getEnv("AZURE_CLIENT_SECRET")); + put("identityPlatformEndpoint", FAILPOINT_SERVER_ADDRESS); + }}); + put("gcp", new HashMap() {{ + put("email", getEnv("GCP_EMAIL")); + put("privateKey", getEnv("GCP_PRIVATE_KEY")); + put("endpoint", FAILPOINT_SERVER_ADDRESS); + }}); + }}; + } + + private static DataKeyOptions getDataKeyOptions(final String provider) { + BsonDocument masterKey; + switch (provider) { + case "aws": + masterKey = new BsonDocument() + .append("region", new BsonString("foo")) + .append("key", new BsonString("bar")) + .append("endpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)); + break; + case "azure": + masterKey = new BsonDocument() + .append("keyVaultEndpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)) + .append("keyName", new BsonString("foo")); + break; + case "gcp": + masterKey = new BsonDocument() + .append("projectId", new BsonString("foo")) + .append("location", new BsonString("bar")) + .append("keyRing", new BsonString("baz")) + .append("keyName", new BsonString("qux")) + .append("endpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)); + break; + default: + throw new UnsupportedOperationException("Unsupported KMS provider: " + provider); + } + return new DataKeyOptions().masterKey(masterKey); + } + + private static void setFailpoint(final String failpointType, final int count) { + try { + SSLContext sslContext = createFailpointSslContext(); + URL url = new URL(FAILPOINT_URL_BASE + "/set_failpoint/" + failpointType); + HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); + try { + connection.setConnectTimeout(10_000); + connection.setReadTimeout(10_000); + connection.setSSLSocketFactory(sslContext.getSocketFactory()); + connection.setHostnameVerifier((hostname, session) -> true); + connection.setRequestMethod("POST"); + connection.setDoOutput(true); + connection.setRequestProperty("Content-Type", "application/json"); + + byte[] body = ("{\"count\": " + count + "}").getBytes(StandardCharsets.UTF_8); + connection.setRequestProperty("Content-Length", String.valueOf(body.length)); + + try (OutputStream os = connection.getOutputStream()) { + os.write(body); + } + + int responseCode = connection.getResponseCode(); + assertEquals(200, responseCode, "Failed to set KMS failpoint, HTTP status: " + responseCode); + } finally { + connection.disconnect(); + } + } catch (Exception e) { + throw new RuntimeException("Failed to set KMS failpoint", e); + } + } + + private static SSLContext createFailpointSslContext() { + try { + String caCertPath = System.getProperty("org.mongodb.test.kms.retry.ca.path"); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate caCert; + try (FileInputStream fis = new FileInputStream(caCertPath)) { + caCert = (X509Certificate) cf.generateCertificate(fis); + } + + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", caCert); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(trustStore); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, tmf.getTrustManagers(), null); + return sslContext; + } catch (Exception e) { + throw new RuntimeException("Failed to create SSL context for failpoint server", e); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..7b51bf915e5 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.client.vault.ClientEncryptions; + +public class ClientSideEncryptionKmsRetryProseTest extends AbstractClientSideEncryptionKmsRetryProseTest { + @Override + public ClientEncryption getClientEncryption(final ClientEncryptionSettings settings) { + return ClientEncryptions.create(settings); + } +} diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java index 41cc8ced31b..280f3ea4205 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java @@ -22,6 +22,7 @@ import com.sun.jna.Native; import com.sun.jna.Pointer; import com.sun.jna.PointerType; +import com.sun.jna.ptr.ByteByReference; import com.sun.jna.ptr.PointerByReference; //CHECKSTYLE:OFF @@ -486,6 +487,21 @@ public interface mongocrypt_random_fn extends Callback { public static native void mongocrypt_setopt_bypass_query_analysis (mongocrypt_t crypt); + /** + * Opt-into handling the MONGOCRYPT_CTX_NEED_KMS state with retry logic. + * + *

If opted in, KMS requests will include retry information accessible via + * {@link #mongocrypt_kms_ctx_usleep}, {@link #mongocrypt_kms_ctx_feed_with_retry}, + * and {@link #mongocrypt_kms_ctx_fail}. + * + * @param crypt The @ref mongocrypt_t object to update + * @param enable Whether to enable KMS retry + * @return A boolean indicating success. If false, an error status is set. + * @since 5.8 + */ + public static native boolean + mongocrypt_setopt_retry_kms(mongocrypt_t crypt, boolean enable); + /** * Set the expiration time for the data encryption key cache. Defaults to 60 seconds if not set. * @@ -1164,6 +1180,47 @@ public interface mongocrypt_random_fn extends Callback { public static native boolean mongocrypt_kms_ctx_feed(mongocrypt_kms_ctx_t kms, mongocrypt_binary_t bytes); + /** + * Get the number of microseconds to sleep before sending the next KMS request. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * A return value of 0 indicates no delay is needed. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @return The number of microseconds to sleep, or 0. + * @since 5.8 + */ + public static native long + mongocrypt_kms_ctx_usleep(mongocrypt_kms_ctx_t kms); + + /** + * Feed bytes from the HTTP response, with retry support. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @param bytes The bytes to feed. + * @param should_retry Receives whether the driver should retry the KMS request. + * @return A boolean indicating success. + * @since 5.8 + */ + public static native boolean + mongocrypt_kms_ctx_feed_with_retry(mongocrypt_kms_ctx_t kms, + mongocrypt_binary_t bytes, + ByteByReference should_retry); + + /** + * Signal to libmongocrypt that a network error occurred on this KMS request. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @return True if the request should be retried, false if retries are exhausted. + * @since 5.8 + */ + public static native boolean + mongocrypt_kms_ctx_fail(mongocrypt_kms_ctx_t kms); + /** * Get the status associated with a @ref mongocrypt_kms_ctx_t object. diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java index 774b9e718cb..5731ca20689 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java @@ -73,6 +73,7 @@ import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_kms_provider_local; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_kms_providers; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_log_handler; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_retry_kms; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_schema_map; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_set_crypt_shared_lib_path_override; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_use_need_kms_credentials_state; @@ -198,6 +199,8 @@ class MongoCryptImpl implements MongoCrypt { mongocrypt_setopt_use_need_kms_credentials_state(wrapped); } + configure(() -> mongocrypt_setopt_retry_kms(wrapped, true)); + if (options.getKmsProviderOptions() != null) { withBinaryHolder(options.getKmsProviderOptions(), binary -> configure(() -> mongocrypt_setopt_kms_providers(wrapped, binary))); diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java index 9b0eae6776f..a8470433fe5 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java @@ -24,6 +24,15 @@ */ public interface MongoKeyDecryptor { + /** + * Initial read size after re-sending a KMS request. Matches libmongocrypt's DEFAULT_MAX_KMS_BYTE_REQUEST + * and is used when {@link #bytesNeeded()} still returns 0 because libmongocrypt's should_retry flag has + * not yet been cleared by {@link #feedAndRetry}. + * + * @since 5.8 + */ + int DEFAULT_KMS_READ_SIZE = 1024; + /** * Gets the name of the KMS provider, e.g. "aws" or "kmip" * @@ -73,4 +82,29 @@ public interface MongoKeyDecryptor { * @param bytes the received bytes */ void feed(ByteBuffer bytes); + + /** + * Gets the number of microseconds to sleep before sending the next KMS request. + * + * @return the number of microseconds to sleep, or 0 if no delay is needed + * @since 5.8 + */ + long sleepMicroseconds(); + + /** + * Feed the received bytes to the decryptor, with retry support. + * + * @param bytes the received bytes + * @return true if the KMS request should be retried + * @since 5.8 + */ + boolean feedAndRetry(ByteBuffer bytes); + + /** + * Signal to libmongocrypt that a network error occurred on this KMS request. + * + * @return true if the request should be retried, false if retries are exhausted + * @since 5.8 + */ + boolean fail(); } diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java index 1411adffc21..f1ab70c3dbc 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java @@ -22,6 +22,7 @@ import com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_t; import com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_t; import com.sun.jna.Pointer; +import com.sun.jna.ptr.ByteByReference; import com.sun.jna.ptr.PointerByReference; import java.nio.ByteBuffer; @@ -30,9 +31,12 @@ import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_binary_new; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_bytes_needed; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_endpoint; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_fail; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_feed; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_feed_with_retry; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_get_kms_provider; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_message; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_usleep; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_status; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_code; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_destroy; @@ -42,6 +46,11 @@ import static com.mongodb.internal.crypt.capi.CAPIHelper.toByteBuffer; import static org.bson.assertions.Assertions.notNull; +/** + * Note: Not thread-safe: methods mutate the underlying native {@code mongocrypt_kms_ctx_t} and must be invoked serially. + * Callers perform retries sequentially — the sync driver in a {@code while} loop and the reactive driver via + * {@code Mono.flatMap} — so no external synchronization is required. + */ class MongoKeyDecryptorImpl implements MongoKeyDecryptor { private final mongocrypt_kms_ctx_t wrapped; @@ -96,6 +105,29 @@ public void feed(final ByteBuffer bytes) { } } + @Override + public long sleepMicroseconds() { + return mongocrypt_kms_ctx_usleep(wrapped); + } + + @Override + public boolean feedAndRetry(final ByteBuffer bytes) { + try (BinaryHolder binaryHolder = toBinary(bytes)) { + // Default 0 means "do not retry"; libmongocrypt writes 1 only when the driver should retry. + ByteByReference shouldRetry = new ByteByReference(); + boolean success = mongocrypt_kms_ctx_feed_with_retry(wrapped, binaryHolder.getBinary(), shouldRetry); + if (!success) { + throwExceptionFromStatus(); + } + return shouldRetry.getValue() != 0; + } + } + + @Override + public boolean fail() { + return mongocrypt_kms_ctx_fail(wrapped); + } + private void throwExceptionFromStatus() { mongocrypt_status_t status = mongocrypt_status_new(); mongocrypt_kms_ctx_status(wrapped, status);