From 0a9077364536d7c81b902a8225d5a7540b931635 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 17 May 2026 19:41:49 +0530 Subject: [PATCH 1/2] Add configurable anisotropic downsampling support to AutoencoderKL and relevant testcases Signed-off-by: Shubham Chandravanshi --- monai/networks/nets/autoencoderkl.py | 246 ++++++++++++++++- tests/networks/nets/test_autoencoderkl.py | 316 ++++++++++++++++++++++ 2 files changed, 554 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 11b4fcfc9e..6c9e93a633 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -25,10 +25,167 @@ __all__ = ["AutoencoderKL"] +def _validate_kernel_stride_parameters( + kernel_size: int | tuple[int, ...] | None, + stride: int | tuple[int, ...] | None, + spatial_dims: int, + param_name: str = "parameter", +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Validate and normalize kernel_size and stride parameters. + + Args: + kernel_size: int or tuple of ints representing kernel size + stride: int or tuple of ints representing stride + spatial_dims: number of spatial dimensions + param_name: name of parameter for error messages + + Returns: + Tuple of (normalized_kernel_size, normalized_stride) + + Raises: + ValueError: if parameters are invalid + """ + if kernel_size is None or stride is None: + return None, None + + # Normalize kernel_size to tuple + if isinstance(kernel_size, int): + kernel_size_tuple = (kernel_size,) * spatial_dims + else: + kernel_size_tuple = tuple(kernel_size) + + # Normalize stride to tuple + if isinstance(stride, int): + stride_tuple = (stride,) * spatial_dims + else: + stride_tuple = tuple(stride) + + # Validate lengths + if len(kernel_size_tuple) != spatial_dims: + raise ValueError(f"{param_name} kernel_size must have length {spatial_dims}, got {len(kernel_size_tuple)}") + if len(stride_tuple) != spatial_dims: + raise ValueError(f"{param_name} stride must have length {spatial_dims}, got {len(stride_tuple)}") + + # Validate kernel sizes are odd + for i, k in enumerate(kernel_size_tuple): + if k % 2 == 0: + raise ValueError(f"{param_name} kernel_size at dimension {i} must be odd, got {k}") + + # Validate all values are positive integers + for i, (k, s) in enumerate(zip(kernel_size_tuple, stride_tuple)): + if not isinstance(k, int) or k <= 0: + raise ValueError(f"{param_name} kernel_size at dimension {i} must be positive int, got {k}") + if not isinstance(s, int) or s <= 0: + raise ValueError(f"{param_name} stride at dimension {i} must be positive int, got {s}") + + return kernel_size_tuple, stride_tuple + + +def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]: + """ + Compute symmetric padding from kernel size. + + For odd kernel sizes, padding = kernel_size // 2 on all sides. + + Args: + kernel_size: tuple of odd integers + + Returns: + Tuple of padding values (one per dimension) + """ + padding = tuple(k // 2 for k in kernel_size) + return padding + + +def _normalize_downsample_parameters( + downsample_parameters: list[dict] | dict | None, + num_levels: int, + spatial_dims: int, + default_kernel_size: int = 3, + default_stride: int = 2, +) -> list[dict]: + """ + Normalize downsampling parameters to canonical internal representation. + + Accepts: + - None: use defaults for all levels + - Single dict: apply same params to all levels + - List of dicts: one dict per level + + Each dict can specify: + - "kernel_size": int or tuple + - "stride": int or tuple + - "padding": int or tuple (auto-computed if omitted) + + Returns: + List of dicts with normalized keys: + - Each dict has "kernel_size", "stride", "padding" as tuples + - Length equals num_levels + + Raises: + ValueError: if parameters are invalid or inconsistent + """ + if downsample_parameters is None: + # Default: use provided defaults for all levels + default_ks_tuple, default_s_tuple = _validate_kernel_stride_parameters( + default_kernel_size, default_stride, spatial_dims + ) + default_padding = _compute_padding(default_ks_tuple) + return [ + {"kernel_size": default_ks_tuple, "stride": default_s_tuple, "padding": default_padding} + for _ in range(num_levels) + ] + + # If single dict, apply to all levels + if isinstance(downsample_parameters, dict): + params_list = [downsample_parameters] * num_levels + else: + params_list = list(downsample_parameters) + + # Validate we have the right number of levels + if len(params_list) != num_levels: + raise ValueError(f"Expected {num_levels} downsampling parameter dicts (one per level), got {len(params_list)}") + + # Normalize each dict + normalized = [] + for i, params in enumerate(params_list): + if not isinstance(params, dict): + raise ValueError(f"Downsampling parameters at level {i} must be dict, got {type(params)}") + + kernel_size = params.get("kernel_size", default_kernel_size) + stride = params.get("stride", default_stride) + padding = params.get("padding", None) + + # Validate and normalize kernel_size and stride + ks_tuple, s_tuple = _validate_kernel_stride_parameters(kernel_size, stride, spatial_dims, f"Level {i}") + + # Compute padding if not provided + if padding is None: + padding_tuple = _compute_padding(ks_tuple) + else: + # Normalize provided padding + if isinstance(padding, int): + padding_tuple = (padding,) * spatial_dims + else: + padding_tuple = tuple(padding) + + if len(padding_tuple) != spatial_dims: + raise ValueError(f"Level {i} padding must have length {spatial_dims}, got {len(padding_tuple)}") + + normalized.append({"kernel_size": ks_tuple, "stride": s_tuple, "padding": padding_tuple}) + + return normalized + + class AsymmetricPad(nn.Module): """ Pad the input tensor asymmetrically along every spatial dimension. + .. deprecated:: 0.10.0 + This class is deprecated and no longer used by `AEKLDownsample`. + Use configurable kernel_size and stride parameters instead (see `AEKLDownsample`). + Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. """ @@ -49,24 +206,46 @@ class AEKLDownsample(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. + kernel_size: kernel size for the convolution. Can be int or tuple. Default: 3. + stride: stride for the convolution. Can be int or tuple. Default: 2. + padding: padding for the convolution. If None, computed from kernel_size. Default: None. """ - def __init__(self, spatial_dims: int, in_channels: int) -> None: + def __init__( + self, + spatial_dims: int, + in_channels: int, + kernel_size: int | tuple[int, ...] = 3, + stride: int | tuple[int, ...] = 2, + padding: int | tuple[int, ...] | None = None, + ) -> None: super().__init__() - self.pad = AsymmetricPad(spatial_dims=spatial_dims) + + # Validate and normalize kernel_size and stride + kernel_size_tuple, stride_tuple = _validate_kernel_stride_parameters( + kernel_size, stride, spatial_dims, "AEKLDownsample" + ) + + # Compute padding if not provided + if padding is None: + padding_tuple = _compute_padding(kernel_size_tuple) + else: + if isinstance(padding, int): + padding_tuple = (padding,) * spatial_dims + else: + padding_tuple = tuple(padding) self.conv = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, - strides=2, - kernel_size=3, - padding=0, + strides=stride_tuple, + kernel_size=kernel_size_tuple, + padding=padding_tuple, conv_only=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pad(x) x = self.conv(x) return x @@ -160,6 +339,7 @@ class Encoder(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: list of dicts specifying kernel_size, stride, padding for each downsampling level. """ def __init__( @@ -176,6 +356,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -187,6 +368,15 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels + # Normalize downsampling parameters + num_downsample_levels = len(channels) - 1 + normalized_downsample_params = _normalize_downsample_parameters( + downsample_parameters, num_downsample_levels, spatial_dims + ) + + # Store for decoder to use + self.downsample_parameters = normalized_downsample_params + blocks: list[nn.Module] = [] # Initial convolution blocks.append( @@ -203,6 +393,7 @@ def __init__( # Residual and downsampling blocks output_channel = channels[0] + downsample_idx = 0 for i in range(len(channels)): input_channel = output_channel output_channel = channels[i] @@ -233,7 +424,19 @@ def __init__( ) if not is_final_block: - blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + # Use downsampling parameters for this level + downsample_params = normalized_downsample_params[downsample_idx] + blocks.append( + AEKLDownsample( + spatial_dims=spatial_dims, + in_channels=input_channel, + kernel_size=downsample_params["kernel_size"], + stride=downsample_params["stride"], + padding=downsample_params["padding"], + ) + ) + downsample_idx += 1 + # Non-local attention block if with_nonlocal_attn is True: blocks.append( @@ -307,6 +510,7 @@ class Decoder(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: list of dicts with encoder downsampling parameters (strides). """ def __init__( @@ -324,6 +528,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | None = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -335,6 +540,12 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels + # Normalize downsampling parameters to get strides for upsampling + num_downsample_levels = len(channels) - 1 + normalized_downsample_params = _normalize_downsample_parameters( + downsample_parameters, num_downsample_levels, spatial_dims + ) + reversed_block_out_channels = list(reversed(channels)) blocks: list[nn.Module] = [] @@ -387,6 +598,10 @@ def __init__( reversed_attention_levels = list(reversed(attention_levels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) block_out_ch = reversed_block_out_channels[0] + + # Reverse downsample parameters for use during upsampling + reversed_downsample_params = list(reversed(normalized_downsample_params)) + for i in range(len(reversed_block_out_channels)): block_in_ch = block_out_ch block_out_ch = reversed_block_out_channels[i] @@ -418,6 +633,10 @@ def __init__( ) if not is_final_block: + # Use stride from encoder downsample as scale_factor for upsampling + # reversed_downsample_params[i] corresponds to the downsampling level we need to upsample + upsampling_stride = reversed_downsample_params[i]["stride"] + if use_convtranspose: blocks.append( Upsample( @@ -441,7 +660,7 @@ def __init__( in_channels=block_in_ch, out_channels=block_in_ch, interp_mode="nearest", - scale_factor=2.0, + scale_factor=tuple(float(s) for s in upsampling_stride), post_conv=post_conv, align_corners=None, ) @@ -492,6 +711,10 @@ class AutoencoderKL(nn.Module): use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + downsample_parameters: downsampling parameters for each level. Can be: + - None: use default (kernel_size=3, stride=2 for all levels) + - dict: apply same parameters to all levels (e.g., {"kernel_size": (3,3,1), "stride": (2,2,1)}) + - list of dicts: one dict per downsampling level with keys "kernel_size", "stride", "padding" """ def __init__( @@ -512,6 +735,7 @@ def __init__( include_fc: bool = True, use_combined_linear: bool = False, use_flash_attention: bool = False, + downsample_parameters: list[dict] | dict | None = None, ) -> None: super().__init__() @@ -544,7 +768,12 @@ def __init__( include_fc=include_fc, use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, + downsample_parameters=downsample_parameters, ) + + # Get downsampling parameters from encoder to ensure decoder uses the same strides + encoder_downsample_params = self.encoder.downsample_parameters + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, @@ -559,6 +788,7 @@ def __init__( include_fc=include_fc, use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, + downsample_parameters=encoder_downsample_params, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index af0c55d6ec..cb66c74a0e 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -428,6 +428,322 @@ def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False" ) + # New tests for downsampling parameters + def test_backward_compatibility_default_behavior(self): + """Test that default behavior (no downsample_parameters) is unchanged.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + # Test with standard input shape + x = torch.randn(1, 1, 16, 16).to(device) + result = net.forward(x) + # With default stride=2 and 2 downsampling levels (for 3 channel groups), + # latent shape should be 16 / 2 / 2 = 4 + self.assertEqual(result[0].shape, (1, 1, 16, 16)) + self.assertEqual(result[1].shape, (1, 4, 4, 4)) + + def test_anisotropic_stride_2d(self): + """Test 2D anisotropic stride (2,1) at first level.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Downsampling: level 0 uses (2,1), level 1 uses (2,2) + downsample_params = [{"kernel_size": 3, "stride": (2, 1)}, {"kernel_size": 3, "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/1=32 + # After level 1: 16/2=8, 32/2=16 + self.assertEqual(result[0].shape, (1, 1, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 16)) + + def test_anisotropic_stride_3d(self): + """Test 3D anisotropic stride (2,2,1) - common for thick slice spacing.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Preserve z-dimension with stride=1 + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32, 64).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/2=16, 64/1=64 + # After level 1: 16/2=8, 16/2=8, 64/1=64 + self.assertEqual(result[0].shape, (1, 1, 32, 32, 64)) + self.assertEqual(result[1].shape, (1, 4, 8, 8, 64)) + + def test_mixed_anisotropic_downsample_parameters(self): + """Test per-level configuration with mixed parameters.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Level 0: preserve z, Level 1: isotropic + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/2=16, 32/1=32 + # After level 1: 16/2=8, 16/2=8, 32/2=16 + self.assertEqual(result[0].shape, (1, 1, 32, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 8, 16)) + + def test_single_dict_applied_to_all_levels(self): + """Test that single dict is applied to all downsampling levels.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Single dict: apply (3,3) kernel with stride (2,1) to all levels + downsample_params = {"kernel_size": (3, 3), "stride": (2, 1)} + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + result = net.forward(x) + # After level 0: 32/2=16, 32/1=32 + # After level 1: 16/2=8, 32/1=32 + self.assertEqual(result[0].shape, (1, 1, 32, 32)) + self.assertEqual(result[1].shape, (1, 4, 8, 32)) + + def test_validation_even_kernel_raises_error(self): + """Test that even kernel sizes raise ValueError.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [{"kernel_size": 4, "stride": 2}] # Even kernel + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_validation_invalid_tuple_length_raises_error(self): + """Test that invalid tuple length raises ValueError.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # 3D but only 2 values in tuple + downsample_params = [{"kernel_size": (3, 3), "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_validation_wrong_num_levels_raises_error(self): + """Test that wrong number of downsampling parameter dicts raises error.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), # 3 channels = 2 downsampling levels + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + # Only 1 dict but need 2 + downsample_params = [{"kernel_size": 3, "stride": 2}] + input_param["downsample_parameters"] = downsample_params + + with self.assertRaises(ValueError): + AutoencoderKL(**input_param) + + def test_reconstruction_with_anisotropic_downsampling(self): + """Test that reconstruction shape matches input with anisotropic downsampling.""" + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 64, 64, 128).to(device) + reconstruction = net.reconstruct(x) + self.assertEqual(reconstruction.shape, x.shape) + + def test_encode_decode_with_anisotropic_downsampling(self): + """Test encode/decode cycle with anisotropic downsampling.""" + input_param = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + downsample_params = [{"kernel_size": (3, 3), "stride": (2, 1)}, {"kernel_size": (3, 3), "stride": (2, 2)}] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + x = torch.randn(1, 1, 32, 32).to(device) + z_mu, z_sigma = net.encode(x) + z = net.sampling(z_mu, z_sigma) + reconstruction = net.decode(z) + self.assertEqual(reconstruction.shape, x.shape) + + def test_reconstruction_robustness_anisotropic_non_power_of_two_odd_dims(self): + """ + Test reconstruction shape consistency with: + - Anisotropic multi-level downsampling config + - Non-power-of-two spatial dimensions (but stride-compatible) + - Mixed even/odd dimensions + + This rigorously validates encoder-decoder symmetry under challenging conditions. + + Note: Dimensions must be compatible with the stride pattern: + - Stride (2,2,1) -> (2,2,2) means dims must be divisible by (4,4,2) + - Using 60 (=4*15), 68 (=4*17), 96 (=2*48) to maximize coverage + """ + input_param = { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + + # Anisotropic config: preserve Z dimension at level 0, isotropic at level 1 + downsample_params = [ + {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] + input_param["downsample_parameters"] = downsample_params + net = AutoencoderKL(**input_param).to(device) + + with eval_mode(net): + # Stride-compatible dimensions: + # Level 0: stride (2,2,1) -> need height/width divisible by 2 + # Level 1: stride (2,2,2) -> need result divisible by 2 again + # Final requirement: dims divisible by (4, 4, 2) + # Using: 60=4*15 (not power of 2), 68=4*17 (not power of 2), 96=2*48 + x = torch.randn(1, 1, 60, 68, 96).to(device) + + # Forward pass + z_mu, z_sigma = net.encode(x) + z = net.sampling(z_mu, z_sigma) + reconstruction = net.decode(z) + + # Verify shape consistency - reconstruction should match input exactly + self.assertEqual( + reconstruction.shape, + x.shape, + f"Reconstruction shape {reconstruction.shape} does not match input shape {x.shape}", + ) + + # Also test via reconstruct method + reconstruction2 = net.reconstruct(x) + self.assertEqual( + reconstruction2.shape, + x.shape, + f"Reconstruct shape {reconstruction2.shape} does not match input shape {x.shape}", + ) + + # Verify latent shape makes sense: + # 60 -> 30 (stride=2) -> 15 (stride=2) + # 68 -> 34 (stride=2) -> 17 (stride=2) + # 96 -> 96 (stride=1) -> 48 (stride=2) + expected_latent_h = 15 + expected_latent_w = 17 + expected_latent_d = 48 + + self.assertEqual( + z_mu.shape[2], + expected_latent_h, + f"Latent H shape mismatch: got {z_mu.shape[2]}, expected {expected_latent_h}", + ) + self.assertEqual( + z_mu.shape[3], + expected_latent_w, + f"Latent W shape mismatch: got {z_mu.shape[3]}, expected {expected_latent_w}", + ) + self.assertEqual( + z_mu.shape[4], + expected_latent_d, + f"Latent D shape mismatch: got {z_mu.shape[4]}, expected {expected_latent_d}", + ) + if __name__ == "__main__": unittest.main() From 6bc4ebc226d33dba05c6a0e347742178d67b0a02 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 17 May 2026 20:41:27 +0530 Subject: [PATCH 2/2] Update AutoencoderKL test configuration and docstrings Signed-off-by: Shubham Chandravanshi --- monai/networks/nets/autoencoderkl.py | 83 ++++++++++++++++++++--- tests/networks/nets/test_autoencoderkl.py | 30 +++++++- 2 files changed, 103 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 6c9e93a633..66aa59564b 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -30,7 +30,7 @@ def _validate_kernel_stride_parameters( stride: int | tuple[int, ...] | None, spatial_dims: int, param_name: str = "parameter", -) -> tuple[tuple[int, ...], tuple[int, ...]]: +) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]: """ Validate and normalize kernel_size and stride parameters. @@ -84,15 +84,16 @@ def _validate_kernel_stride_parameters( def _compute_padding(kernel_size: tuple[int, ...]) -> tuple[int, ...]: """ - Compute symmetric padding from kernel size. + Compute symmetric padding for odd kernel sizes. - For odd kernel sizes, padding = kernel_size // 2 on all sides. + Padding is derived as: + padding[d] = kernel_size[d] // 2 Args: - kernel_size: tuple of odd integers + kernel_size: Kernel size for each spatial dimension. Returns: - Tuple of padding values (one per dimension) + Tuple of padding values for each spatial dimension. """ padding = tuple(k // 2 for k in kernel_size) return padding @@ -119,9 +120,9 @@ def _normalize_downsample_parameters( - "padding": int or tuple (auto-computed if omitted) Returns: - List of dicts with normalized keys: - - Each dict has "kernel_size", "stride", "padding" as tuples - - Length equals num_levels + List of dicts with normalized keys: + - Each dict has "kernel_size", "stride", "padding" as tuples + - Length equals num_levels Raises: ValueError: if parameters are invalid or inconsistent @@ -195,6 +196,15 @@ def __init__(self, spatial_dims: int) -> None: self.pad = (0, 1) * spatial_dims def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply asymmetric padding to the input tensor. + + Args: + x: Input tensor. + + Returns: + Padded tensor. + """ x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) return x @@ -246,6 +256,15 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply convolutional downsampling. + + Args: + x: Input tensor. + + Returns: + Downsampled tensor. + """ x = self.conv(x) return x @@ -486,6 +505,15 @@ def __init__( self.blocks = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward input through encoder blocks. + + Args: + x: Input tensor. + + Returns: + Encoded latent representation. + """ for block in self.blocks: x = block(x) return x @@ -682,6 +710,15 @@ def __init__( self.blocks = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward latent representation through decoder blocks. + + Args: + x: Latent tensor. + + Returns: + Reconstructed image tensor. + """ for block in self.blocks: x = block(x) return x @@ -890,17 +927,47 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: return dec def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encode, sample, and reconstruct an input image. + + Args: + x: Input tensor of shape BxCx[SPATIAL_DIMS]. + + Returns: + Tuple containing: + - reconstructed image + - latent mean + - latent standard deviation + """ z_mu, z_sigma = self.encode(x) z = self.sampling(z_mu, z_sigma) reconstruction = self.decode(z) return reconstruction, z_mu, z_sigma def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode an input image into latent space representation. + + Args: + x: Input tensor. + + Returns: + Sampled latent tensor. + """ z_mu, z_sigma = self.encode(x) z = self.sampling(z_mu, z_sigma) return z def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode latent representation into image space. + + Args: + z: Latent tensor. + + Returns: + Decoded image tensor. + """ image = self.decode(z) return image diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index cb66c74a0e..36272bd7cf 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -440,6 +440,8 @@ def test_backward_compatibility_default_behavior(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } net = AutoencoderKL(**input_param).to(device) with eval_mode(net): @@ -462,6 +464,8 @@ def test_anisotropic_stride_2d(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Downsampling: level 0 uses (2,1), level 1 uses (2,2) downsample_params = [{"kernel_size": 3, "stride": (2, 1)}, {"kernel_size": 3, "stride": (2, 2)}] @@ -487,6 +491,8 @@ def test_anisotropic_stride_3d(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Preserve z-dimension with stride=1 downsample_params = [ @@ -515,6 +521,8 @@ def test_mixed_anisotropic_downsample_parameters(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Level 0: preserve z, Level 1: isotropic downsample_params = [ @@ -543,6 +551,8 @@ def test_single_dict_applied_to_all_levels(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Single dict: apply (3,3) kernel with stride (2,1) to all levels downsample_params = {"kernel_size": (3, 3), "stride": (2, 1)} @@ -568,8 +578,11 @@ def test_validation_even_kernel_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } - downsample_params = [{"kernel_size": 4, "stride": 2}] # Even kernel + + downsample_params = [{"kernel_size": 4, "stride": 2}, {"kernel_size": 3, "stride": 2}] # Even kernel input_param["downsample_parameters"] = downsample_params with self.assertRaises(ValueError): @@ -586,9 +599,14 @@ def test_validation_invalid_tuple_length_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # 3D but only 2 values in tuple - downsample_params = [{"kernel_size": (3, 3), "stride": (2, 2)}] + downsample_params = [ + {"kernel_size": (3, 3), "stride": (2, 2)}, # Invalid: 2 values for 3D + {"kernel_size": (3, 3, 3), "stride": (2, 2, 2)}, + ] input_param["downsample_parameters"] = downsample_params with self.assertRaises(ValueError): @@ -605,6 +623,8 @@ def test_validation_wrong_num_levels_raises_error(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Only 1 dict but need 2 downsample_params = [{"kernel_size": 3, "stride": 2}] @@ -624,6 +644,8 @@ def test_reconstruction_with_anisotropic_downsampling(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } downsample_params = [ {"kernel_size": (3, 3, 1), "stride": (2, 2, 1)}, @@ -648,6 +670,8 @@ def test_encode_decode_with_anisotropic_downsampling(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } downsample_params = [{"kernel_size": (3, 3), "stride": (2, 1)}, {"kernel_size": (3, 3), "stride": (2, 2)}] input_param["downsample_parameters"] = downsample_params @@ -682,6 +706,8 @@ def test_reconstruction_robustness_anisotropic_non_power_of_two_odd_dims(self): "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, } # Anisotropic config: preserve Z dimension at level 0, isotropic at level 1