Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 11 additions & 82 deletions STL_main/STL_2D_Kernel_Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,12 @@ def _find_mask(self, data):
else:
raise ValueError("len(data.conv_history) must be 0, 1 or 2.")

def _build_bump_steerable_wavelet_kernel(self):
N = 256
def _build_wavelet_kernel_from_ifft_crop(self, fft_wavelet_builder, N=256):
assert (
self.KERNELSZ % 2 == 1
), "KERNELSZ must be odd to have a well-defined center."
w = self.KERNELSZ // 2
kernel = WaveletOperator2D_FFT_torch.bump_steerable_bank(
self.J, self.L, size=(N, N)
)[
0
] # [L, N, N]
kernel = fft_wavelet_builder(self.J, self.L, size=(N, N))[0] # [L, N, N]
kernel = torch.fft.fftshift(torch.fft.ifft2(kernel, dim=(-2, -1)), dim=(-2, -1))
Comment on lines +599 to 600
kernel = kernel[
:, N // 2 - w : N // 2 + w + 1, N // 2 - w : N // 2 + w + 1
Expand All @@ -610,81 +605,15 @@ def _build_bump_steerable_wavelet_kernel(self):
kernel = kernel.to(device=self.device)
return kernel.unsqueeze(0) # (1, L, K, K)

def _build_morlet_wavelet_kernel(self, sigma=1):
"""Create a 2D Wavelet kernel."""

# Morlay wavelet
coords = (
torch.arange(self.KERNELSZ, device=self.device, dtype=self.dtype)
- (self.KERNELSZ - 1) / 2.0
)
yy, xx = torch.meshgrid(coords, coords, indexing="ij")

# Gaussian envelope
gaussian_envelope_tmp = torch.exp(-2 * (xx**2 + yy**2) / (self.L * sigma**2))
gaussian_envelope = torch.exp(-8 * (xx**2 + yy**2) / (self.L * sigma**2))

# Orientations
angles = (
torch.arange(self.L, device=self.device, dtype=self.dtype)
/ self.L
* torch.pi
)
angles *= -1 # align with bump-steerable wavelet orientations

# Morlet wavelet: exp(i*k0*x_rot) * gaussian_envelope
# x_rot is the coordinate along the orientation direction
x_rot = xx[None, :, :] * torch.cos(angles[:, None, None]) + yy[
None, :, :
] * torch.sin(angles[:, None, None])

# Complex Morlet wavelet
kernel_tmp = (
torch.exp(1j * 0.75 * np.pi * x_rot) * gaussian_envelope_tmp[None, :, :]
) # (L, K, K)

# y: (L, 3K, 3K)
y = torch.zeros(
[self.L, self.KERNELSZ * 3, self.KERNELSZ * 3],
device=self.device,
dtype=kernel_tmp.dtype,
)
y[:, self.KERNELSZ : self.KERNELSZ * 2, self.KERNELSZ : self.KERNELSZ * 2] = (
kernel_tmp
)

# conv2d expects 4D input: (N, C, H, W)
y4 = y.unsqueeze(1) # (L, 1, 3K, 3K)

# weight: (C_out=1, C_in=1, 5, 5)
w = gaussian_envelope.unsqueeze(0).unsqueeze(0) # (1,1,5,5)
w = w.to(dtype=kernel_tmp.dtype) # optional: cast to complex if you want

# No padding needed because you already embedded into a larger array
y4 = F.conv2d(input=y4, weight=w, stride=1, padding=0) # (L,1,3K-4,3K-4)

# Back to (L, H, W)
y = y4.squeeze(1) # (L, 3K-4, 3K-4)

# IMPORTANT: indices shift because output is smaller by 4 pixels
# You want the central KxK block corresponding to original center
kernel = y[
:,
(self.KERNELSZ - 2) : (self.KERNELSZ - 2) + self.KERNELSZ,
(self.KERNELSZ - 2) : (self.KERNELSZ - 2) + self.KERNELSZ,
] # (L, K, K)
######################

# Remove DC component (admissibility condition)
kernel = kernel - torch.mean(kernel, dim=(1, 2))[:, None, None]

# L2 normalization
kernel /= 2 * self.L

# ad hoc normalisation
kernel /= 0.8

return kernel.reshape(1, self.L, self.KERNELSZ, self.KERNELSZ) # (1, L, K, K)
def _build_bump_steerable_wavelet_kernel(self):
return self._build_wavelet_kernel_from_ifft_crop(
fft_wavelet_builder=WaveletOperator2D_FFT_torch.bump_steerable_bank
) # (1, L, K, K)

def _build_morlet_wavelet_kernel(self):
return self._build_wavelet_kernel_from_ifft_crop(
fft_wavelet_builder=WaveletOperator2D_FFT_torch.gaussian_bank
) # (1, L, K, K)

def _crop(self, array, border):
"""
Expand Down
6 changes: 3 additions & 3 deletions docs/user_notebook/scattering_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 11,
"id": "e75901a1",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -723,7 +723,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 12,
"id": "44fbfbed",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -751,7 +751,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 13,
"id": "3dbc99ba",
"metadata": {},
"outputs": [
Expand Down
Loading
Loading