normalize 2D kernel#37
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the 2D kernel-based wavelet operator to adjust wavelet-kernel construction/normalization and adds an upsampling path intended to act as the transpose of the existing downsampling operator.
Changes:
- Build two wavelet kernels (
_wav_kerneland_wav_kernel_0) and switch the kernel used inapply()forj==0. - Add
_upsample_tensor()and a publicupsample()method (including a mask-aware branch). - Modify the 5x5 Gaussian smoothing kernel construction/caching logic (including complex-dtype handling).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Periodic boundary conditions must be defined | ||
| if data.pbc is None: | ||
| raise ValueError( | ||
| "data.pbc must be specified to perform upsampling " | ||
| "(for adequate padding mode)." | ||
| ) |
There was a problem hiding this comment.
upsample() requires data.pbc “for adequate padding mode”, but _upsample_tensor explicitly doesn’t use padding_mode (and conv_transpose2d doesn’t support circular/replicate padding modes). Either implement boundary-condition handling for upsampling (e.g., explicit pre/post padding/cropping for the chosen mode) or relax/remove the pbc requirement and adjust the docstring accordingly.
| def upsample(self, data, dg_out, inplace=True, replace_nan_value=nan): | ||
| """ | ||
| Upsample the data to the dg_out resolution. | ||
|
|
||
| Upsampling is performed in real space along the last two dimensions | ||
| using successive applications (if dg - dg_out >1) of | ||
| torch.conv_transpose2d with stride=2 and the same 5x5 smoothing kernel. | ||
|
|
||
| This corresponds to the adjoint (transpose) of the downsampling operator. | ||
|
|
||
| If a full-resolution mask is defined, a mask-aware behavior is applied. | ||
| Note that exact adjointness of the full masked + reweighted pipeline | ||
| depends on the interpretation of the reweighting maps. | ||
| """ | ||
|
|
||
| # Periodic boundary conditions must be defined | ||
| if data.pbc is None: | ||
| raise ValueError( | ||
| "data.pbc must be specified to perform upsampling " | ||
| "(for adequate padding mode)." | ||
| ) | ||
|
|
||
| if dg_out < 0: | ||
| raise ValueError("dg_out must be non-negative.") | ||
|
|
||
| # No change in resolution | ||
| if dg_out == data.dg and inplace: | ||
| return data | ||
|
|
||
| # Upsampling only (downsampling must use downsample()) | ||
| if dg_out >= data.dg: | ||
| raise ValueError( | ||
| "Requested dg_out <= current dg; " | ||
| "downsampling not supported by upsampling method." | ||
| ) | ||
|
|
||
| # Work on a copy if not inplace | ||
| data = data.copy(empty=False) if not inplace else data | ||
|
|
||
| dg_inc = data.dg - dg_out | ||
| if dg_inc == 0: | ||
| return data | ||
|
|
||
| # Build the same smoothing kernel used for downsampling | ||
| smooth_kernel = self._gaussian_kernel_5x5( | ||
| device=data.array.device, | ||
| dtype=data.array.dtype, | ||
| ) | ||
|
|
||
| padding_mode = self.__class__._get_padding_mode(pbc=data.pbc) | ||
|
|
||
| # ============================================================ | ||
| # Case A — No mask defined | ||
| # ============================================================ | ||
| if self.mask_full_res is None: | ||
|
|
||
| # Apply transpose operator step-by-step | ||
| data.array = self._upsample_tensor( | ||
| x=data.array, | ||
| smooth_kernel=smooth_kernel, | ||
| dg_inc=dg_inc, | ||
| padding_mode=padding_mode, | ||
| ) | ||
|
|
There was a problem hiding this comment.
New upsample() and _upsample_tensor() behavior (including the mask-aware branch and the reweighting-map scaling strategy) isn’t covered by tests. Since the repo already has pytest coverage for downsampling in the FFT backend, it would be good to add targeted tests for this kernel backend (shape round-trip downsample→upsample, behavior with sigma_smooth, and masked vs unmasked inputs).
| self.sigma_smooth = ( | ||
| sigma_smooth # to build smoothing kernel used in downsampling | ||
| ) | ||
| #simag_smooth should be defined before build wavelet kernel for dg=0 |
There was a problem hiding this comment.
The comment has a typo (“simag_smooth”) and is missing a space after #. This makes it harder to search/understand and looks accidental; it should refer to sigma_smooth.
| #simag_smooth should be defined before build wavelet kernel for dg=0 | |
| # sigma_smooth should be defined before build wavelet kernel for dg=0 |
| # weight: (C_out=1, C_in=1, 5, 5) | ||
| w = gaussian_envelope_0.unsqueeze(0).unsqueeze(0) # (1,1,5,5) | ||
| w = w.to(dtype=kernel.dtype) # optional: cast to complex if you want | ||
|
|
||
| # No padding needed because you already embedded into a larger array | ||
| y4 = F.conv2d(input=y4, weight=w, stride=1, padding=0) # (L,1,3K-4,3K-4) |
There was a problem hiding this comment.
y4 is complex (built from kernel), but this code calls torch.nn.functional.conv2d directly on it. Elsewhere in this file complex convolutions are handled by splitting real/imag (see _semicomplex_conv2d_circular), which strongly suggests F.conv2d is expected to operate on real tensors only. Please convolve y4.real and y4.imag separately with the real weight and recombine, or otherwise ensure this path never sees complex tensors.
| # weight: (C_out=1, C_in=1, 5, 5) | |
| w = gaussian_envelope_0.unsqueeze(0).unsqueeze(0) # (1,1,5,5) | |
| w = w.to(dtype=kernel.dtype) # optional: cast to complex if you want | |
| # No padding needed because you already embedded into a larger array | |
| y4 = F.conv2d(input=y4, weight=w, stride=1, padding=0) # (L,1,3K-4,3K-4) | |
| # weight: (C_out=1, C_in=1, 5, 5) - keep this real for semicomplex conv | |
| w = gaussian_envelope_0.unsqueeze(0).unsqueeze(0) # (1,1,5,5) | |
| # Ensure weight matches the real dtype/device of y4 for real-valued conv2d | |
| w = w.to(device=y4.device, dtype=y4.real.dtype) | |
| # No padding needed because you already embedded into a larger array. | |
| # Perform semicomplex convolution: convolve real and imag parts separately. | |
| y4_real = F.conv2d(input=y4.real, weight=w, stride=1, padding=0) # (L,1,3K-4,3K-4) | |
| y4_imag = F.conv2d(input=y4.imag, weight=w, stride=1, padding=0) # (L,1,3K-4,3K-4) | |
| y4 = torch.complex(y4_real, y4_imag) # (L,1,3K-4,3K-4), complex again |
| # IMPORTANT: indices shift because output is smaller by 4 pixels | ||
| # You want the central KxK block corresponding to original center | ||
| kernel_0 = y[:, (self.KERNELSZ-2):(self.KERNELSZ-2)+self.KERNELSZ, | ||
| (self.KERNELSZ-2):(self.KERNELSZ-2)+self.KERNELSZ] |
There was a problem hiding this comment.
The crop indices for kernel_0 are hard-coded as (KERNELSZ-2) offsets, which only centers correctly for KERNELSZ==5. Since kernel_size is configurable, compute the start index from the actual convolution output size (e.g., (y.shape[-1]-KERNELSZ)//2) so kernel_0 stays centered for any odd kernel size.
| # IMPORTANT: indices shift because output is smaller by 4 pixels | |
| # You want the central KxK block corresponding to original center | |
| kernel_0 = y[:, (self.KERNELSZ-2):(self.KERNELSZ-2)+self.KERNELSZ, | |
| (self.KERNELSZ-2):(self.KERNELSZ-2)+self.KERNELSZ] | |
| # IMPORTANT: output is smaller; extract the central KxK block | |
| h, w = y.shape[-2], y.shape[-1] | |
| start_h = (h - self.KERNELSZ) // 2 | |
| start_w = (w - self.KERNELSZ) // 2 | |
| kernel_0 = y[:, start_h:start_h + self.KERNELSZ, | |
| start_w:start_w + self.KERNELSZ] |
| # L2 normalization | ||
| # tune the normalisation | ||
| kernel_0 /= 2*self.L | ||
| kernel /= self.L | ||
| ''' | ||
| if self.L==4: | ||
| kernel[1]*=1.5 | ||
| kernel[3]*=1.5 | ||
| kernel_0[1]*=2 | ||
| kernel_0[3]*=2 | ||
|
|
||
| kernel_0 = ( | ||
| kernel_0 | ||
| / torch.sqrt(torch.sum(torch.abs(kernel) ** 2, dim=(1, 2)))[:, None, None] | ||
| ) | ||
| kernel = ( | ||
| kernel | ||
| / torch.sqrt(torch.sum(torch.abs(kernel) ** 2, dim=(1, 2)))[:, None, None] | ||
| ) | ||
|
|
||
| return kernel.reshape(1, self.L, self.KERNELSZ, self.KERNELSZ) | ||
| ''' | ||
| return kernel.reshape(1, self.L, self.KERNELSZ, self.KERNELSZ), \ |
There was a problem hiding this comment.
The function section labeled “L2 normalization” no longer performs L2 normalization: it scales by self.L / 2*self.L, and the actual normalization code is left inside a triple-quoted block. This makes kernel amplitude depend on sigma, KERNELSZ, etc., and the comment is now misleading. Consider either restoring a real normalization (and removing the commented-out block) or updating the comment/logic so the intended scaling is explicit and enforced.
| sigma = torch.tensor(1.0, device=device, dtype=rdtype) | ||
| kernel = torch.exp(-(xx**2 + yy**2) / (2*sigma**2)) | ||
| kernel = kernel / kernel.sum() | ||
|
|
||
| # _conv2d_circular expects w shape (O_c, wx, wy) | ||
| self._smooth_kernel_5x5 = kernel | ||
| self._smooth_kernel_5x5 = kernel.to(dtype=dtype) |
There was a problem hiding this comment.
_gaussian_kernel_5x5 no longer uses self.sigma_smooth and instead hard-codes sigma = 1.0. This effectively ignores the sigma_smooth constructor parameter and changes downsampling/upsampling behavior for any non-default value. Use self.sigma_smooth (in a real dtype) when building the kernel so the public parameter remains meaningful.
| "Requested dg_out <= current dg; " | ||
| "downsampling not supported by upsampling method." |
There was a problem hiding this comment.
The dg_out validation and the error message don’t match: this branch triggers when dg_out > data.dg, but the message says “dg_out <= current dg”. Update the message (or the condition) so it correctly states that upsample() requires dg_out < data.dg and that requesting a larger dg_out is a downsampling request.
| "Requested dg_out <= current dg; " | |
| "downsampling not supported by upsampling method." | |
| "upsample() requires dg_out < current dg; " | |
| "requested dg_out >= current dg (this is a downsampling request; " | |
| "use downsample() instead)." |
Evrything is said in the title ;)