Add configurable anisotropic downsampling support to AutoencoderKL an…#8856
Add configurable anisotropic downsampling support to AutoencoderKL an…#8856shubham-61969 wants to merge 2 commits into
Conversation
…d relevant testcases Signed-off-by: Shubham Chandravanshi <[email protected]>
📝 WalkthroughWalkthroughThis PR adds per-level configurable downsampling and upsampling parameters to AutoencoderKL. New validation utilities enforce odd kernel sizes and normalize parameters across spatial dimensions. AEKLDownsample is refactored to accept kernel_size, stride, and padding. Encoder normalizes and applies per-level parameters; Decoder reverses them to compute per-dimension upsampling scale_factors. AutoencoderKL exposes downsample_parameters and wires encoder/decoder consistently. Tests cover valid anisotropic configs, validation errors, and reconstruction/shape robustness. Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/nets/autoencoderkl.py (1)
636-667:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftConvTranspose path ignores anisotropic stride.
When
use_convtranspose=True, theUpsamplecall doesn't receive the per-level stride and defaults to stride=2. This breaks anisotropic configurations (e.g., stride=(2,2,1)). Theupsampling_stridecomputed on line 638 is unused in this branch, while the nontrainable path correctly applies it asscale_factor.Proposed fix
if use_convtranspose: blocks.append( Upsample( - spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch + spatial_dims=spatial_dims, + mode="deconv", + in_channels=block_in_ch, + out_channels=block_in_ch, + scale_factor=tuple(float(s) for s in upsampling_stride), ) )Note: Anisotropic stride tests exist but don't exercise the convtranspose path, leaving this bug untested.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/autoencoderkl.py` around lines 636 - 667, The convtranspose branch for upsampling ignores the computed per-level upsampling_stride (variable upsampling_stride) and always uses the default stride, breaking anisotropic cases; modify the use_convtranspose branch in the loop that builds blocks so the Upsample(...) call for mode="deconv" receives the per-level scale/stride (e.g., pass scale_factor=tuple(float(s) for s in upsampling_stride) or the appropriate strides argument accepted by Upsample) so it uses the anisotropic upsampling_stride instead of the hardcoded default; update the Upsample(...) invocation in the use_convtranspose True branch (the block creating Upsample with mode="deconv" and in_channels=block_in_ch) to include that scale/stride parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 560-576: The test test_validation_even_kernel_raises_error
currently fails for the wrong reason because the supplied downsample_parameters
list length doesn't match the expected number of downsampling levels for the
provided channels, so the level-count validation triggers before kernel-size
validation; update the test to supply a downsample_parameters list whose length
matches the required number of levels for AutoencoderKL (e.g., for
channels=(4,4,4) provide two dicts) and ensure at least one dict uses an even
"kernel_size" (e.g., 4) so that instantiating AutoencoderKL(...) raises the
intended ValueError about even kernel sizes rather than the level-count
mismatch.
- Around line 578-595: The test
test_validation_invalid_tuple_length_raises_error is failing because the
level-count mismatch validation runs before the tuple-length check; to reach the
tuple-length validation you must provide two downsample parameter dicts in
downsample_parameters so the number of levels matches attention_levels and
channels length, then still include invalid tuple lengths (e.g., kernel_size and
stride with only 2 elements) to trigger ValueError from AutoencoderKL; update
the downsample_params used in the test (referenced variable downsample_params
and class AutoencoderKL) to contain two dicts with the bad tuples so the
tuple-length validation is exercised.
---
Outside diff comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 636-667: The convtranspose branch for upsampling ignores the
computed per-level upsampling_stride (variable upsampling_stride) and always
uses the default stride, breaking anisotropic cases; modify the
use_convtranspose branch in the loop that builds blocks so the Upsample(...)
call for mode="deconv" receives the per-level scale/stride (e.g., pass
scale_factor=tuple(float(s) for s in upsampling_stride) or the appropriate
strides argument accepted by Upsample) so it uses the anisotropic
upsampling_stride instead of the hardcoded default; update the Upsample(...)
invocation in the use_convtranspose True branch (the block creating Upsample
with mode="deconv" and in_channels=block_in_ch) to include that scale/stride
parameter.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8b8b88a6-f7ce-47d8-b866-6f242adcbc65
📒 Files selected for processing (2)
monai/networks/nets/autoencoderkl.pytests/networks/nets/test_autoencoderkl.py
Signed-off-by: Shubham Chandravanshi <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/nets/autoencoderkl.py (1)
668-692:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftPass anisotropic stride to deconv branch.
The deconv upsampling ignores
upsampling_stride(line 667) and defaults to isotropic ×2, while the nontrainable branch correctly passesscale_factor=tuple(float(s) for s in upsampling_stride). For anisotropic configs like(2, 2, 1), deconv will upscale incorrectly.Add
scale_factor=tuple(float(s) for s in upsampling_stride)to the deconv Upsample call (line 668-673).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/autoencoderkl.py` around lines 668 - 692, The deconv branch inside the upsampling construction (when use_convtranspose is True) currently creates an Upsample(mode="deconv", ...) but omits the anisotropic upsampling factor; pass the same computed scale factor used by the nontrainable branch by adding scale_factor=tuple(float(s) for s in upsampling_stride) to that Upsample(...) call so Upsample(mode="deconv", ...) uses the correct anisotropic upsampling_stride value (refer to the Upsample instantiation, use_convtranspose flag, and the upsampling_stride variable).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 130-139: The current None-handling for downsample_parameters
normalizes to symmetric padding via
_validate_kernel_stride_parameters/_compute_padding, changing legacy behavior;
revert to special-casing the legacy default when downsample_parameters is None
by returning per-level entries that match the original AsymmetricPad + Conv
semantics (kernel_size=3, stride=2, padding=0) instead of computed symmetric
padding—use the symbols downsample_parameters,
default_kernel_size/default_stride, spatial_dims and num_levels to locate the
branch and ensure each returned dict keeps padding=0 (caller is expected to
apply the AsymmetricPad((0,1)*spatial_dims) behavior externally) so existing
checkpoints keep the same behavior.
- Around line 85-99: The current _compute_padding that returns padding = tuple(k
// 2 for k in kernel_size) produces symmetric padding only and does not preserve
spatial sizes for non-divisible inputs; update the encoder/decoder to record
per-stage spatial outputs (target sizes) during encoding and use those targets
in reconstruct() to compute and apply either per-stage output_padding for
ConvTranspose (based on stride and recorded encoder sizes) or explicit cropping
after upsampling, rather than relying on fixed symmetric padding. Specifically,
modify the code paths around _compute_padding and the encoder forward that
produces ceil(n/stride) to store each intermediate spatial shape, and update the
decoder/ConvTranspose reconstruction logic (where output_padding or cropping is
applied) to use those stored sizes to guarantee exact recovery for
kernels/strides such as kernel=3,stride=2 (also fix the similar logic referenced
at lines 691-692).
---
Outside diff comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 668-692: The deconv branch inside the upsampling construction
(when use_convtranspose is True) currently creates an Upsample(mode="deconv",
...) but omits the anisotropic upsampling factor; pass the same computed scale
factor used by the nontrainable branch by adding scale_factor=tuple(float(s) for
s in upsampling_stride) to that Upsample(...) call so Upsample(mode="deconv",
...) uses the correct anisotropic upsampling_stride value (refer to the Upsample
instantiation, use_convtranspose flag, and the upsampling_stride variable).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 43f2a9ce-100e-4c3b-a9b2-baea9721e870
📒 Files selected for processing (2)
monai/networks/nets/autoencoderkl.pytests/networks/nets/test_autoencoderkl.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/networks/nets/test_autoencoderkl.py
| def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]: | ||
| """ | ||
| Compute symmetric padding for odd kernel sizes. | ||
|
|
||
| Padding is derived as: | ||
| padding[d] = kernel_size[d] // 2 | ||
|
|
||
| Args: | ||
| kernel_size: Kernel size for each spatial dimension. | ||
|
|
||
| Returns: | ||
| Tuple of padding values for each spatial dimension. | ||
| """ | ||
| padding = tuple(k // 2 for k in kernel_size) | ||
| return padding |
There was a problem hiding this comment.
Static symmetric padding still breaks odd/non-divisible reconstruction sizes.
padding[d] = kernel_size[d] // 2 makes the encoder output ceil(n / stride[d]), while the decoder only multiplies by stride[d]. For (kernel=3, stride=2), a size-5 axis becomes 3 after encode and 6 after decode, so reconstruct() can still return the wrong spatial shape on non-divisible inputs. This needs per-stage target sizes or crop/output-padding metadata, not just fixed padding.
Also applies to: 691-692
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/nets/autoencoderkl.py` around lines 85 - 99, The current
_compute_padding that returns padding = tuple(k // 2 for k in kernel_size)
produces symmetric padding only and does not preserve spatial sizes for
non-divisible inputs; update the encoder/decoder to record per-stage spatial
outputs (target sizes) during encoding and use those targets in reconstruct() to
compute and apply either per-stage output_padding for ConvTranspose (based on
stride and recorded encoder sizes) or explicit cropping after upsampling, rather
than relying on fixed symmetric padding. Specifically, modify the code paths
around _compute_padding and the encoder forward that produces ceil(n/stride) to
store each intermediate spatial shape, and update the decoder/ConvTranspose
reconstruction logic (where output_padding or cropping is applied) to use those
stored sizes to guarantee exact recovery for kernels/strides such as
kernel=3,stride=2 (also fix the similar logic referenced at lines 691-692).
| if downsample_parameters is None: | ||
| # Default: use provided defaults for all levels | ||
| default_ks_tuple, default_s_tuple = _validate_kernel_stride_parameters( | ||
| default_kernel_size, default_stride, spatial_dims | ||
| ) | ||
| default_padding = _compute_padding(default_ks_tuple) | ||
| return [ | ||
| {"kernel_size": default_ks_tuple, "stride": default_s_tuple, "padding": default_padding} | ||
| for _ in range(num_levels) | ||
| ] |
There was a problem hiding this comment.
downsample_parameters=None no longer preserves the legacy path.
This now normalizes the default case to kernel_size=3, stride=2, padding=1, but the previous implementation was AsymmetricPad((0, 1) * spatial_dims) + Conv(..., kernel_size=3, stride=2, padding=0). Those are not equivalent: a length-5 axis went to 2 before and goes to 3 here, and even even-sized inputs see different border context. Existing checkpoints on the default config will silently change behavior unless the legacy default is special-cased.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/nets/autoencoderkl.py` around lines 130 - 139, The current
None-handling for downsample_parameters normalizes to symmetric padding via
_validate_kernel_stride_parameters/_compute_padding, changing legacy behavior;
revert to special-casing the legacy default when downsample_parameters is None
by returning per-level entries that match the original AsymmetricPad + Conv
semantics (kernel_size=3, stride=2, padding=0) instead of computed symmetric
padding—use the symbols downsample_parameters,
default_kernel_size/default_stride, spatial_dims and num_levels to locate the
branch and ensure each returned dict keeps padding=0 (caller is expected to
apply the AsymmetricPad((0,1)*spatial_dims) behavior externally) so existing
checkpoints keep the same behavior.
…d relevant testcases
Fixes #8447.
Description
This PR adds configurable anisotropic downsampling support to
AutoencoderKL.Previously,
AutoencoderKLhardcoded:kernel_size=3stride=2This PR introduces configurable per-level and per-dimension downsampling parameters while preserving backward compatibility and encoder-decoder spatial consistency.
Key changes:
Added configurable downsampling parameters for
AEKLDownsampleAdded helper utilities for:
Added support for anisotropic configurations such as:
stride=(2,2,1)kernel_size=(3,3,1)Removed dependency on hardcoded asymmetric padding for configurable paths
Updated decoder upsampling to automatically mirror encoder downsampling configuration
Added validation for:
Added comprehensive tests covering:
This is particularly useful for medical imaging workloads with anisotropic voxel spacing such as CT and MRI volumes.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.