diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index 7de253baac6..b0772260acd 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -764,6 +764,16 @@ functions: set +o xtrace MONGODB_URI="${MONGODB_URI}" KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE} .evergreen/run-kms-tls-tests.sh + "run-kms-retry-test": + - command: shell.exec + type: "test" + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + set +o xtrace + MONGODB_URI="${MONGODB_URI}" .evergreen/run-kms-retry-tests.sh + "run-csfle-aws-from-environment-test": - command: shell.exec type: "test" @@ -1632,6 +1642,17 @@ tasks: AUTH: "noauth" SSL: "nossl" + - name: "test-kms-retry-task" + tags: [ "kms-retry" ] + commands: + - func: "start-mongo-orchestration" + vars: + TOPOLOGY: "server" + AUTH: "noauth" + SSL: "nossl" + - func: "start-csfle-servers" + - func: "run-kms-retry-test" + - name: "test-csfle-aws-from-environment-task" tags: [ "csfle-aws-from-environment" ] commands: @@ -2528,6 +2549,12 @@ buildvariants: tasks: - name: ".kms-tls" + - matrix_name: "kms-retry-test" + matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] } + display_name: "CSFLE KMS Retry" + tasks: + - name: ".kms-retry" + - matrix_name: "csfle-aws-from-environment-test" matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] } display_name: "CSFLE AWS From Environment" diff --git a/.evergreen/run-kms-retry-tests.sh b/.evergreen/run-kms-retry-tests.sh new file mode 100755 index 00000000000..c9a96ddafa2 --- /dev/null +++ b/.evergreen/run-kms-retry-tests.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Don't trace since the URI contains a password that shouldn't show up in the logs +set -o errexit # Exit the script with error if any of the commands fail + +# Supported/used environment variables: +# MONGODB_URI Set the suggested connection MONGODB_URI (including credentials and topology info) + +############################################ +# Main Program # +############################################ +RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")" +. "${RELATIVE_DIR_PATH}/setup-env.bash" +echo "Running KMS Retry tests" + +cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore +${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt + +export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit" + +./gradlew -version + +# Disable errexit so both suites run and their exit codes can be captured below. +set +o errexit + +./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ + -Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \ + driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsRetryProseTest +first=$? +echo $first + +./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \ + -Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \ + driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsRetryProseTest +second=$? +echo $second + +if [ $first -ne 0 ]; then + exit $first +elif [ $second -ne 0 ]; then + exit $second +else + exit 0 +fi diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 20bc69e81f0..2daec943e5d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -236,10 +236,13 @@ public long read(ByteBufferSet dest) throws IOException { case NOT_HANDSHAKING: case FINISHED: UnwrapResult res = readAndUnwrap(Optional.of(dest)); + bytesToReturn = res.bytesProduced; if (res.wasClosed) { - return -1; + // JAVA-5391: return any bytes produced alongside close_notify; the next read + // sees shutdownReceived and returns -1. Fixed in upstream marianobarrios/tls-channel; + // this is the minimal patch until the vendored snapshot is refreshed. + return bytesToReturn > 0 ? bytesToReturn : -1; } - bytesToReturn = res.bytesProduced; handshakeStatus = res.lastHandshakeStatus; break; case NEED_TASK: diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java index 67ebf421c9c..076c8743f1b 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java @@ -48,10 +48,12 @@ import java.io.Closeable; import java.nio.channels.CompletionHandler; import java.nio.channels.InterruptedByTimeoutException; +import java.time.Duration; import java.util.List; import java.util.Map; import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.MICROSECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.bson.assertions.Assertions.assertTrue; @@ -74,6 +76,29 @@ public void close() { } Mono decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) { + return Mono.defer(() -> { + long sleepMicros = keyDecryptor.sleepMicroseconds(); + if (sleepMicros > 0 && operationTimeout != null) { + operationTimeout.run(MICROSECONDS, + () -> { }, + remainingMicros -> { + if (remainingMicros < sleepMicros) { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + } + }, + () -> { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + }); + } + Mono attempt = sleepMicros > 0 + ? Mono.delay(Duration.ofNanos(MICROSECONDS.toNanos(sleepMicros))) + .then(attemptDecryptKey(keyDecryptor, operationTimeout)) + : attemptDecryptKey(keyDecryptor, operationTimeout); + return attempt; + }).onErrorMap(this::unWrapException); + } + + private Mono attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) { SocketSettings socketSettings = SocketSettings.builder() .connectTimeout(timeoutMillis, MILLISECONDS) .readTimeout(timeoutMillis, MILLISECONDS) @@ -85,84 +110,119 @@ Mono decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Time LOGGER.info("Connecting to KMS server at " + serverAddress); - return Mono.create(sink -> { - Stream stream = streamFactory.create(serverAddress); + return Mono.create(sink -> { OperationContext operationContext = createOperationContext(operationTimeout, socketSettings); + Stream stream = streamFactory.create(serverAddress); stream.openAsync(operationContext, new AsyncCompletionHandler() { @Override public void completed(@Nullable final Void ignored) { - streamWrite(stream, keyDecryptor, operationContext, sink); + try { + streamWrite(stream, keyDecryptor, operationContext, sink); + } catch (Throwable t) { + stream.close(); + sink.error(t); + } } @Override public void failed(final Throwable t) { stream.close(); - handleError(t, operationContext, sink); + failOrHandleError(t, keyDecryptor, operationContext, sink); } }); - }).onErrorMap(this::unWrapException); + }).flatMap(shouldRetry -> { + if (shouldRetry) { + return decryptKey(keyDecryptor, operationTimeout); + } + return Mono.empty(); + }); } private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecryptor, - final OperationContext operationContext, final MonoSink sink) { + final OperationContext operationContext, final MonoSink sink) { List byteBufs = singletonList(new ByteBufNIO(keyDecryptor.getMessage())); stream.writeAsync(byteBufs, operationContext, new AsyncCompletionHandler() { @Override public void completed(@Nullable final Void aVoid) { - streamRead(stream, keyDecryptor, operationContext, sink); + try { + int bytesNeeded = keyDecryptor.bytesNeeded(); + int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE; + streamRead(stream, keyDecryptor, operationContext, sink, readSize); + } catch (Throwable t) { + stream.close(); + sink.error(t); + } } @Override public void failed(final Throwable t) { stream.close(); - handleError(t, operationContext, sink); + failOrHandleError(t, keyDecryptor, operationContext, sink); } }); } private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecryptor, - final OperationContext operationContext, final MonoSink sink) { - int bytesNeeded = keyDecryptor.bytesNeeded(); - if (bytesNeeded > 0) { - AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream; - ByteBuf buffer = asyncStream.getBuffer(bytesNeeded); - long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS(); - asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null, - new CompletionHandler() { - - @Override - public void completed(final Integer integer, final Void aVoid) { - if (integer == -1) { - sink.error(new MongoException( - "Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider())); - return; - } - buffer.flip(); - try { - keyDecryptor.feed(buffer.asNIO()); - buffer.release(); - streamRead(stream, keyDecryptor, operationContext, sink); - } catch (Throwable t) { - sink.error(t); - } - } - - @Override - public void failed(final Throwable t, final Void aVoid) { - buffer.release(); - stream.close(); - handleError(t, operationContext, sink); - } - }); - } else { + final OperationContext operationContext, final MonoSink sink, + final int readSize) { + if (readSize <= 0) { stream.close(); - sink.success(); + sink.success(false); + return; } + AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream; + ByteBuf buffer = asyncStream.getBuffer(readSize); + long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS(); + asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null, + new CompletionHandler() { + + @Override + public void completed(final Integer integer, final Void aVoid) { + try { + if (integer == -1) { + buffer.release(); + stream.close(); + MongoException eof = new MongoException("Unexpected end of stream from KMS provider " + + keyDecryptor.getKmsProvider()); + failOrHandleError(eof, keyDecryptor, operationContext, sink); + return; + } + buffer.flip(); + boolean shouldRetry; + try { + shouldRetry = keyDecryptor.feedAndRetry(buffer.asNIO()); + } finally { + buffer.release(); + } + if (shouldRetry) { + stream.close(); + sink.success(true); + } else { + streamRead(stream, keyDecryptor, operationContext, sink, + keyDecryptor.bytesNeeded()); + } + } catch (Throwable t) { + stream.close(); + sink.error(t); + } + } + + @Override + public void failed(final Throwable t, final Void aVoid) { + buffer.release(); + stream.close(); + failOrHandleError(t, keyDecryptor, operationContext, sink); + } + }); } - private static void handleError(final Throwable t, final OperationContext operationContext, final MonoSink sink) { + private static void failOrHandleError(final Throwable t, final MongoKeyDecryptor keyDecryptor, + final OperationContext operationContext, final MonoSink sink) { if (isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) { sink.error(TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE, t)); + } else if (keyDecryptor.fail()) { + LOGGER.debug("Retrying KMS request after transient error", t); + sink.success(true); } else { sink.error(t); } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..1dd4f0f833d --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/ClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.client.AbstractClientSideEncryptionKmsRetryProseTest; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.reactivestreams.client.syncadapter.SyncClientEncryption; +import com.mongodb.reactivestreams.client.vault.ClientEncryptions; + +public class ClientSideEncryptionKmsRetryProseTest extends AbstractClientSideEncryptionKmsRetryProseTest { + @Override + public ClientEncryption getClientEncryption(final ClientEncryptionSettings settings) { + return new SyncClientEncryption(ClientEncryptions.create(settings)); + } +} diff --git a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java index 67fac13770c..6039d64ae98 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/Crypt.java +++ b/driver-sync/src/main/com/mongodb/client/internal/Crypt.java @@ -24,8 +24,11 @@ import com.mongodb.client.model.vault.EncryptOptions; import com.mongodb.client.model.vault.RewrapManyDataKeyOptions; import com.mongodb.crypt.capi.MongoCryptException; +import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.capi.MongoCryptHelper; import com.mongodb.internal.crypt.capi.MongoCrypt; +import com.mongodb.internal.diagnostics.logging.Logger; +import com.mongodb.internal.diagnostics.logging.Loggers; import com.mongodb.internal.crypt.capi.MongoCryptContext; import com.mongodb.internal.crypt.capi.MongoDataKeyOptions; import com.mongodb.internal.crypt.capi.MongoKeyDecryptor; @@ -38,11 +41,13 @@ import org.bson.RawBsonDocument; import java.io.Closeable; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -56,6 +61,8 @@ */ public class Crypt implements Closeable { + private static final Logger LOGGER = Loggers.getLogger("client"); + private static final String TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit."; private static final RawBsonDocument EMPTY_RAW_BSON_DOCUMENT = RawBsonDocument.parse("{}"); private final MongoCrypt mongoCrypt; private final Map> kmsProviders; @@ -361,19 +368,83 @@ private void decryptKeys(final MongoCryptContext cryptContext, @Nullable final T } } - private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) throws IOException { - try (InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(), - keyDecryptor.getMessage(), operationTimeout)) { - int bytesNeeded = keyDecryptor.bytesNeeded(); + private void decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) + throws IOException, InterruptedException { + while (true) { + long sleepMicros = keyDecryptor.sleepMicroseconds(); + if (sleepMicros > 0) { + if (operationTimeout != null) { + operationTimeout.run(TimeUnit.MICROSECONDS, + () -> { }, + remainingMicros -> { + if (remainingMicros < sleepMicros) { + throw TimeoutContext.createMongoTimeoutException( + TIMEOUT_ERROR_MESSAGE); + } + }, + () -> { + throw TimeoutContext.createMongoTimeoutException( + TIMEOUT_ERROR_MESSAGE); + }); + } + TimeUnit.MICROSECONDS.sleep(sleepMicros); + } + boolean shouldRetry; + try { + shouldRetry = attemptDecryptKey(keyDecryptor, operationTimeout); + } catch (IOException e) { + if (!keyDecryptor.fail()) { + throw e; + } + LOGGER.debug("Retrying KMS request after transient error", e); + continue; + } + if (!shouldRetry) { + return; + } + } + } - while (bytesNeeded > 0) { - byte[] bytes = new byte[bytesNeeded]; + private boolean attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) + throws IOException { + Timeout.onExistsAndExpired(operationTimeout, () -> { + throw TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE); + }); + // After a fail()-triggered retry, bytesNeeded() may return 0 until feedAndRetry() clears + // the flag, so the do-while guarantees at least one read using DEFAULT_KMS_READ_SIZE. + InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(), + keyDecryptor.getMessage(), operationTimeout); + Throwable primary = null; + try { + int bytesNeeded = keyDecryptor.bytesNeeded(); + int readSize = bytesNeeded > 0 ? bytesNeeded : MongoKeyDecryptor.DEFAULT_KMS_READ_SIZE; + do { + byte[] bytes = new byte[readSize]; int bytesRead = inputStream.read(bytes, 0, bytes.length); if (bytesRead == -1) { - throw new MongoException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()); + throw new EOFException("Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()); + } + if (keyDecryptor.feedAndRetry(ByteBuffer.wrap(bytes, 0, bytesRead))) { + return true; + } + readSize = keyDecryptor.bytesNeeded(); + } while (readSize > 0); + return false; + } catch (Throwable t) { + primary = t; + throw t; + } finally { + // If the feed loop succeeded, suppress close() failures so they do not trigger a retry + // of an already-complete KMS exchange. If the feed loop threw, preserve the primary + // exception and attach any close() failure as suppressed — matching try-with-resources. + try { + inputStream.close(); + } catch (IOException closeException) { + if (primary != null) { + primary.addSuppressed(closeException); + } else { + LOGGER.debug("Ignoring close() failure after successful KMS exchange", closeException); } - keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead)); - bytesNeeded = keyDecryptor.bytesNeeded(); } } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..e7aa620d8dd --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,282 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.MongoClientException; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.client.model.vault.DataKeyOptions; +import com.mongodb.client.model.vault.EncryptOptions; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.lang.NonNull; +import com.mongodb.lang.Nullable; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import java.io.FileInputStream; +import java.io.OutputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.ClusterFixture.getEnv; +import static com.mongodb.ClusterFixture.hasEncryptionTestsEnabled; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.client.Fixture.getMongoClientSettings; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * See + * 24. KMS Retry Tests. + * + *

Requires the {@code org.mongodb.test.kms.retry.ca.path} system property pointing to the CA cert for the + * failpoint server. + */ +public abstract class AbstractClientSideEncryptionKmsRetryProseTest { + + private static final String FAILPOINT_SERVER_ADDRESS = "127.0.0.1:9003"; + private static final String FAILPOINT_URL_BASE = "https://" + FAILPOINT_SERVER_ADDRESS; + + @NonNull + protected abstract ClientEncryption getClientEncryption(ClientEncryptionSettings settings); + + @BeforeEach + public void setUp() { + assumeTrue(System.getProperty("org.mongodb.test.kms.retry.ca.path") != null, + "org.mongodb.test.kms.retry.ca.path system property is not set"); + } + + /** + * Case 1: createDataKey and encrypt with TCP retry. + */ + @ParameterizedTest(name = "Case 1: TCP retry with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyAndEncryptWithTcpRetry(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("network", 1); + BsonBinary keyId = assertDoesNotThrow( + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + + setFailpoint("network", 1); + assertDoesNotThrow( + () -> clientEncryption.encrypt(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(keyId))); + } + } + + /** + * Case 2: createDataKey and encrypt with HTTP retry. + */ + @ParameterizedTest(name = "Case 2: HTTP retry with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyAndEncryptWithHttpRetry(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("http", 1); + BsonBinary keyId = assertDoesNotThrow( + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + + setFailpoint("http", 1); + assertDoesNotThrow( + () -> clientEncryption.encrypt(new BsonInt32(123), + new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(keyId))); + } + } + + /** + * Case 3: createDataKey fails after too many retries. + */ + @ParameterizedTest(name = "Case 3: Exhausted retries with {0}") + @ValueSource(strings = {"aws", "azure", "gcp"}) + public void testCreateDataKeyFailsAfterTooManyRetries(final String provider) { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest()) { + setFailpoint("network", 4); + assertThrows(MongoClientException.class, + () -> clientEncryption.createDataKey(provider, getDataKeyOptions(provider))); + } + } + + /** + * Prose test: createDataKey surfaces MongoOperationTimeoutException when the operation timeout expires + * mid-retry. Configures a 100ms operation timeout and a failpoint that triggers repeated network errors; + * the cumulative retry backoff must push the operation past its deadline, and the expiry check at the + * top of each retry iteration must surface MongoOperationTimeoutException rather than MongoClientException. + */ + @Test + public void testCreateDataKeyTimesOutDuringRetry() { + assumeTrue(hasEncryptionTestsEnabled()); + assumeTrue(serverVersionAtLeast(4, 2)); + + try (ClientEncryption clientEncryption = createClientEncryptionForRetryTest(100L)) { + setFailpoint("network", 4); + assertThrows(MongoOperationTimeoutException.class, + () -> clientEncryption.createDataKey("aws", getDataKeyOptions("aws"))); + } + } + + private ClientEncryption createClientEncryptionForRetryTest() { + return createClientEncryptionForRetryTest(null); + } + + private ClientEncryption createClientEncryptionForRetryTest(@Nullable final Long timeoutMS) { + Map> kmsProviders = getKmsProvidersForRetryTest(); + SSLContext failpointSslContext = createFailpointSslContext(); + Map kmsProviderSslContextMap = new HashMap<>(); + kmsProviderSslContextMap.put("aws", failpointSslContext); + kmsProviderSslContextMap.put("azure", failpointSslContext); + kmsProviderSslContextMap.put("gcp", failpointSslContext); + + ClientEncryptionSettings.Builder builder = ClientEncryptionSettings.builder() + .keyVaultMongoClientSettings(getMongoClientSettings()) + .keyVaultNamespace("keyvault.datakeys") + .kmsProviders(kmsProviders) + .kmsProviderSslContextMap(kmsProviderSslContextMap); + if (timeoutMS != null) { + builder.timeout(timeoutMS, TimeUnit.MILLISECONDS); + } + + return getClientEncryption(builder.build()); + } + + private static Map> getKmsProvidersForRetryTest() { + return new HashMap>() {{ + put("aws", new HashMap() {{ + put("accessKeyId", getEnv("AWS_ACCESS_KEY_ID")); + put("secretAccessKey", getEnv("AWS_SECRET_ACCESS_KEY")); + }}); + put("azure", new HashMap() {{ + put("tenantId", getEnv("AZURE_TENANT_ID")); + put("clientId", getEnv("AZURE_CLIENT_ID")); + put("clientSecret", getEnv("AZURE_CLIENT_SECRET")); + put("identityPlatformEndpoint", FAILPOINT_SERVER_ADDRESS); + }}); + put("gcp", new HashMap() {{ + put("email", getEnv("GCP_EMAIL")); + put("privateKey", getEnv("GCP_PRIVATE_KEY")); + put("endpoint", FAILPOINT_SERVER_ADDRESS); + }}); + }}; + } + + private static DataKeyOptions getDataKeyOptions(final String provider) { + BsonDocument masterKey; + switch (provider) { + case "aws": + masterKey = new BsonDocument() + .append("region", new BsonString("foo")) + .append("key", new BsonString("bar")) + .append("endpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)); + break; + case "azure": + masterKey = new BsonDocument() + .append("keyVaultEndpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)) + .append("keyName", new BsonString("foo")); + break; + case "gcp": + masterKey = new BsonDocument() + .append("projectId", new BsonString("foo")) + .append("location", new BsonString("bar")) + .append("keyRing", new BsonString("baz")) + .append("keyName", new BsonString("qux")) + .append("endpoint", new BsonString(FAILPOINT_SERVER_ADDRESS)); + break; + default: + throw new UnsupportedOperationException("Unsupported KMS provider: " + provider); + } + return new DataKeyOptions().masterKey(masterKey); + } + + private static void setFailpoint(final String failpointType, final int count) { + try { + SSLContext sslContext = createFailpointSslContext(); + URL url = new URL(FAILPOINT_URL_BASE + "/set_failpoint/" + failpointType); + HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); + try { + connection.setConnectTimeout(10_000); + connection.setReadTimeout(10_000); + connection.setSSLSocketFactory(sslContext.getSocketFactory()); + connection.setHostnameVerifier((hostname, session) -> true); + connection.setRequestMethod("POST"); + connection.setDoOutput(true); + connection.setRequestProperty("Content-Type", "application/json"); + + byte[] body = ("{\"count\": " + count + "}").getBytes(StandardCharsets.UTF_8); + connection.setRequestProperty("Content-Length", String.valueOf(body.length)); + + try (OutputStream os = connection.getOutputStream()) { + os.write(body); + } + + int responseCode = connection.getResponseCode(); + assertEquals(200, responseCode, "Failed to set KMS failpoint, HTTP status: " + responseCode); + } finally { + connection.disconnect(); + } + } catch (Exception e) { + throw new RuntimeException("Failed to set KMS failpoint", e); + } + } + + private static SSLContext createFailpointSslContext() { + try { + String caCertPath = System.getProperty("org.mongodb.test.kms.retry.ca.path"); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate caCert; + try (FileInputStream fis = new FileInputStream(caCertPath)) { + caCert = (X509Certificate) cf.generateCertificate(fis); + } + + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", caCert); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(trustStore); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, tmf.getTrustManagers(), null); + return sslContext; + } catch (Exception e) { + throw new RuntimeException("Failed to create SSL context for failpoint server", e); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java new file mode 100644 index 00000000000..7b51bf915e5 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionKmsRetryProseTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.client.vault.ClientEncryptions; + +public class ClientSideEncryptionKmsRetryProseTest extends AbstractClientSideEncryptionKmsRetryProseTest { + @Override + public ClientEncryption getClientEncryption(final ClientEncryptionSettings settings) { + return ClientEncryptions.create(settings); + } +} diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java index 41cc8ced31b..280f3ea4205 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/CAPI.java @@ -22,6 +22,7 @@ import com.sun.jna.Native; import com.sun.jna.Pointer; import com.sun.jna.PointerType; +import com.sun.jna.ptr.ByteByReference; import com.sun.jna.ptr.PointerByReference; //CHECKSTYLE:OFF @@ -486,6 +487,21 @@ public interface mongocrypt_random_fn extends Callback { public static native void mongocrypt_setopt_bypass_query_analysis (mongocrypt_t crypt); + /** + * Opt-into handling the MONGOCRYPT_CTX_NEED_KMS state with retry logic. + * + *

If opted in, KMS requests will include retry information accessible via + * {@link #mongocrypt_kms_ctx_usleep}, {@link #mongocrypt_kms_ctx_feed_with_retry}, + * and {@link #mongocrypt_kms_ctx_fail}. + * + * @param crypt The @ref mongocrypt_t object to update + * @param enable Whether to enable KMS retry + * @return A boolean indicating success. If false, an error status is set. + * @since 5.8 + */ + public static native boolean + mongocrypt_setopt_retry_kms(mongocrypt_t crypt, boolean enable); + /** * Set the expiration time for the data encryption key cache. Defaults to 60 seconds if not set. * @@ -1164,6 +1180,47 @@ public interface mongocrypt_random_fn extends Callback { public static native boolean mongocrypt_kms_ctx_feed(mongocrypt_kms_ctx_t kms, mongocrypt_binary_t bytes); + /** + * Get the number of microseconds to sleep before sending the next KMS request. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * A return value of 0 indicates no delay is needed. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @return The number of microseconds to sleep, or 0. + * @since 5.8 + */ + public static native long + mongocrypt_kms_ctx_usleep(mongocrypt_kms_ctx_t kms); + + /** + * Feed bytes from the HTTP response, with retry support. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @param bytes The bytes to feed. + * @param should_retry Receives whether the driver should retry the KMS request. + * @return A boolean indicating success. + * @since 5.8 + */ + public static native boolean + mongocrypt_kms_ctx_feed_with_retry(mongocrypt_kms_ctx_t kms, + mongocrypt_binary_t bytes, + ByteByReference should_retry); + + /** + * Signal to libmongocrypt that a network error occurred on this KMS request. + * + *

Requires {@link #mongocrypt_setopt_retry_kms} to be enabled. + * + * @param kms The @ref mongocrypt_kms_ctx_t. + * @return True if the request should be retried, false if retries are exhausted. + * @since 5.8 + */ + public static native boolean + mongocrypt_kms_ctx_fail(mongocrypt_kms_ctx_t kms); + /** * Get the status associated with a @ref mongocrypt_kms_ctx_t object. diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java index 774b9e718cb..5731ca20689 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoCryptImpl.java @@ -73,6 +73,7 @@ import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_kms_provider_local; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_kms_providers; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_log_handler; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_retry_kms; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_schema_map; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_set_crypt_shared_lib_path_override; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_setopt_use_need_kms_credentials_state; @@ -198,6 +199,8 @@ class MongoCryptImpl implements MongoCrypt { mongocrypt_setopt_use_need_kms_credentials_state(wrapped); } + configure(() -> mongocrypt_setopt_retry_kms(wrapped, true)); + if (options.getKmsProviderOptions() != null) { withBinaryHolder(options.getKmsProviderOptions(), binary -> configure(() -> mongocrypt_setopt_kms_providers(wrapped, binary))); diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java index 9b0eae6776f..a8470433fe5 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptor.java @@ -24,6 +24,15 @@ */ public interface MongoKeyDecryptor { + /** + * Initial read size after re-sending a KMS request. Matches libmongocrypt's DEFAULT_MAX_KMS_BYTE_REQUEST + * and is used when {@link #bytesNeeded()} still returns 0 because libmongocrypt's should_retry flag has + * not yet been cleared by {@link #feedAndRetry}. + * + * @since 5.8 + */ + int DEFAULT_KMS_READ_SIZE = 1024; + /** * Gets the name of the KMS provider, e.g. "aws" or "kmip" * @@ -73,4 +82,29 @@ public interface MongoKeyDecryptor { * @param bytes the received bytes */ void feed(ByteBuffer bytes); + + /** + * Gets the number of microseconds to sleep before sending the next KMS request. + * + * @return the number of microseconds to sleep, or 0 if no delay is needed + * @since 5.8 + */ + long sleepMicroseconds(); + + /** + * Feed the received bytes to the decryptor, with retry support. + * + * @param bytes the received bytes + * @return true if the KMS request should be retried + * @since 5.8 + */ + boolean feedAndRetry(ByteBuffer bytes); + + /** + * Signal to libmongocrypt that a network error occurred on this KMS request. + * + * @return true if the request should be retried, false if retries are exhausted + * @since 5.8 + */ + boolean fail(); } diff --git a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java index 1411adffc21..f1ab70c3dbc 100644 --- a/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java +++ b/mongodb-crypt/src/main/com/mongodb/internal/crypt/capi/MongoKeyDecryptorImpl.java @@ -22,6 +22,7 @@ import com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_t; import com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_t; import com.sun.jna.Pointer; +import com.sun.jna.ptr.ByteByReference; import com.sun.jna.ptr.PointerByReference; import java.nio.ByteBuffer; @@ -30,9 +31,12 @@ import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_binary_new; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_bytes_needed; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_endpoint; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_fail; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_feed; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_feed_with_retry; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_get_kms_provider; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_message; +import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_usleep; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_kms_ctx_status; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_code; import static com.mongodb.internal.crypt.capi.CAPI.mongocrypt_status_destroy; @@ -42,6 +46,11 @@ import static com.mongodb.internal.crypt.capi.CAPIHelper.toByteBuffer; import static org.bson.assertions.Assertions.notNull; +/** + * Note: Not thread-safe: methods mutate the underlying native {@code mongocrypt_kms_ctx_t} and must be invoked serially. + * Callers perform retries sequentially — the sync driver in a {@code while} loop and the reactive driver via + * {@code Mono.flatMap} — so no external synchronization is required. + */ class MongoKeyDecryptorImpl implements MongoKeyDecryptor { private final mongocrypt_kms_ctx_t wrapped; @@ -96,6 +105,29 @@ public void feed(final ByteBuffer bytes) { } } + @Override + public long sleepMicroseconds() { + return mongocrypt_kms_ctx_usleep(wrapped); + } + + @Override + public boolean feedAndRetry(final ByteBuffer bytes) { + try (BinaryHolder binaryHolder = toBinary(bytes)) { + // Default 0 means "do not retry"; libmongocrypt writes 1 only when the driver should retry. + ByteByReference shouldRetry = new ByteByReference(); + boolean success = mongocrypt_kms_ctx_feed_with_retry(wrapped, binaryHolder.getBinary(), shouldRetry); + if (!success) { + throwExceptionFromStatus(); + } + return shouldRetry.getValue() != 0; + } + } + + @Override + public boolean fail() { + return mongocrypt_kms_ctx_fail(wrapped); + } + private void throwExceptionFromStatus() { mongocrypt_status_t status = mongocrypt_status_new(); mongocrypt_kms_ctx_status(wrapped, status);