diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 896ba38761..01bb34618c 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -28,7 +28,7 @@ To change the default on an applicable layer, pass the desired function with the ```jldoctest; setup = :(using Flux) julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal) -Conv((3, 3), 1=>8, relu) +Conv((3, 3), 1 => 8, relu) # 80 parameters ``` ```@docs diff --git a/src/Flux.jl b/src/Flux.jl index 0689beb278..025d69ae36 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -43,6 +43,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") +include("layers/show.jl") include("outputsize.jl") diff --git a/src/functor.jl b/src/functor.jl index dd6374353a..99d6411cfa 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,8 +1,7 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky using Zygote: IdSet -import Functors: @functor, functor, fmap -import Functors +import Functors: Functors, @functor, functor, fmap, isleaf trainable(m) = functor(m)[1] diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 31b25ef1df..8f4b1053b7 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -89,7 +89,7 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided # Examples ```jldoctest julia> d = Dense(5, 2) -Dense(5, 2) +Dense(5, 2) # 12 parameters julia> d(rand(Float32, 5, 64)) |> size (2, 64) @@ -98,7 +98,7 @@ julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimension (2, 1, 1, 64) julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix -Dense(5, 2, tanh; bias=false) +Dense(5, 2, tanh; bias=false) # 10 parameters julia> d1(ones(5)) 2-element Vector{Float64}: @@ -395,7 +395,11 @@ julia> size(model(rand(3))) (17,) julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)) -Parallel(+, Dense(10, 2), Dense(5, 2)) +Parallel( + +, + Dense(10, 2), # 22 parameters + Dense(5, 2), # 12 parameters +) # Total: 4 arrays, 34 parameters, 392 bytes. julia> size(model(rand(10), rand(5))) (2,) @@ -417,8 +421,10 @@ Parallel(connection, layers...) = Parallel(connection, layers) Base.getindex(m::Parallel, i::Integer) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) +trainable(m::Parallel) = (m.connection, m.layers...) + function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") join(io, m.layers, ", ") print(io, ")") -end \ No newline at end of file +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index bef5d94b62..208ab7a265 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -67,7 +67,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref). julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images julia> lay = Conv((5,5), 3 => 7, relu; bias=false) -Conv((5, 5), 3=>7, relu) +Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters julia> lay(xs) |> size (96, 96, 7, 50) @@ -98,7 +98,7 @@ end Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a convolutional layer with the given weight and bias. -Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)` +Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3 => 7, relu)` method. # Examples @@ -108,7 +108,7 @@ julia> weight = rand(3, 4, 5); julia> bias = zeros(5); julia> c1 = Conv(weight, bias, sigmoid) # expects 1 spatial dimension -Conv((3,), 4=>5, σ) +Conv((3,), 4 => 5, σ) # 65 parameters julia> c1(randn(100, 4, 64)) |> size (98, 5, 64) @@ -134,7 +134,7 @@ function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity end """ - convfilter(filter::Tuple, in=>out) + convfilter(filter::Tuple, in => out) Constructs a standard convolutional weight matrix with given `filter` and channels from `in` to `out`. @@ -159,11 +159,18 @@ end function Base.show(io::IO, l::Conv) print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) - l.σ == identity || print(io, ", ", l.σ) + print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight))) + _print_conv_opt(io, l) print(io, ")") end +function _print_conv_opt(io::IO, l) + l.σ == identity || print(io, ", ", l.σ) + all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) + all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride)) + all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation)) + l.bias == Zeros() && print(io, ", bias=false") +end """ ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) @@ -184,15 +191,15 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = ConvTranspose((5,5), 3 => 7, relu) -ConvTranspose((5, 5), 3=>7, relu) +ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters julia> lay(xs) |> size (104, 104, 7, 50) -julia> ConvTranspose((5,5), 3=>7, stride=2)(xs) |> size +julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size (203, 203, 7, 50) -julia> ConvTranspose((5,5), 3=>7, stride=3, pad=SamePad())(xs) |> size +julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size (300, 300, 7, 50) ``` """ @@ -209,7 +216,7 @@ end ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method. +Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method. """ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -255,8 +262,8 @@ end function Base.show(io::IO, l::ConvTranspose) print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) - l.σ == identity || print(io, ", ", l.σ) + print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1)) + _print_conv_opt(io, l) print(io, ")") end @@ -266,7 +273,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat end """ - DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) + DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Depthwise convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -284,7 +291,7 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false) -DepthwiseConv((5, 5), 3=>6, relu) +DepthwiseConv((5, 5), 3 => 6, relu, bias=false) # 150 parameters julia> lay(xs) |> size (96, 96, 6, 50) @@ -306,7 +313,7 @@ end DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method. +Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method. """ function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -327,7 +334,7 @@ end @functor DepthwiseConv """ - depthwiseconvfilter(filter::Tuple, in=>out) + depthwiseconvfilter(filter::Tuple, in => out) Constructs a depthwise convolutional weight array defined by `filter` and channels from `in` to `out`. @@ -348,8 +355,8 @@ end function Base.show(io::IO, l::DepthwiseConv) print(io, "DepthwiseConv(", size(l.weight)[1:end-2]) - print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end])) - l.σ == identity || print(io, ", ", l.σ) + print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end])) + _print_conv_opt(io, l) print(io, ")") end @@ -372,12 +379,12 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false) -CrossCor((5, 5), 3=>6, relu) +CrossCor((5, 5), 3 => 6, relu, bias=false) # 450 parameters julia> lay(xs) |> size (96, 96, 6, 50) -julia> CrossCor((5,5), 3=>7, stride=3, pad=(2,0))(xs) |> size +julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size (34, 32, 7, 50) ``` """ @@ -394,7 +401,7 @@ end CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method. +Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method. """ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -429,8 +436,8 @@ end function Base.show(io::IO, l::CrossCor) print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) - l.σ == identity || print(io, ", ", l.σ) + print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight))) + _print_conv_opt(io, l) print(io, ")") end @@ -529,8 +536,7 @@ See also [`MaxPool`](@ref), [`GlobalMeanPool`](@ref). ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); -julia> m = Chain(Conv((3,3), 3=>7), GlobalMaxPool()) -Chain(Conv((3, 3), 3=>7), GlobalMaxPool()) +julia> m = Chain(Conv((3,3), 3 => 7), GlobalMaxPool()); julia> m(xs) |> size (1, 1, 7, 50) @@ -567,8 +573,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps. ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); -julia> m = Chain(Conv((3,3), 3=>7), GlobalMeanPool()) -Chain(Conv((3, 3), 3=>7), GlobalMeanPool()) +julia> m = Chain(Conv((3,3), 3 => 7), GlobalMeanPool()); julia> m(xs) |> size (1, 1, 7, 50) @@ -611,8 +616,11 @@ See also [`Conv`](@ref), [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref), [`Global ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images -julia> m = Chain(Conv((5, 5), 3=>7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())) -Chain(Conv((5, 5), 3=>7), MaxPool((5, 5), pad=2)) +julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())) +Chain( + Conv((5, 5), 3 => 7, pad=2), # 532 parameters + MaxPool((5, 5), pad=2), +) julia> m[1](xs) |> size (100, 100, 7, 50) @@ -674,7 +682,10 @@ See also [`Conv`](@ref), [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). julia> xs = rand(Float32, 100, 100, 3, 50); julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad())) -Chain(Conv((5, 5), 3=>7), MeanPool((5, 5), pad=2)) +Chain( + Conv((5, 5), 3 => 7), # 532 parameters + MeanPool((5, 5), pad=2), +) julia> m[1](xs) |> size (96, 96, 7, 50) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index dbd67240c3..1d0e8f73b9 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -278,7 +278,7 @@ testmode!(m::BatchNorm, mode=true) = function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") - l.λ == identity || print(io, ", $(l.λ)") + (l.λ == identity) || print(io, ", $(l.λ)") hasaffine(l) || print(io, ", affine=false") print(io, ")") end @@ -443,8 +443,9 @@ testmode!(m::GroupNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::GroupNorm) + # print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G) print(io, "GroupNorm($(l.chs), $(l.G)") - l.λ == identity || print(io, ", $(l.λ)") + l.λ == identity || print(io, ", ", l.λ) hasaffine(l) || print(io, ", affine=false") print(io, ")") end diff --git a/src/layers/show.jl b/src/layers/show.jl new file mode 100644 index 0000000000..40d49dd9d1 --- /dev/null +++ b/src/layers/show.jl @@ -0,0 +1,110 @@ + +for T in [ + :Chain, :Parallel, :SkipConnection, :Recur # container types + ] + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end + end +end + +function _big_show(io::IO, obj, indent::Int=0) + children = trainable(obj) + if all(_show_leaflike, children) + _layer_show(io, obj, indent) + else + println(io, " "^indent, nameof(typeof(obj)), "(") + for c in children + _big_show(io, c, indent+2) + end + if indent == 0 + print(io, ")") + _big_finale(io, obj) + else + println(io, " "^indent, "),") + end + end +end + +_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: +_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv +_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell +_show_leaflike(::Diagonal) = true # appears inside LayerNorm + +for T in [ + :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, + :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, + ] + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end + end +end + +function _layer_show(io::IO, layer, indent::Int=0) + str = sprint(show, layer, context=io) + print(io, " "^indent, str, indent==0 ? "" : ",") + if !isempty(params(layer)) + print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) + printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) + nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) + if nonparam > 0 + printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) + end + _nan_show(io, params(layer)) + end + indent==0 || println(io) +end + +function _big_finale(io::IO, m) + ps = params(m) + if length(ps) > 2 + pars = underscorise(sum(length, ps)) + bytes = Base.format_bytes(Base.summarysize(m)) + noncnt = _childarray_sum(_->1, m) - length(ps) + if noncnt > 0 + nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) + printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) + println(io, pars, " parameters,") + printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) + print(io, bytes, ".") + else + printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) + print(io, pars, " parameters, ", bytes, ".") + end + end +end + +_childarray_sum(f, x::AbstractArray) = f(x) +_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) + +# utility functions + +underscorise(n::Integer) = + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + +function _nan_show(io::IO, x) + if !isempty(x) && _all(iszero, x) + printstyled(io, " (all zero)", color=:cyan) + elseif _any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif _any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + end +end + +_any(f, xs::AbstractArray{<:Number}) = any(f, xs) +# _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) +_any(f, xs) = any(x -> _any(f, x), xs) +_any(f, x::Number) = f(x) +# _any(f, x) = false + +_all(f, xs) = !_any(!f, xs) diff --git a/src/utils.jl b/src/utils.jl index 06b2bb01b0..73acd45f96 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,14 +13,12 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r # Examples ```jldoctest -julia> layer = Dense(10, 20) -Dense(10, 20) +julia> layer = Dense(10, 20); julia> Flux.nfan(size(layer.W)) (10, 20) -julia> layer = Conv((3, 3), 2=>10) -Conv((3, 3), 2=>10) +julia> layer = Conv((3, 3), 2=>10); julia> Flux.nfan(size(layer.weight)) (18, 90) @@ -506,7 +504,7 @@ julia> Flux.chunk(1:10, 3) 9:10 julia> Flux.chunk(collect(1:10), 3) -3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: +3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: [1, 2, 3, 4] [5, 6, 7, 8] [9, 10] @@ -720,19 +718,25 @@ over specific modules or subsets of the parameters # Examples ```jldoctest -julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)) -Chain(Dense(784, 64), BatchNorm(64, relu)) +julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)); julia> m2 = Chain(m1, Dense(64, 10)) -Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) +Chain( + Chain( + Dense(784, 64), # 50_240 parameters + BatchNorm(64, relu), # 128 parameters, plus 128 + ), + Dense(64, 10), # 650 parameters +) # Total: 6 trainable arrays, 51_018 parameters, + # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB. julia> Flux.modules(m2) 5-element Vector{Any}: - Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) - Chain(Dense(784, 64), BatchNorm(64, relu)) - Dense(784, 64) - BatchNorm(64, relu) - Dense(64, 10) + Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable + Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable + Dense(784, 64) # 50_240 parameters + BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable + Dense(64, 10) # 650 parameters julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense) L2 (generic function with 1 method) @@ -760,6 +764,7 @@ julia> loss() = rand(); julia> trigger = Flux.patience(() -> loss() < 1, 3); + julia> Flux.@epochs 10 begin trigger() && break end @@ -796,6 +801,7 @@ julia> loss = let l = 0 julia> es = Flux.early_stopping(loss, 3); + julia> Flux.@epochs 10 begin es() && break end @@ -836,6 +842,7 @@ julia> f = let v = 10 julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18); + julia> Flux.@epochs 10 begin trigger() && break end diff --git a/test/layers/show.jl b/test/layers/show.jl new file mode 100644 index 0000000000..9c689eba49 --- /dev/null +++ b/test/layers/show.jl @@ -0,0 +1,70 @@ + +@testset "layer printing" begin # 2-arg show, defined with layes + + @test repr(Dense(2,3)) == "Dense(2, 3)" + @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))" + +end +@testset "nested model printing" begin # 3-arg show, defined in show.jl + + # Dense -- has parameter count, but not inside a matrix + + toplevel_dense = repr("text/plain", Dense(2,3)) + @test occursin("Dense(2, 3)", toplevel_dense) + @test occursin("# 9 parameters", toplevel_dense) + + @test Meta.isexpr(Meta.parse(toplevel_dense), :call) # comment is ignored + + vector_dense = repr("text/plain", [Dense(2,3), Dense(2,3)]) + @test occursin("Dense(2, 3)", vector_dense) + @test occursin("# 9 parameters", vector_dense) + + matrix_dense = repr("text/plain", fill(Dense(2,3), 3, 3)) + @test occursin("Dense(2, 3)", matrix_dense) + @test !occursin("# 9 parameters", matrix_dense) + + tuple_dense = repr("text/plain", tuple(Dense(2,3))) + @test occursin("Dense(2, 3)", tuple_dense) + @test !occursin("# 9 parameters", tuple_dense) + + # Chain -- gets split over lines at top level only + + toplevel_chain = repr("text/plain", Chain(Dense(2,3))) + @test occursin("Chain(\n Dense(2, 3)", toplevel_chain) + @test occursin("# 9 parameters", toplevel_chain) + @test !occursin("# Total:", toplevel_chain) + + vector_chain = repr("text/plain", [Chain(Dense(2,3)), Chain(Dense(2,3))]) + @test occursin("Chain(Dense(2, 3))", vector_chain) + @test occursin("# 9 parameters", vector_chain) + @test !occursin("# Total:", vector_chain) + + matrix_chain = repr("text/plain", fill(Chain(Dense(2,3)), 3,3)) + @test occursin("Chain(Dense(2, 3))", matrix_chain) + @test !occursin("# 9 parameters", matrix_chain) + @test !occursin("# Total:", matrix_chain) + + # ... and only long enough chains get + + longchain = Chain(Dense(2, 3), Dense(3, 4), Dense(4, 5), softmax) + + toplevel_longchain = repr("text/plain", longchain) + @test occursin("Chain(\n Dense(2, 3)", toplevel_longchain) + @test occursin("# 9 parameters", toplevel_longchain) + @test occursin("# Total: 6 arrays, 50 parameters", toplevel_longchain) + + vector_longchain = repr("text/plain", [longchain, longchain]) # pretty ugly in reality + @test occursin("Chain(Dense(2, 3)", vector_longchain) + @test occursin("# 50 parameters", vector_longchain) + @test !occursin("# 9 parameters", vector_longchain) + @test !occursin("# Total:", vector_longchain) + + matrix_longchain = repr("text/plain", fill(longchain, 3,3)) + @test occursin("Chain(Dense(2, 3)", matrix_longchain) + @test !occursin("# 9 parameters", matrix_longchain) + @test !occursin("# Total:", matrix_longchain) + + @test Meta.isexpr(Meta.parse(toplevel_longchain), :call) # comments are ignored + @test Meta.parse(toplevel_longchain).args[1] == :Chain + +end diff --git a/test/runtests.jl b/test/runtests.jl index a40433d0f1..0d02323807 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end include("layers/recurrent.jl") include("layers/conv.jl") include("layers/upsample.jl") + include("layers/show.jl") end @testset "outputsize" begin