Skip to content

Commit a4a74ab

Browse files
authored
Merge pull request #205 from PartitionedArrays/adapt
Adding some prelimiar Adapt statements
2 parents 3a6eb12 + 00e7256 commit a4a74ab

2 files changed

Lines changed: 53 additions & 3 deletions

File tree

src/adapt.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,43 @@ function Adapt.adapt_structure(to,v::PSparseMatrix)
5050
row_par = v.row_partition
5151
PSparseMatrix(matrix_partition,row_par,col_par,v.assembled)
5252
end
53+
54+
function Adapt.adapt_structure(to,v::PVector)
55+
new_local_values = map(local_values(v)) do myvals
56+
Adapt.adapt_structure(to,myvals)
57+
end
58+
new_cache = Adapt.adapt_structure(to,v.cache)
59+
new_v = PVector(new_local_values,v.index_partition, new_cache)
60+
new_v
61+
end
62+
63+
function Adapt.adapt_structure(to, cache::SplitVectorAssemblyCache)
64+
# Adapt all the components
65+
neighbors_snd = cache.neighbors_snd
66+
neighbors_rcv = cache.neighbors_rcv
67+
buffer_snd = map(cache.buffer_snd) do ja
68+
Adapt.adapt_structure(to, ja)
69+
end
70+
buffer_rcv = map(cache.buffer_rcv) do ja
71+
Adapt.adapt_structure(to, ja)
72+
end
73+
exchange_setup = cache.exchange_setup
74+
ghost_indices_snd = map(cache.ghost_indices_snd) do ja
75+
Adapt.adapt_structure(to, ja)
76+
end
77+
own_indices_rcv = map(cache.own_indices_rcv) do ja
78+
Adapt.adapt_structure(to, ja)
79+
end
80+
81+
# Create new cache with adapted components
82+
SplitVectorAssemblyCache(
83+
neighbors_snd,
84+
neighbors_rcv,
85+
ghost_indices_snd,
86+
own_indices_rcv,
87+
buffer_snd,
88+
buffer_rcv,
89+
exchange_setup,
90+
false
91+
)
92+
end

test/adapt_tests.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ function Adapt.adapt_storage(::Type{<:FakeCuVector},x::AbstractArray)
1414
end
1515

1616
function adapt_tests(distribute)
17-
18-
rank = distribute(LinearIndices((2,2)))
17+
parts_per_dir = (2,2)
18+
rank = distribute(LinearIndices(parts_per_dir))
1919

2020
a = [[1,2],[3,4,5],Int[],[3,4]]
2121
b = JaggedArray(a)
@@ -61,4 +61,14 @@ function adapt_tests(distribute)
6161
@test typeof(val_b) == FakeCuVector{typeof(val_a)}
6262
@test val_b.vector == val_a
6363
end
64-
end
64+
65+
p = prod(parts_per_dir)
66+
ranks = distribute(LinearIndices((p,)))
67+
nodes_per_dir = map(i->2*i,parts_per_dir)
68+
args = laplacian_fdm(nodes_per_dir,parts_per_dir,ranks)
69+
A = psparse(args...) |> fetch
70+
Adapt.adapt(FakeCuVector, A)
71+
b = pzeros(axes(A, 2), split_format=true)
72+
Adapt.adapt(FakeCuVector, b)
73+
74+
end

0 commit comments

Comments
 (0)