diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index fe802a752..5a1068243 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -253,8 +253,33 @@ julia> g12.ndata.x 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ``` """ -Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...) - +function Flux.batch(gs::Vector{<:GNNGraph}) + nodes = [g.num_nodes for g in gs] + + if all(y -> isa(y, COO_T), [g.graph for g in gs] ) + edge_indices = [edge_index(g) for g in gs] + nodesum = cumsum([0, nodes...])[1:end-1] + s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + w = reduce(vcat, [get_edge_weight(g) for g in gs]) + w = w isa Vector{Nothing} ? nothing : w + graph = (s, t, w) + graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...) + elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] ) + graph = blockdiag([g.graph for g in gs]...) + graph_indicator = vcat([ones_like(graph,Int,nodes[ii]) .+ (ii - 1) for ii in 1:length(nodes)]...) + end + + GNNGraph(graph, + sum(nodes), + sum([g.num_edges for g in gs]), + sum([g.num_graphs for g in gs]), + graph_indicator, + cat_features([g.ndata for g in gs]), + cat_features([g.edata for g in gs]), + cat_features([g.gdata for g in gs]), + ) +end """ unbatch(g::GNNGraph) diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index 159227466..60d6e819d 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -24,6 +24,18 @@ function cat_features(x1::NamedTuple, x2::NamedTuple) NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1)) end +function cat_features(xs::Vector{<:NamedTuple}) + symbols = [sort(collect(keys(x))) for x in xs] + all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys" + length(xs) == 1 && return xs[1] + + # concatenate + syms = symbols[1] + NamedTuple( + k => cat_features([x[k] for x in xs]) for (ii,k) in enumerate(syms) + ) +end + # Turns generic type into named tuple normalize_graphdata(data::Nothing; kws...) = NamedTuple() diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index 75c062eee..6b24ad466 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -25,6 +25,7 @@ g12 = Flux.batch([g1, g2]) g12b = blockdiag(g1, g2) + @test g12 == g12b g123 = Flux.batch([g1, g2, g3]) @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]