Skip to content
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,14 @@

### Fixed
- Added the `WType` argument to the constructor of the Wavelet operator in the `STL_2D_Kernel_Torch` dataclass. Currently available: Morlet
- Removed the `mean_ref` variable to avoid numerical instability
- Removed the `mean_ref` variable to avoid numerical instability


## [v1.4.0] - 2026-04-29
### Added
- Added the ability to perform synthesis from the statistics of a target map.
- Added and fully completed a user notebook demonstrating synthesis from statistics.
- Extended the `flatten` method to optionally preserve the batch dimension and to handle flattening of complex coefficients.

### Fixed
- Fixed incorrect synthesis behavior in mono-channel FFT and cross-channel FFT.
48 changes: 46 additions & 2 deletions STL_main/STL_2D_FFT_Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,50 @@ def set_fourier_status(self, target_fourier_status, inplace=False):

return data

###########################################################################
def divide(self, data2, epsilon=1e-8, pow=1.0, inplace=False):
"""
Divide self.array by data2.array raised to a power, with a small epsilon added to the denominator for numerical stability.

- If `data2.array` is in real space:
`self.array` is converted to real space (if needed), and the division
is performed in real space.

- If `data2.array` is in Fourier space:
`self.array` is converted to Fourier space (if needed), and the division
is performed in Fourier space (i.e., deconvolution in real space).

Parameters
----------
data2 : STL_2D_FFT_Torch
Another instance whose array is used as the denominator. Its Fourier
status determines the computation domain.
epsilon : float, optional
Small constant added to the denominator for numerical stability
(default is 1e-8).
power : float, optional
Exponent applied to the denominator (default is 1).
inplace : bool
If True, performs the operation in-place and returns self.
If False, returns a new instance.

Returns
-------
STL_2D_FFT_Torch
Result of the division in the appropriate domain.
"""

# convert self to the Fourier status of data2
data1 = self.set_fourier_status(
target_fourier_status=data2.fourier_status, inplace=inplace
)

# perform the division in the appropriate domain
data1.array = data1.array / (data2.array + epsilon) ** pow
data1.dtype = data1.array.dtype

return data1

###########################################################################
def get_wavelet_op(self, *args, **kwargs):

Expand Down Expand Up @@ -904,15 +948,15 @@ def standardize(self, data, mean_field, inplace=False, dim=None):

mean = self.mean(l_data) # [Nb,Nc]
if mean_field:
mean = mean.mean(dim=0) # [Nc]
mean = mean.mean(dim=0, keepdim=True) # [1,Nc]

l_data.array = (
l_data.array - mean[..., None, None]
) # centering first because no remove_mean in cov

var = self.cov(l_data, l_data)
if mean_field:
var = var.mean(dim=0) # [Nc]
var = var.mean(dim=0, keepdim=True) # [1,Nc]

std = torch.sqrt(var)

Expand Down
38 changes: 36 additions & 2 deletions STL_main/STL_2D_Kernel_Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ def modulus(self, inplace=False):

return data

def divide(self, data2, epsilon=1e-8, pow=1.0, inplace=False):
"""
Divide self.array by data2.array raised to a power, with a small epsilon added to the denominator for numerical stability.

Division is still performed in real space

Parameters
----------
data2 : STL_2D_FFT_Torch
Another instance whose array is used as the denominator. Its Fourier
status determines the computation domain.
epsilon : float, optional
Small constant added to the denominator for numerical stability
(default is 1e-8).
power : float, optional
Exponent applied to the denominator (default is 1).
inplace : bool
If True, performs the operation in-place and returns self.
If False, returns a new instance.

Returns
-------
STL_2D_FFT_Torch
Result of the division in the appropriate domain.
"""

data1 = self.copy(empty=False) if not inplace else self

# Apply the division in real space
data1.array = data1.array / (data2.array + epsilon) ** pow
data1.dtype = data1.array.dtype

return data1

def get_wavelet_op(
self,
J=None,
Expand Down Expand Up @@ -768,15 +802,15 @@ def standardize(self, data, mean_field, inplace=False, dim=None):

mean = self.mean(l_data) # [Nb,Nc]
if mean_field:
mean = mean.mean(dim=0) # [Nc]
mean = mean.mean(dim=0, keepdim=True) # [1,Nc]

l_data.array = (
l_data.array - mean[..., None, None]
) # centering first because no remove_mean in cov

var = self.cov(l_data, l_data)
if mean_field:
var = var.mean(dim=0) # [Nc]
var = var.mean(dim=0, keepdim=True) # [1,Nc]

std = torch.sqrt(var)

Expand Down
41 changes: 25 additions & 16 deletions STL_main/ST_Operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def from_ST_Statistics(self, st_stat):
def apply(
self,
data,
standardize=False,
SC=None,
has_fewer_convolutions=None,
norm=None,
Expand Down Expand Up @@ -385,24 +386,35 @@ def apply(
else compute_cross_matrix.to(device=data.device)
)

# Initialize ST statistics values
# Add readability w.r.t. having it in the ST statistics initilization
if standardize:
standardized = True
l_data, mean_pre_std, std_pre_std = self.wavelet_op.standardize(
data, mean_field=False, inplace=False
)
else:
l_data = data.copy()
standardized = False
mean_pre_std, std_pre_std = None, None

# Create a ST_statistics instance
data_st = ST_Statistics(
self.DT,
data.__class__,
N0,
J,
L,
WType,
SC,
Nb,
Nc,
self.wavelet_op,
SC,
has_fewer_convolutions,
compute_cross_matrix,
compute_PS,
self.n_bins,
standardized,
mean_pre_std,
std_pre_std,
)

# Initialize ST statistics values
# Add readability w.r.t. having it in the ST statistics initilization
l_data = data.copy()

# Systematic statistics (data supposed to be real)
assert (
data.array.is_complex() == False
Expand All @@ -417,7 +429,7 @@ def apply(
)

if SC == "ScatCov":
# data_st.S1 = bk.zeros((Nb, Nc, J, L)) + bk.nan
# data_st.S1 = bk.zeros((Nb, Nc, J, L)) + bk.nan
data_st.S1 = (
bk.zeros((Nb, Nc, Nc, J, L), dtype=bk._DEFAULT_COMPLEX_DTYPE) + bk.nan
)
Expand Down Expand Up @@ -496,12 +508,9 @@ def apply(
if (
compute_cross_matrix * (~bk.eye(Nc, dtype=bool, device=data.device))
).any():
data_l1_modulus_square_rooted = data_l1.copy(empty=True)
data_l1_modulus_square_rooted.array = data_l1.array * (
data_l1m[j3].array + 1e-8
) ** (
-0.5
) # (Nb,Nc,L,N3)
data_l1_modulus_square_rooted = data_l1.divide(
data_l1m[j3], epsilon=1e-8, pow=0.5, inplace=False
) # [Nb,Nc,L,N3]

self.wavelet_op._compute_and_store_cross_cov(
data_l1_modulus_square_rooted,
Expand Down
Loading
Loading