Skip to content

Commit 9f9f41c

Browse files
authored
Merge pull request #129 from fverdugo/non_isbitstype
Support for gather/scatter with Non isbitstype objects (and more)
2 parents f03ff94 + 1567473 commit 9f9f41c

File tree

7 files changed

+107
-18
lines changed

7 files changed

+107
-18
lines changed

CHANGELOG.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.4.1] - Unreleased
9+
10+
### Added
11+
12+
- Gather/scatter for non isbitstype objects.
13+
- Function `find_local_indices`.
814

915
## [0.4.0] - 2024-01-21
1016

11-
## Changed
17+
### Changed
1218

1319
- Major refactoring in `PSparseMatrix` (and in `PVector` in a lesser extent).
1420
The old code is still available (but deprecated), and can be recovered applying this renaming to your code-base:
@@ -21,7 +27,7 @@ The old code is still available (but deprecated), and can be recovered applying
2127
The previous "monolithic" storage is not implemented anymore for the new version of `PSparseMatrix`, but can be implemented in the new setup if needed.
2228
- `emit` renamed to `multicast`. The former name is still available but deprecated.
2329

24-
## Added
30+
### Added
2531

2632
- Efficient re-construction of `PSparseMatrix` and `PVector` objects.
2733
- Functions `assemble` and `consistent` (allocating versions of `assemble!` and `consistent!` with a slightly different

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PartitionedArrays"
22
uuid = "5a9dfac6-5c52-46f7-8278-5e2210713be9"
33
authors = ["Francesc Verdugo <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
CircularArrays = "7a955b69-7140-5f4e-a0ed-f168c5e2e749"

docs/src/reference/primitives.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ scatter!
1616
allocate_scatter
1717
```
1818

19-
## Emit
19+
## Multicast
2020

2121
```@docs
2222
multicast

src/PartitionedArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ export assemble
130130
export consistent
131131
export repartition
132132
export repartition!
133+
export find_local_indices
133134
include("p_vector.jl")
134135

135136
export OldPSparseMatrix

src/mpi_array.jl

+30-14
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,21 @@ function gather_impl!(
281281
comm = snd.comm
282282
if isa(destination,Integer)
283283
root = destination-1
284-
if MPI.Comm_rank(comm) == root
285-
@assert length(rcv.item) == MPI.Comm_size(comm)
286-
rcv.item[destination] = snd.item
287-
rcv_buffer = MPI.UBuffer(rcv.item,1)
288-
MPI.Gather!(MPI.IN_PLACE,rcv_buffer,root,comm)
284+
if isbitstype(T)
285+
if MPI.Comm_rank(comm) == root
286+
@assert length(rcv.item) == MPI.Comm_size(comm)
287+
rcv.item[destination] = snd.item
288+
rcv_buffer = MPI.UBuffer(rcv.item,1)
289+
MPI.Gather!(MPI.IN_PLACE,rcv_buffer,root,comm)
290+
else
291+
MPI.Gather!(snd.item_ref,nothing,root,comm)
292+
end
289293
else
290-
MPI.Gather!(snd.item_ref,nothing,root,comm)
294+
if MPI.Comm_rank(comm) == root
295+
rcv.item[:] = MPI.gather(snd.item,comm;root)
296+
else
297+
MPI.gather(snd.item,comm;root)
298+
end
291299
end
292300
else
293301
@assert destination === :all
@@ -330,17 +338,25 @@ end
330338
function scatter_impl!(
331339
rcv::MPIArray,snd::MPIArray,
332340
source,::Type{T}) where T
333-
@assert source !== :all "All to all not implemented"
334-
@assert rcv.comm === snd.comm
335-
@assert eltype(snd.item) == typeof(rcv.item)
336341
comm = snd.comm
337342
root = source - 1
338-
if MPI.Comm_rank(comm) == root
339-
snd_buffer = MPI.UBuffer(snd.item,1)
340-
rcv.item = snd.item[source]
341-
MPI.Scatter!(snd_buffer,MPI.IN_PLACE,root,comm)
343+
@assert source !== :all "All to all not implemented"
344+
@assert rcv.comm === snd.comm
345+
if isbitstype(T)
346+
@assert eltype(snd.item) == typeof(rcv.item)
347+
if MPI.Comm_rank(comm) == root
348+
snd_buffer = MPI.UBuffer(snd.item,1)
349+
rcv.item = snd.item[source]
350+
MPI.Scatter!(snd_buffer,MPI.IN_PLACE,root,comm)
351+
else
352+
MPI.Scatter!(nothing,rcv.item_ref,root,comm)
353+
end
342354
else
343-
MPI.Scatter!(nothing,rcv.item_ref,root,comm)
355+
if MPI.Comm_rank(comm) == root
356+
rcv.item_ref[] = MPI.scatter(snd.item,comm;root)
357+
else
358+
rcv.item_ref[] = MPI.scatter(nothing,comm;root)
359+
end
344360
end
345361
rcv
346362
end

src/p_vector.jl

+40
Original file line numberDiff line numberDiff line change
@@ -1076,4 +1076,44 @@ function repartition!(w::PVector,v::PVector,cache;reversed=false)
10761076
end
10771077
end
10781078

1079+
function find_local_indices(node_to_mask::PVector)
1080+
n_own_dofs = map(count,own_values(node_to_mask))
1081+
n_dofs = sum(n_own_dofs)
1082+
dof_partition = variable_partition(n_own_dofs,n_dofs)
1083+
node_partition = partition(axes(node_to_mask,1))
1084+
node_to_global_dof = pzeros(Int,node_partition)
1085+
function fill_own_dofs!(own_node_to_global_dof,own_node_to_boundary,dofs)
1086+
own_to_global_dof = own_to_global(dofs)
1087+
own_node_to_global_dof[own_node_to_boundary] = own_to_global_dof
1088+
end
1089+
map(fill_own_dofs!,own_values(node_to_global_dof),own_values(node_to_mask),dof_partition)
1090+
consistent!(node_to_global_dof) |> wait
1091+
function add_ghost_dofs(ghost_node_to_global_dof,nodes,dofs)
1092+
ghost_node_to_owner = ghost_to_owner(nodes)
1093+
free_ghost_nodes = findall(global_dof->global_dof!=0,ghost_node_to_global_dof)
1094+
owners = view(ghost_node_to_owner,free_ghost_nodes)
1095+
ghost_dofs = view(ghost_node_to_global_dof,free_ghost_nodes)
1096+
union_ghost(dofs,ghost_dofs,owners)
1097+
end
1098+
dof_partition = map(add_ghost_dofs,ghost_values(node_to_global_dof),node_partition,dof_partition)
1099+
neighbors = assembly_graph(node_partition)
1100+
assembly_neighbors(dof_partition;neighbors)
1101+
node_to_local_dof = pzeros(Int32,node_partition)
1102+
dof_to_local_node = pzeros(Int32,dof_partition)
1103+
function finalize!(local_node_to_global_dof,local_node_to_local_dof,local_dof_to_local_node,dofs)
1104+
global_to_local_dof = global_to_local(dofs)
1105+
n_local_nodes = length(local_node_to_global_dof)
1106+
for local_node in 1:n_local_nodes
1107+
global_dof = local_node_to_global_dof[local_node]
1108+
if global_dof == 0
1109+
continue
1110+
end
1111+
local_dof = global_to_local_dof[global_dof]
1112+
local_node_to_local_dof[local_node] = local_dof
1113+
local_dof_to_local_node[local_dof] = local_node
1114+
end
1115+
end
1116+
map(finalize!,partition(node_to_global_dof),partition(node_to_local_dof),partition(dof_to_local_node),dof_partition)
1117+
dof_to_local_node, node_to_local_dof
1118+
end
10791119

test/primitives_tests.jl

+26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11

22
using Test
33

4+
struct NonIsBitsType{T}
5+
data::Vector{T}
6+
end
7+
Base.:(==)(a::NonIsBitsType,b::NonIsBitsType) = a.data == b.data
8+
49
function primitives_tests(distribute)
510

611
rank = distribute(LinearIndices((2,2)))
@@ -61,6 +66,27 @@ function primitives_tests(distribute)
6166
@test rcv == [[1],[1,2],[1,2,3],[1,2,3,4]]
6267
end
6368

69+
snd2 = map(rank) do rank
70+
NonIsBitsType([2])
71+
end
72+
rcv2 = gather(snd2)
73+
snd3 = scatter(rcv2)
74+
map(snd2,snd3) do snd2,snd3
75+
@test snd2 == snd3
76+
end
77+
78+
np = length(rank)
79+
rcv3 = map_main(rank) do rank
80+
fill(NonIsBitsType([2]),np)
81+
end
82+
snd3 = allocate_scatter(rcv3)
83+
scatter!(snd3,rcv3)
84+
snd3 = scatter(rcv3)
85+
rcv4 = gather(snd3)
86+
map(rcv4,rcv2) do rcv4,rcv2
87+
@test rcv4 == rcv2
88+
end
89+
6490
rcv = multicast(rank,source=2)
6591
map(rcv) do rcv
6692
@test rcv == 2

0 commit comments

Comments
 (0)