@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494
494
end
495
495
496
496
# broadcast
497
- import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497
+ import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498
+ using Base. Broadcast: combine_styles
498
499
499
500
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500
501
@@ -524,6 +525,82 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524
525
525
526
BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
526
527
528
+ """
529
+ always_struct_broadcast(style::BroadcastStyle)
530
+
531
+ Check if `style` supports struct-broadcast natively, which means:
532
+ 1) `Base.copy` is not overloaded.
533
+ 2) `Base.similar` is defined.
534
+ 3) `Base.copyto!` supports `StructArray`s as broadcasted arguments.
535
+
536
+ If any of the above conditions are not met, then this function should
537
+ not be overloaded.
538
+ In that case, try to overload [`try_struct_copy`](@ref) to support out-of-place
539
+ struct-broadcast.
540
+ """
541
+ always_struct_broadcast (:: Any ) = false
542
+ always_struct_broadcast (:: DefaultArrayStyle ) = true
543
+ always_struct_broadcast (:: ArrayConflict ) = true
544
+
545
+ """
546
+ try_struct_copy(bc::Broadcasted)
547
+
548
+ Entry for non-native outplace struct-broadcast.
549
+
550
+ See also [`always_struct_broadcast`](@ref).
551
+ """
552
+ try_struct_copy (bc:: Broadcasted ) = copy (bc)
553
+
554
+ function Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
555
+ if always_struct_broadcast (S ())
556
+ return invoke (copy, Tuple{Broadcasted}, bc)
557
+ else
558
+ return try_struct_copy (replace_structarray (bc))
559
+ end
560
+ end
561
+
562
+ """
563
+ replace_structarray(bc::Broadcasted)
564
+
565
+ An internal function transforms the `Broadcasted` with `StructArray` into
566
+ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567
+ supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568
+ e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569
+ """
570
+ function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
571
+ args = replace_structarray_args (bc. args)
572
+ Style′ = parent_style (Style ())
573
+ return Broadcasted {Style′} (bc. f, args, bc. axes)
574
+ end
575
+ function replace_structarray (A:: StructArray )
576
+ f = Instantiator (eltype (A))
577
+ args = Tuple (components (A))
578
+ Style = typeof (combine_styles (args... ))
579
+ return Broadcasted {Style} (f, args, axes (A))
580
+ end
581
+ replace_structarray (@nospecialize (A)) = A
582
+
583
+ replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584
+ replace_structarray_args (:: Tuple{} ) = ()
585
+
586
+ parent_style (@nospecialize (x)) = typeof (x)
587
+ parent_style (:: StructArrayStyle{S, N} ) where {S, N} = S
588
+ parent_style (:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
589
+ parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
590
+ parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof (S (Val (N)))
591
+
592
+ # `instantiate` and `_axes` might be overloaded for static axes.
593
+ function Broadcast. instantiate (bc:: Broadcasted{Style} ) where {Style <: StructArrayStyle }
594
+ Style′ = parent_style (Style ())
595
+ bc′ = Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596
+ return convert (Broadcasted{Style}, bc′)
597
+ end
598
+
599
+ function Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where {Style <: StructArrayStyle }
600
+ Style′ = parent_style (Style ())
601
+ return Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
602
+ end
603
+
527
604
# Here we use `similar` defined for `S` to build the dest Array.
528
605
function Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
529
606
bc′ = convert (Broadcasted{S}, bc)
@@ -532,12 +609,22 @@ end
532
609
533
610
# Unwrapper to recover the behaviour defined by parent style.
534
611
@inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
535
- return copyto! (dest, convert (Broadcasted{S}, bc))
612
+ bc′ = always_struct_broadcast (S ()) ? convert (Broadcasted{S}, bc) : replace_structarray (bc)
613
+ return copyto! (dest, bc′)
536
614
end
537
615
538
616
@inline function Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
539
- return Broadcast. materialize! (S (), dest, bc)
617
+ bc′ = always_struct_broadcast (S ()) ? bc : replace_structarray (bc)
618
+ return Broadcast. materialize! (S (), dest, bc′)
540
619
end
541
620
542
621
# for aliasing analysis during broadcast
622
+ function Broadcast. broadcast_unalias (dest:: StructArray , src:: AbstractArray )
623
+ if dest === src || any (Base. Fix2 (=== , src), components (dest))
624
+ return src
625
+ else
626
+ return Base. unalias (dest, src)
627
+ end
628
+ end
629
+
543
630
Base. dataids (u:: StructArray ) = mapreduce (Base. dataids, (a, b) -> (a... , b... ), values (components (u)), init= ())
0 commit comments