Skip to content

Commit 4ea69d5

Browse files
pieveraplavin
andauthored
improve type stability of cat (#258)
* improve type stability * bump version Co-authored-by: Alexander Plavin <[email protected]>
1 parent 7446973 commit 4ea69d5

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StructArrays"
22
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3-
version = "0.6.13"
3+
version = "0.6.14"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/structarray.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,13 @@ function Base.sizehint!(s::StructArray, i::Integer)
450450
end
451451

452452
for op in [:cat, :hcat, :vcat]
453+
curried_op = Symbol(:curried, op)
453454
@eval begin
454-
function Base.$op(arg1::StructArray, argsrest::StructArray...; kwargs...)
455-
args = (arg1, argsrest...)
456-
f = key -> $op((getproperty(t, key) for t in args)...; kwargs...)
455+
function Base.$op(arg::StructArray, others::StructArray...; kwargs...)
456+
$curried_op(A...) = $op(A...; kwargs...)
457+
args = (arg, others...)
457458
T = mapreduce(eltype, promote_type, args)
458-
StructArray{T}(map(f, propertynames(arg1)))
459+
StructArray{T}(map($curried_op, map(components, args)...))
459460
end
460461
end
461462
end

test/runtests.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,11 @@ end
636636
horizontal_concat = StructArray{Pair{Int, String}}(([3 1; 5 6], ["a" "a"; "b" "b"]))
637637
@test cat(t, t2; dims=2)::StructArray == horizontal_concat == hcat(t, t2)
638638
@test hcat(t, t2) isa StructArray
639-
640-
# check that cat(dims=1) doesn't commit type piracy (#254)
639+
t3 = StructArray(x=view([1], 1:1:1), y=view([:a], 1:1:1))
640+
@test @inferred(vcat(t3)) == t3
641+
@inferred vcat(t3, t3)
642+
@inferred vcat(t3, collect(t3))
643+
# Check that `cat(dims=1)` doesn't commit type piracy (#254)
641644
# We only test that this works, the return value is immaterial
642645
@test cat(dims=1) == vcat()
643646
end

0 commit comments

Comments
 (0)