diff --git a/tests/test_2D_FFT_Torch.py b/tests/test_2D_FFT_Torch.py index 7c04220..808a39f 100644 --- a/tests/test_2D_FFT_Torch.py +++ b/tests/test_2D_FFT_Torch.py @@ -62,9 +62,9 @@ def test_DataClass_downsample(): data, dg_out, inplace=False, target_fourier_status=False ) diff = np.asarray( - data_downsampled.array - data.array[..., :: 2**dg_out, :: 2**dg_out] + (data_downsampled.array - data.array[..., :: 2**dg_out, :: 2**dg_out]).cpu() ) - assert np.all(np.abs(diff) < threshold * np.abs(np.asarray(data_downsampled.array))) + assert np.all(np.abs(diff) < threshold * np.abs(np.asarray(data_downsampled.array.cpu()))) if __name__ == "__main__":