Skip to content

Commit 10c0578

Browse files
authored
Merge pull request #23 from alan-turing-institute/fix-slow
Improve performance of scitype on arrays in the mlj convention (resolves #12)
2 parents 6ffed04 + 68ac89f commit 10c0578

File tree

5 files changed

+74
-4
lines changed

5 files changed

+74
-4
lines changed

src/ScientificTypes.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ scitype_union(A) = reduce((a,b)->Union{a,b}, (scitype(el) for el in A))
136136
# ## SCITYPES OF TUPLES AND ARRAYS
137137

138138
scitype(t::Tuple, ::Val) = Tuple{scitype.(t)...}
139+
140+
# The following fallback can be quite slow. Individual conventions
141+
# will usually be able to find more perfomant overloadings of this
142+
# method:
139143
scitype(A::B, ::Val) where {T,N,B<:AbstractArray{T,N}} =
140144
AbstractArray{scitype_union(A),N}
141145

src/conventions/mlj/finite.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,18 @@ function coerce(v, ::Type{T2}; verbosity=1) where T2 <: Union{Missing,Finite}
1818
end
1919
return categorical(v, true, ordered=T2 <: Union{Missing,OrderedFactor})
2020
end
21+
22+
## PERFORMANT SCITYPES FOR ARRAYS
23+
24+
function scitype(A::B, ::Val{:mlj}) where {T,N,B<:CategoricalArray{T,N}}
25+
nlevels = length(levels(A))
26+
if isordered(A)
27+
S = OrderedFactor{nlevels}
28+
else
29+
S = Multiclass{nlevels}
30+
end
31+
if T isa Union && Missing <: T
32+
S = Union{S,Missing}
33+
end
34+
return AbstractArray{S, N}
35+
end

src/conventions/mlj/images.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,3 @@ scitype(image::AbstractArray{<:Gray,2}, ::Val{:mlj}) =
44
GrayImage{size(image)...}
55
scitype(image::AbstractArray{<:AbstractRGB,2}, ::Val{:mlj}) =
66
ColorImage{size(image)...}
7-
8-

src/conventions/mlj/mlj.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@ _coerce_missing_warn(T) =
55
@warn "Missing values encountered coercing scitype to $T.\n"*
66
"Coerced to Union{Missing,$T} instead. "
77

8+
9+
## PERFORMANT SCITYPES FOR ARRAYS
10+
11+
const A{T,N} = AbstractArray{T,N}
12+
13+
scitype(::B, ::Val{:mlj}) where {N,B<:A{<:AbstractFloat,N}} =
14+
A{Continuous,N}
15+
scitype(::B, ::Val{:mlj}) where {N,B<:A{Union{<:AbstractFloat,Missing},N}} =
16+
A{Union{Continuous,Missing},N}
17+
scitype(::B, ::Val{:mlj}) where {N,B<:A{<:Integer,N}} =
18+
A{Count,N}
19+
scitype(::B, ::Val{:mlj}) where {N,B<:A{Union{<:Integer,Missing},N}} =
20+
A{Union{Count,Missing},N}
21+
22+
823
## COERCE VECTOR TO CONTINUOUS
924

1025
"""

test/runtests.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,53 @@ A = Any[2 4.5;
3737
end
3838

3939
@testset "Arrays" begin
40-
@test scitype(A) == AbstractArray{Union{Count, Continuous}, 2}
41-
@test scitype([1,2,3, missing]) == AbstractVector{Union{Missing, Count}}
40+
@test scitype(A) ==
41+
AbstractArray{Union{Count, Continuous}, 2}
42+
43+
@test scitype([1, 2, 3]) ==
44+
AbstractVector{Count}
45+
@test scitype([1, missing, 3]) ==
46+
AbstractVector{Union{Missing,Count}}
47+
@test scitype(Any[1, 2, 3]) ==
48+
AbstractVector{Count}
49+
@test scitype(Any[1, missing, 3]) ==
50+
AbstractVector{Union{Missing,Count}}
51+
52+
@test scitype([1.0, 2.0, 3.0]) ==
53+
AbstractVector{Continuous}
54+
@test scitype(Any[1.0, missing, 3.0]) ==
55+
AbstractVector{Union{Missing,Continuous}}
56+
@test scitype(Any[1.0, 2.0, 3.0]) ==
57+
AbstractVector{Continuous}
58+
@test scitype(Any[1.0, missing, 3.0]) ==
59+
AbstractVector{Union{Missing,Continuous}}
60+
61+
@test scitype(categorical(1:4)) ==
62+
AbstractVector{Multiclass{4}}
63+
@test scitype(Any[categorical(1:4)...]) ==
64+
AbstractVector{Multiclass{4}}
65+
@test scitype(categorical([1, missing, 3])) ==
66+
AbstractVector{Union{Multiclass{2},Missing}}
67+
68+
@test scitype(categorical(1:4, ordered=true)) ==
69+
AbstractVector{OrderedFactor{4}}
70+
@test scitype(Any[categorical(1:4, ordered=true)...]) ==
71+
AbstractVector{OrderedFactor{4}}
72+
@test scitype(categorical([1, missing, 3], ordered=true)) ==
73+
AbstractVector{Union{OrderedFactor{2},Missing}}
74+
4275
end
4376

4477
@testset "Images" begin
4578
black = RGB(0, 0, 0)
4679
color_image = fill(black, (10, 20))
4780
@test scitype(color_image) == ColorImage{10,20}
4881

82+
color_image2 = fill(black, (5, 3))
83+
v = [color_image, color_image2, color_image2]
84+
@test scitype(v) ==
85+
AbstractVector{Union{ColorImage{10,20},ColorImage{5,3}}}
86+
4987
white = Gray(1.0)
5088
gray_image = fill(white, (10, 20))
5189
@test scitype(gray_image) == GrayImage{10,20}

0 commit comments

Comments
 (0)