Skip to content
50 changes: 32 additions & 18 deletions STL_main/STL_2D_FFT_Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,23 +408,26 @@ def gaussian_bank(cls, J, L, size, base_mu=None, base_sigma=None):
filters_bank = torch.fft.fftshift(filters_bank, dim=(-2, -1))
filters_bank[:, :, 0, 0] = 0

# ad hoc normalisation
filters_bank /= 0.8

return filters_bank

@staticmethod
def bump_steerable_2d(omega_grid, c, L, xi0, eps=1e-12):
def bump_steerable_2d(omega_grid, L, xi0, width_factor=2.5, eps=1e-12):
"""
Generate a 2D bump steerable wavelet in Fourier space.

Parameters
----------
omega_grid : torch.Tensor
Grid of frequencies in Fourier space, of shape [Nx, Ny, 2], where the last dimension corresponds to (omega_x, omega_y).
c : float
Normalization constant.
L : int
Number of orientations (steerability order).
xi0 : float
Center frequency xi = (xi0, 0) in Fourier space.
width_factor : float
Optimized constant to follow at best Littlewood-Paley condition (optimized for L=4).
eps : float
Small constant to avoid division by zero in the bump window.

Expand All @@ -441,18 +444,36 @@ def bump_steerable_2d(omega_grid, c, L, xi0, eps=1e-12):

# apply bump window over r: g(r) = exp(-r^2 / (1 - r^2)) * 1_{0<r<1}
r2 = r**2
support_r = (r > 0.0) & (r < 1.0)
support_r = (r >= 0.0) & (r < 1.0)
r2 *= width_factor # optimized parameter (for L=4) to follow at best Littlewood Paley condition, see numerical integration in the notebook "BS_wavelet_kernel.ipynb"
denom = (1.0 - r2).clamp_min(eps)
bump = torch.where(support_r, torch.exp(-r2 / denom), torch.zeros_like(r))

# angular part: cos(theta)^(L-1) where theta is the angle of omega in Fourier space
theta = torch.atan2(omega_grid[..., 1], omega_grid[..., 0])
support_theta = (theta >= -torch.pi / 2) & (theta <= torch.pi / 2)
support_theta = abs(theta) < torch.pi / 2
angular = torch.where(
support_theta, torch.cos(theta).pow(L - 1), torch.zeros_like(theta)
)

return c * bump * angular
weights = bump * angular

# normalize
weights /= weights.abs().max()

# Add a constant to follow Littlewood-Paley condition when L!=4
c = (
(2 ** (L - 1))
* math.factorial(L - 1)
/ math.sqrt(L * math.factorial(2 * (L - 1)))
)
c_L4 = (
(2 ** (4 - 1))
* math.factorial(4 - 1)
/ math.sqrt(4 * math.factorial(2 * (4 - 1)))
)

return weights * (c / c_L4)

@classmethod
def bump_steerable_bank(cls, J, L, size):
Expand Down Expand Up @@ -483,20 +504,14 @@ def bump_steerable_bank(cls, J, L, size):

omega_y = torch.fft.fftfreq(Ny) * Ny
omega_y = torch.fft.fftshift(omega_y) # Shift zero frequency to center
omega_y = torch.flip(omega_y, dims=[0])

Omega_x, Omega_y = torch.meshgrid(omega_x, omega_y, indexing="ij")
omega_grid = torch.stack((Omega_x, Omega_y), dim=-1)

# c value for Littlewood-Paley condition
c = ((1.29**-1) * (2 ** (L - 1)) * math.factorial(L - 1)) / math.sqrt(
L * math.factorial(2 * (L - 1))
)

for j in range(J):
scale_factor = 2**j
for l_idx, l in enumerate(range(L)):
theta = math.pi * l / L
theta = math.pi * l / L + math.pi / 2

cos_theta = torch.cos(torch.tensor(theta))
sin_theta = torch.sin(torch.tensor(theta))
Expand All @@ -510,7 +525,7 @@ def bump_steerable_bank(cls, J, L, size):
scale_factor * omega_grid @ R
) # rotate and dilate the frequency grid
filters_bank[j, l_idx] = torch.fft.fftshift(
cls.bump_steerable_2d(q, c=c, L=L, xi0=xi0)
cls.bump_steerable_2d(q, L=L, xi0=xi0)
)

return filters_bank
Expand Down Expand Up @@ -586,7 +601,7 @@ def __init__(
Parameters
----------
- WType : str
type of wavelets (e.g., "Gaussian" or "Bump-Steerable")
type of wavelets (e.g., "Bump-Steerable" or "Morlet")
- L : int
number of orientations
- J : int
Expand All @@ -602,7 +617,7 @@ def __init__(
- get_crop_border_size_method : function
Method to compute the crop border size.
"""
self.WType = WType # type of wavelets (e.g., "Gaussian" or "Bump-Steerable")
self.WType = WType # type of wavelets (e.g., "Bump-Steerable" or "Morlet")

# Main parameters
self.N0 = N0
Expand Down Expand Up @@ -643,7 +658,7 @@ def _build(self):
- j_to_dg
"""
# Create the full resolution Wavelet set (in fourier space plus fftshifted)
if self.WType == "Gaussian":
if self.WType == "Morlet":
self.wavelet_array = self.__class__.gaussian_bank(
self.J, self.L, self.N0
).to(
Expand Down Expand Up @@ -1009,7 +1024,6 @@ def _compute_and_store_cross_cov(
assert (
data1.array.shape[1] == data2.array.shape[1]
), "data1 and data2 arrays must have the same number of channels."

assert (
data1.array.ndim == data2.array.ndim
), "data1 and data2 arrays must have the same number of dimensions."
Expand Down
Loading
Loading