@@ -445,19 +445,25 @@ end
445
445
# broadcast
446
446
import Base. Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
447
447
448
- struct StructArrayStyle{Style } <: AbstractArrayStyle{Any } end
448
+ struct StructArrayStyle{S, N } <: AbstractArrayStyle{N } end
449
449
450
- @inline combine_style_types (:: Type{A} , args... ) where A<: AbstractArray =
450
+ # Here we define the dimension tracking behavior of StructArrayStyle
451
+ function StructArrayStyle {S, M} (:: Val{N} ) where {S, M, N}
452
+ T = S <: AbstractArrayStyle{M} ? typeof (S (Val (N))) : S
453
+ return StructArrayStyle {T, N} ()
454
+ end
455
+
456
+ @inline combine_style_types (:: Type{A} , args... ) where {A<: AbstractArray } =
451
457
combine_style_types (BroadcastStyle (A), args... )
452
- @inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where A<: AbstractArray =
458
+ @inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where { A<: AbstractArray } =
453
459
combine_style_types (Broadcast. result_style (s, BroadcastStyle (A)), args... )
454
460
combine_style_types (s:: BroadcastStyle ) = s
455
461
456
- Base. @pure cst (:: Type{SA} ) where SA = combine_style_types (array_types (SA). parameters... )
462
+ Base. @pure cst (:: Type{SA} ) where {SA} = combine_style_types (array_types (SA). parameters... )
457
463
458
- BroadcastStyle (:: Type{SA} ) where SA<: StructArray = StructArrayStyle {typeof(cst(SA))} ()
464
+ BroadcastStyle (:: Type{SA} ) where { SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA )} ()
459
465
460
- Base. similar (bc:: Broadcasted{StructArrayStyle{S}} , :: Type{ElType} ) where {S<: DefaultArrayStyle ,N, ElType} =
466
+ Base. similar (bc:: Broadcasted{StructArrayStyle{S, N }} , :: Type{ElType} ) where {S<: DefaultArrayStyle , N, ElType} =
461
467
isstructtype (ElType) ? similar (StructArray{ElType}, axes (bc)) : similar (Array{ElType}, axes (bc))
462
468
463
469
# for aliasing analysis during broadcast
0 commit comments