Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ext/NFFTGPUArraysExt/NFFTGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ module NFFTGPUArraysExt

using NFFT, NFFT.AbstractNFFTs
using NFFT.SparseArrays, NFFT.LinearAlgebra, NFFT.FFTW
using GPUArrays, Adapt
using GPUArrays, GPUArrays.KernelAbstractions, Adapt
using GPUArrays.KernelAbstractions.Extras: @unroll

include("implementation.jl")
include("precomputation.jl")

end
10 changes: 5 additions & 5 deletions ext/NFFTGPUArraysExt/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI
B::SM
end

function AbstractNFFTs.plan_nfft(::NFFTBackend, arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...;
AbstractNFFTs.plan_nfft(b::NFFTBackend, arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...; kargs...) where {T,D} = plan_nfft(b, arr, arr(k), N, rest...; kargs...)
function AbstractNFFTs.plan_nfft(::NFFTBackend, arr::Type{<:AbstractGPUArray}, k::AbstractGPUArray{T}, N::NTuple{D,Int}, rest...;
timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D}
t = @elapsed begin
p = GPU_NFFTPlan(arr, k, N, rest...; kargs...)
Expand All @@ -27,7 +28,7 @@ function AbstractNFFTs.plan_nfft(::NFFTBackend, arr::Type{<:AbstractGPUArray}, k
return p
end

function GPU_NFFTPlan(arr, k::Matrix{T}, N::NTuple{D,Int}; dims::Union{Integer,UnitRange{Int64}}=1:D,
function GPU_NFFTPlan(arr, k::AbstractGPUMatrix{T}, N::NTuple{D,Int}; dims::Union{Integer,UnitRange{Int64}}=1:D,
fftflags=nothing, kwargs...) where {T,D}

if dims != 1:D
Expand All @@ -50,10 +51,9 @@ function GPU_NFFTPlan(arr, k::Matrix{T}, N::NTuple{D,Int}; dims::Union{Integer,U

deconvIdx = Int32.(adapt(arr, (deconvolveIdx)))
winHatInvLUT = Complex{T}.(adapt(arr, (windowHatInvLUT[1])))
B_ = Complex{T}.(adapt(arr, (B))) # Bit hacky

GPU_NFFTPlan{T,D, typeof(tmpVec), typeof(deconvIdx), typeof(FP), typeof(BP), typeof(winHatInvLUT), typeof(B_)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat,
deconvIdx, windowLinInterp, winHatInvLUT, B_)
GPU_NFFTPlan{T,D, typeof(tmpVec), typeof(deconvIdx), typeof(FP), typeof(BP), typeof(winHatInvLUT), typeof(B)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat,
deconvIdx, windowLinInterp, winHatInvLUT, B)
end

AbstractNFFTs.size_in(p::GPU_NFFTPlan) = p.N
Expand Down
48 changes: 48 additions & 0 deletions ext/NFFTGPUArraysExt/precomputation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function NFFT.precomputeB(win, k::AbstractGPUArray, N::NTuple{D,Int}, Ñ::NTuple{D,Int}, m, J, σ, K, T) where D
I = similar(k, Int64, (2*m)^D, J)
β = (2*m)^D

# CPU uses a CSC constructor, which is not generically available for GPU (I think)
#Y = similar(k, Int64, J + 1)
#Y .= (0:J) .* β .+ 1
# We have to use the COO constructor and need (2*m)^D * J values:
Y = similar(k, Int64, (2*m)^D * J)
Y .= ((0:β*J-1) .÷ β) .+ 1

V = similar(k, T, (2*m)^D, J)
nProd = ntuple(d-> (d==1) ? 1 : prod(Ñ[1:(d-1)]), D)
L = Val(2*m)

@kernel inbounds = true function precomputeB_kernel(I, V, win, k, Ñ::NTuple{D,Int}, m, σ, nProd, ::Val{Z}) where {D, Z}
idx = @index(Global, Cartesian)
j = idx[2]
linear = idx[1]

prodWin = one(eltype(k))
ζ = 1
tmpIdx = linear - 1 # 0-based for index calculation
@unroll for d = 1:D
l_d = (tmpIdx % Z) + 1 # index in 1:(2*m)
tmpIdx = div(tmpIdx, Z)

kscale = k[d, j] * Ñ[d]
off = floor(Int, kscale) - m + 1

idx_d = rem(l_d + off + Ñ[d] - 1, Ñ[d]) + 1 # periodic wrapped index in 1:Ñ[d]
ζ += (idx_d - 1) * nProd[d]

# accumulate window product
prodWin *= win( (kscale - (l_d-1) - off) / Ñ[d], Ñ[d], m, σ)
end

I[idx] = ζ
V[idx] = prodWin
end

backend = get_backend(k)
kernel = precomputeB_kernel(backend)
kernel(I, V, win, k, Ñ, m, σ, nProd, L, ndrange = size(I))

S = sparse(vec(I), Y, vec(V), prod(Ñ), J)
return S
end
4 changes: 2 additions & 2 deletions src/precomputation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
### Init some initial parameters necessary to create the plan ###

function initParams(k::Matrix{T}, N::NTuple{D,Int}, dims::Union{Integer,UnitRange{Int64}}=1:D;
function initParams(k::AbstractMatrix{T}, N::NTuple{D,Int}, dims::Union{Integer,UnitRange{Int64}}=1:D;
kargs...) where {D,T}
# convert dims to a unit range
dims_ = (typeof(dims) <: Integer) ? (dims:dims) : dims
Expand Down Expand Up @@ -357,7 +357,7 @@ function precomputeWindowHatInvLUT(windowHatInvLUT, win_hat, N, Ñ, m, σ, T)
end
end

function precomputation(k::Union{Matrix{T},Vector{T}}, N::NTuple{D,Int}, Ñ, params) where {T,D}
function precomputation(k::Union{AbstractMatrix{T},AbstractVector{T}}, N::NTuple{D,Int}, Ñ, params) where {T,D}

m = params.m; σ = params.σ; window=params.window
LUTSize = params.LUTSize; precompute = params.precompute
Expand Down
Loading