From 8bcf749c0455e5c824c4c993fad9069b80433157 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 1 Aug 2022 21:22:00 +0800 Subject: [PATCH 1/3] Preserve more `FieldArray`s with parametric `eltype`. And return a `MArray` for mutable `FieldArray` --- src/FieldArray.jl | 29 ++++++++++++++++++----------- test/FieldMatrix.jl | 5 ++--- test/FieldVector.jl | 7 +++---- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/FieldArray.jl b/src/FieldArray.jl index e8b66850..09a4078f 100644 --- a/src/FieldArray.jl +++ b/src/FieldArray.jl @@ -125,14 +125,21 @@ Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a) Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} = Ptr{T}(Base.unsafe_convert(Ptr{FA}, m)) -# We can automatically preserve FieldArrays in array operations which do not -# change their eltype or Size. This should cover all non-parametric FieldArray, -# but for those which are parametric on the eltype the user will still need to -# overload similar_type themselves. -similar_type(::Type{A}, ::Type{T}, S::Size) where {N, T, A<:FieldArray{N, T}} = - _fieldarray_similar_type(A, T, S, Size(A)) - -# Extra layer of dispatch to match NewSize and OldSize -_fieldarray_similar_type(A, T, NewSize::S, OldSize::S) where {S} = A -_fieldarray_similar_type(A, T, NewSize, OldSize) = - default_similar_type(T, NewSize, length_val(NewSize)) +# We can preserve FieldArrays in array operations which do not change their `Size` and `eltype`. +# FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically. +# Otherwise, we fallback to `S/MArray` based on it's mutability. +function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray} + A′ = Base.typeintersect(base_type(A), StaticArray{Tuple{Tuple(S)...},T,length(S)}) + isabstracttype(A′) || A′ === Union{} || return A′ + if ismutabletype(A) + return mutable_similar_type(T, S, length_val(S)) + else + return default_similar_type(T, S, length_val(S)) + end +end + +# return `Union{}` for Union Type. Otherwise return the constructor with no parameters. +@pure base_type(@nospecialize(T::Type)) = (T′ = Base.unwrap_unionall(T); T′ isa DataType ? T′.name.wrapper : Union{}) +if VERSION < v"1.7" + @pure ismutabletype(@nospecialize(T::Type)) = (T′ = Base.unwrap_unionall(T); T′ isa DataType && T′.mutable) +end diff --git a/test/FieldMatrix.jl b/test/FieldMatrix.jl index 4ed8c15d..fc08317e 100644 --- a/test/FieldMatrix.jl +++ b/test/FieldMatrix.jl @@ -59,7 +59,6 @@ yy::T end - StaticArrays.similar_type(::Type{<:Tensor2x2}, ::Type{T}, s::Size{(2,2)}) where {T} = Tensor2x2{T} end) p = Tensor2x2(0.0, 0.0, 0.0, 0.0) @@ -83,8 +82,8 @@ @test @inferred(similar_type(Tensor2x2{Float64})) == Tensor2x2{Float64} @test @inferred(similar_type(Tensor2x2{Float64}, Float32)) == Tensor2x2{Float32} - @test @inferred(similar_type(Tensor2x2{Float64}, Size(3,3))) == SMatrix{3,3,Float64,9} - @test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4,4))) == SMatrix{4,4,Float32,16} + @test @inferred(similar_type(Tensor2x2{Float64}, Size(3, 3))) == MMatrix{3,3,Float64,9} + @test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4, 4))) == MMatrix{4,4,Float32,16} # eltype promotion @test Tuple(@inferred(Tensor2x2(1., 2, 3, 4f0))) === (1.,2.,3.,4.) diff --git a/test/FieldVector.jl b/test/FieldVector.jl index 318035cb..17f5f664 100644 --- a/test/FieldVector.jl +++ b/test/FieldVector.jl @@ -63,7 +63,6 @@ y::T end - StaticArrays.similar_type(::Type{<:Point2D}, ::Type{T}, s::Size{(2,)}) where {T} = Point2D{T} end) p = Point2D(0.0, 0.0) @@ -86,8 +85,8 @@ @test @inferred(similar_type(Point2D{Float64})) == Point2D{Float64} @test @inferred(similar_type(Point2D{Float64}, Float32)) == Point2D{Float32} - @test @inferred(similar_type(Point2D{Float64}, Size(4))) == SVector{4,Float64} - @test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == SVector{4,Float32} + @test @inferred(similar_type(Point2D{Float64}, Size(4))) == MVector{4,Float64} + @test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == MVector{4,Float32} # eltype promotion @test Point2D(1f0, 2) isa Point2D{Float32} @@ -122,7 +121,7 @@ # No similar_type defined - test fallback codepath end) - @test @inferred(similar_type(FVT{Float64}, Float32)) == SVector{2,Float32} # Fallback code path + @test @inferred(similar_type(FVT{Float64}, Float32)) == FVT{Float32} @test @inferred(similar_type(FVT{Float64}, Size(2))) == FVT{Float64} @test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64} @test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32} From 5efcdaaa6bbb3aa000f9ba8af269359eaf77cd24 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 2 Aug 2022 09:40:58 +0800 Subject: [PATCH 2/3] Add more checks for `similar_type` with `eltype` changed. The `typeintersect` must return a concrete type now. And it's `fieldtypes` should match our doc example. --- src/FieldArray.jl | 10 ++++++---- test/FieldVector.jl | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/FieldArray.jl b/src/FieldArray.jl index 09a4078f..6fd4389b 100644 --- a/src/FieldArray.jl +++ b/src/FieldArray.jl @@ -125,12 +125,14 @@ Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a) Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} = Ptr{T}(Base.unsafe_convert(Ptr{FA}, m)) -# We can preserve FieldArrays in array operations which do not change their `Size` and `eltype`. -# FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically. -# Otherwise, we fallback to `S/MArray` based on it's mutability. function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray} + # We can preserve FieldArrays in array operations which do not change their `Size` and `eltype`. + has_eltype(A) && eltype(A) === T && has_size(A) && Size(A) === S && return A + # FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically. A′ = Base.typeintersect(base_type(A), StaticArray{Tuple{Tuple(S)...},T,length(S)}) - isabstracttype(A′) || A′ === Union{} || return A′ + # But extra parameters are disallowed here. Also we check `fieldtypes` to make sure the result is valid. + isconcretetype(A′) && fieldtypes(A′) === ntuple(Returns(T), Val(prod(S))) && return A′ + # Otherwise, we fallback to `S/MArray` based on it's mutability. if ismutabletype(A) return mutable_similar_type(T, S, length_val(S)) else diff --git a/test/FieldVector.jl b/test/FieldVector.jl index 17f5f664..26a52aa5 100644 --- a/test/FieldVector.jl +++ b/test/FieldVector.jl @@ -126,4 +126,28 @@ @test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64} @test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32} end + + @testset "similar_type for some ill FieldVector" begin + # extra parameters + struct IllFV{T,N} <: FieldVector{3,T} + x::T + y::T + z::T + end + + @test @inferred(similar_type(IllFV{Float64}, Float64)) == IllFV{Float64} + @test @inferred(similar_type(IllFV{Float64,Int}, Float64)) == IllFV{Float64,Int} + @test @inferred(similar_type(IllFV{Float64}, Float32)) == SVector{3,Float32} + @test @inferred(similar_type(IllFV{Float64,Int}, Float32)) == SVector{3,Float32} + + # invalid `eltype` + struct IllFV2{T} <: FieldVector{3,T} + x::Int + y::Float64 + z::Int8 + end + + @test @inferred(similar_type(IllFV2{Float64}, Float64)) == IllFV2{Float64} + @test @inferred(similar_type(IllFV2{Float64}, Float32)) == SVector{3,Float32} + end end From ac2f60e450f080d876dcd7110c4e0b954dc1c203 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 2 Aug 2022 10:42:39 +0800 Subject: [PATCH 3/3] Fix inference on 1.6/1.7 --- src/FieldArray.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/FieldArray.jl b/src/FieldArray.jl index 6fd4389b..d1f4757b 100644 --- a/src/FieldArray.jl +++ b/src/FieldArray.jl @@ -131,7 +131,7 @@ function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray} # FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically. A′ = Base.typeintersect(base_type(A), StaticArray{Tuple{Tuple(S)...},T,length(S)}) # But extra parameters are disallowed here. Also we check `fieldtypes` to make sure the result is valid. - isconcretetype(A′) && fieldtypes(A′) === ntuple(Returns(T), Val(prod(S))) && return A′ + isconcretetype(A′) && fieldtypes(A′) === ntuple(_ -> T, Val(prod(S))) && return A′ # Otherwise, we fallback to `S/MArray` based on it's mutability. if ismutabletype(A) return mutable_similar_type(T, S, length_val(S)) @@ -141,7 +141,12 @@ function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray} end # return `Union{}` for Union Type. Otherwise return the constructor with no parameters. -@pure base_type(@nospecialize(T::Type)) = (T′ = Base.unwrap_unionall(T); T′ isa DataType ? T′.name.wrapper : Union{}) -if VERSION < v"1.7" - @pure ismutabletype(@nospecialize(T::Type)) = (T′ = Base.unwrap_unionall(T); T′ isa DataType && T′.mutable) +@pure base_type(@nospecialize(T::Type)) = (T′ = Base.unwrap_unionall(T); +T′ isa DataType ? T′.name.wrapper : Union{}) +if VERSION < v"1.8" + fieldtypes(::Type{T}) where {T} = ntuple(i -> fieldtype(T, i), Val(fieldcount(T))) + @eval @pure function ismutabletype(@nospecialize(T::Type)) + T′ = Base.unwrap_unionall(T) + T′ isa DataType && $(VERSION < v"1.7" ? :(T′.mutable) : :(T′.name.flags & 0x2 == 0x2)) + end end