From e5eeb57740b0f8bf53c1de3701a0eeeb36b5bf91 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 15 Jan 2021 19:43:27 +0100 Subject: [PATCH 01/25] fancy show for Chain & layers --- src/layers/basic.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 31b25ef1df..1689efe551 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -421,4 +421,49 @@ function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") join(io, m.layers, ", ") print(io, ")") -end \ No newline at end of file +end + +Base.show(io::IO, m::MIME"text/plain", c::Chain) = _big_show(io, c) + +function _big_show(io::IO, c::Union{Chain, Parallel, SkipConnection}, indent=0) + print(io, " "^indent, nameof(typeof(c)), "(") + c isa Chain ? println(io) : println(io, c.connection, ",") + if c.layers isa Tuple + for x in c.layers + _big_show(io, x, indent+2) + end + else + _big_show(io, c.layers, indent+2) + end + print(io, " "^indent, ")") + indent == 0 ? _big_finale(io, params(c)) : println(io, ",") +end + +function _big_show(io::IO, layer, indent=0) + str = sprint(show, layer, context=nothing) + print(io, " "^indent, str, ",") + if !isempty(params(layer)) + print(" "^(31 - indent - length(str))) + pars = underscorise(sum(length, params(layer))) + printstyled(io, "# ", pars, " parameters", color=:light_black) + if !all(x -> all(isfinite, x), params(layer)) + printstyled(io, " (some NaN or Inf)", color=:red) + elseif all(x -> all(iszero, x), params(layer)) + printstyled(io, " (all zero)", color=:light_black) + end + end + println(io) +end + +function _big_finale(io::IO, ps) + num = length(ps) + num < 3 && return println(io) + pars = underscorise(sum(length, ps)) + bytes = sum(sizeof, ps) + print(io, " "^15) + printstyled(io, "# Total: ", num, " arrays, "; color=:light_black) + printstyled(io, pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) +end + +underscorise(n::Integer) = + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') From 2a9a01135dfc78055c3f9db4b5e0d3525421b9a3 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 15 Jan 2021 19:43:53 +0100 Subject: [PATCH 02/25] pirate show for Params & Grads --- src/layers/basic.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1689efe551..184bc0e266 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -467,3 +467,58 @@ end underscorise(n::Integer) = join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + +Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, "Params", true) + +function _param_show(io::IO, p, name::String, iter::Bool) + println(io, name, "(") + spad = maximum(length∘summary, p) + ipad = length(string(length(p))) + 2 + for (i,x) in enumerate(p) + if iter + printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) + end + str = sprint(show, x) + str = length(str) < 32 ? str : str[1:32] * "…" + print(io, " ", rpad(summary(x), spad), " ", str) + if any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + elseif !isempty(x) && all(iszero, x) + printstyled(io, " (all zero)", color=:light_black) + end + println(io) + end + print(io, ")") + pars = underscorise(sum(length, p)) + bytes = sum(sizeof, p) + printstyled(io, " "^15, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) +end + +function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) + println(io, "Zygote.Grads(") + pars, bytes, spad = 0, 0, 0 + for k in keys(g.grads) + x = + pars += length(g[k]) + bytes += sizeof(g[k]) + spad = max(spad, length(summary(g[k]))) + end + for k in keys(g.grads) + x = g[k] + str = sprint(show, x) + str = length(str) < 32 ? str : str[1:32] * "…" + print(io, " ", rpad(summary(x), spad), " ", str) + if any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + elseif !isempty(x) && all(iszero, x) + printstyled(io, " (all zero)", color=:light_black) + end + println(io) + end + print(io, ")") + printstyled(io, " "^15, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) +end From 6d051ba98259e4d4556d29e8751da672ade979b1 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 15 Jan 2021 22:04:47 +0100 Subject: [PATCH 03/25] tweaks & bug fixes --- src/layers/basic.jl | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 184bc0e266..1c87393e44 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -423,17 +423,19 @@ function Base.show(io::IO, m::Parallel) print(io, ")") end + Base.show(io::IO, m::MIME"text/plain", c::Chain) = _big_show(io, c) function _big_show(io::IO, c::Union{Chain, Parallel, SkipConnection}, indent=0) print(io, " "^indent, nameof(typeof(c)), "(") - c isa Chain ? println(io) : println(io, c.connection, ",") - if c.layers isa Tuple + c isa Parallel ? println(io, c.connection, ",") : println(io) # Parallel's connection is 1st arg + if c isa SkipConnection + _big_show(io, c.layers, indent+2) + _big_show(io, c.connection, indent+2) # SkipConnection's connection is 2nd arg + else for x in c.layers _big_show(io, x, indent+2) end - else - _big_show(io, c.layers, indent+2) end print(io, " "^indent, ")") indent == 0 ? _big_finale(io, params(c)) : println(io, ",") @@ -443,7 +445,7 @@ function _big_show(io::IO, layer, indent=0) str = sprint(show, layer, context=nothing) print(io, " "^indent, str, ",") if !isempty(params(layer)) - print(" "^(31 - indent - length(str))) + print(" "^max(2, 39 - indent - length(str))) pars = underscorise(sum(length, params(layer))) printstyled(io, "# ", pars, " parameters", color=:light_black) if !all(x -> all(isfinite, x), params(layer)) @@ -460,7 +462,7 @@ function _big_finale(io::IO, ps) num < 3 && return println(io) pars = underscorise(sum(length, ps)) bytes = sum(sizeof, ps) - print(io, " "^15) + print(io, " "^19) printstyled(io, "# Total: ", num, " arrays, "; color=:light_black) printstyled(io, pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) end @@ -471,29 +473,32 @@ underscorise(n::Integer) = Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, "Params", true) function _param_show(io::IO, p, name::String, iter::Bool) + length(p) == 0 && return print(io, name, "([])") println(io, name, "(") - spad = maximum(length∘summary, p) ipad = length(string(length(p))) + 2 - for (i,x) in enumerate(p) + spad = min(40-6-ipad, maximum(length∘summary, p)) + wid = get(io, :displaysize, displaysize())[2] + for (i,x) in enumerate(p) if iter printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) end - str = sprint(show, x) - str = length(str) < 32 ? str : str[1:32] * "…" - print(io, " ", rpad(summary(x), spad), " ", str) + desc = Base._truncate_at_width_or_chars(summary(x), spad) + data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) + str = Base._truncate_at_width_or_chars(data, min(30, wid-40-14)) + print(io, " ", rpad(desc, spad), " ", str) if any(isnan, x) printstyled(io, " (some NaN)", color=:red) elseif any(isinf, x) printstyled(io, " (some Inf)", color=:red) elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:light_black) + printstyled(io, " (all zero)", color=:cyan) end println(io) end print(io, ")") pars = underscorise(sum(length, p)) bytes = sum(sizeof, p) - printstyled(io, " "^15, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) + printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) end function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) @@ -515,10 +520,14 @@ function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) elseif any(isinf, x) printstyled(io, " (some Inf)", color=:red) elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:light_black) + printstyled(io, " (all zero)", color=:cyan) end println(io) end print(io, ")") - printstyled(io, " "^15, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) + printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) end + + + + From 60a8f6b1904a85c133508823cf09295248def33d Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 15 Jan 2021 22:44:39 +0100 Subject: [PATCH 04/25] more recursive big_show --- src/functor.jl | 3 +-- src/layers/basic.jl | 26 +++++++++++--------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index dd6374353a..30305e2173 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,8 +1,7 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky using Zygote: IdSet -import Functors: @functor, functor, fmap -import Functors +import Functors: Functors, @functor, functor, fmap trainable(m) = functor(m)[1] diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1c87393e44..d80201b279 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -426,22 +426,20 @@ end Base.show(io::IO, m::MIME"text/plain", c::Chain) = _big_show(io, c) -function _big_show(io::IO, c::Union{Chain, Parallel, SkipConnection}, indent=0) - print(io, " "^indent, nameof(typeof(c)), "(") - c isa Parallel ? println(io, c.connection, ",") : println(io) # Parallel's connection is 1st arg - if c isa SkipConnection - _big_show(io, c.layers, indent+2) - _big_show(io, c.connection, indent+2) # SkipConnection's connection is 2nd arg - else - for x in c.layers - _big_show(io, x, indent+2) - end +function _big_show(io::IO, obj, indent=0) + children = Flux.trainable(obj) + if all(c -> isleaf(c) || isa(c,Tuple), children) # need isa(c,Tuple) to get Conv right + return _layer_show(io, obj, indent) + end + println(io, " "^indent, nameof(typeof(obj)), "(") + for c in children + _big_show(io, c, indent+2) end print(io, " "^indent, ")") - indent == 0 ? _big_finale(io, params(c)) : println(io, ",") + indent == 0 ? _big_finale(io, params(obj)) : println(io, ",") end -function _big_show(io::IO, layer, indent=0) +function _layer_show(io::IO, layer, indent=0) str = sprint(show, layer, context=nothing) print(io, " "^indent, str, ",") if !isempty(params(layer)) @@ -451,7 +449,7 @@ function _big_show(io::IO, layer, indent=0) if !all(x -> all(isfinite, x), params(layer)) printstyled(io, " (some NaN or Inf)", color=:red) elseif all(x -> all(iszero, x), params(layer)) - printstyled(io, " (all zero)", color=:light_black) + printstyled(io, " (all zero)", color=:cyan) end end println(io) @@ -529,5 +527,3 @@ function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) end - - From f84e0e5ac59af757f324090fe9ff64cd5c3a9af6 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 15 Jan 2021 23:50:56 +0100 Subject: [PATCH 05/25] close brackets without wasting lines --- src/layers/basic.jl | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index d80201b279..a7b48abede 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -423,27 +423,36 @@ function Base.show(io::IO, m::Parallel) print(io, ")") end +for T in [ + :Chain, :Parallel, :SkipConnection, + :Conv, :ConvTranspose, :CrossCor, :Dense, + :BatchNorm, :LayerNorm, + ] + @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) +end -Base.show(io::IO, m::MIME"text/plain", c::Chain) = _big_show(io, c) - -function _big_show(io::IO, obj, indent=0) +function _big_show(io::IO, obj, indent=0, toclose=0) children = Flux.trainable(obj) if all(c -> isleaf(c) || isa(c,Tuple), children) # need isa(c,Tuple) to get Conv right - return _layer_show(io, obj, indent) + return _layer_show(io, obj, indent, toclose) end println(io, " "^indent, nameof(typeof(obj)), "(") - for c in children - _big_show(io, c, indent+2) + for (i,c) in enumerate(children) + close = i==length(children) && indent>0 + _big_show(io, c, indent+2, close ? toclose+1 : 0) + end + if indent == 0 + print(io, ")") + _big_finale(io, params(obj)) end - print(io, " "^indent, ")") - indent == 0 ? _big_finale(io, params(obj)) : println(io, ",") end -function _layer_show(io::IO, layer, indent=0) - str = sprint(show, layer, context=nothing) - print(io, " "^indent, str, ",") +function _layer_show(io::IO, layer, indent=0, toclose=0) + str = sprint(show, layer, context=nothing) * ",)"^toclose + print(io, " "^indent, str, indent==0 ? "" : ",") + tab = indent==0 ? 20 : 39 # when inside Chain, move all parameter counts out to 40 if !isempty(params(layer)) - print(" "^max(2, 39 - indent - length(str))) + print(" "^max(2, tab - indent - length(str))) pars = underscorise(sum(length, params(layer))) printstyled(io, "# ", pars, " parameters", color=:light_black) if !all(x -> all(isfinite, x), params(layer)) From 12bad6d72c17333187b48c6a92f9b82862bf21da Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 11:57:15 +0100 Subject: [PATCH 06/25] tweaks --- src/layers/basic.jl | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a7b48abede..a5f546636c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -425,8 +425,8 @@ end for T in [ :Chain, :Parallel, :SkipConnection, - :Conv, :ConvTranspose, :CrossCor, :Dense, - :BatchNorm, :LayerNorm, + :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, + :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) end @@ -447,6 +447,9 @@ function _big_show(io::IO, obj, indent=0, toclose=0) end end +# Opt out of being printed as a container: +_big_show(io::IO, l::LayerNorm, i=0, t=0) = _layer_show(io, l, i, t) + function _layer_show(io::IO, layer, indent=0, toclose=0) str = sprint(show, layer, context=nothing) * ",)"^toclose print(io, " "^indent, str, indent==0 ? "" : ",") @@ -461,7 +464,7 @@ function _layer_show(io::IO, layer, indent=0, toclose=0) printstyled(io, " (all zero)", color=:cyan) end end - println(io) + indent==0 || println(io) end function _big_finale(io::IO, ps) @@ -477,21 +480,21 @@ end underscorise(n::Integer) = join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') -Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, "Params", true) +Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, true) -function _param_show(io::IO, p, name::String, iter::Bool) - length(p) == 0 && return print(io, name, "([])") - println(io, name, "(") +function _param_show(io::IO, p, iter::Bool) + length(p) == 0 && return print(io, typeof(p), "([])") + println(io, typeof(p), "([") ipad = length(string(length(p))) + 2 spad = min(40-6-ipad, maximum(length∘summary, p)) - wid = get(io, :displaysize, displaysize())[2] + wid = get(io, :displaysize, (0,100))[2] for (i,x) in enumerate(p) if iter printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) end desc = Base._truncate_at_width_or_chars(summary(x), spad) data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) - str = Base._truncate_at_width_or_chars(data, min(30, wid-40-14)) + str = Base._truncate_at_width_or_chars(data, min(30, wid-40-12)) print(io, " ", rpad(desc, spad), " ", str) if any(isnan, x) printstyled(io, " (some NaN)", color=:red) @@ -502,10 +505,10 @@ function _param_show(io::IO, p, name::String, iter::Bool) end println(io) end - print(io, ")") + print(io, "])") pars = underscorise(sum(length, p)) - bytes = sum(sizeof, p) - printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) + bytes = Base.format_bytes(sum(sizeof, p)) + printstyled(io, " "^18, "# Total: ", pars, " parameters, ", bytes; color=:light_black) end function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) From d3e2ee89626cb7a3e8f3a6c7264027c6783d5961 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 11:57:59 +0100 Subject: [PATCH 07/25] don't print BatchNorm as if it takes a keyword, ditto GroupNorm --- src/layers/normalise.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index dbd67240c3..8e7e4b0e71 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -278,8 +278,8 @@ testmode!(m::BatchNorm, mode=true) = function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") - l.λ == identity || print(io, ", $(l.λ)") - hasaffine(l) || print(io, ", affine=false") + (l.λ == identity) || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") # ?? print(io, ")") end @@ -443,8 +443,9 @@ testmode!(m::GroupNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::GroupNorm) + # print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G) print(io, "GroupNorm($(l.chs), $(l.G)") - l.λ == identity || print(io, ", $(l.λ)") + l.λ == identity || print(io, ", ", l.λ) hasaffine(l) || print(io, ", affine=false") print(io, ")") end From d76da82c73bd4a8791cf48634faf4082e4396b24 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 11:59:06 +0100 Subject: [PATCH 08/25] do print Conv keywords, including bias --- src/layers/conv.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index bef5d94b62..2ca3c7990c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -160,10 +160,17 @@ end function Base.show(io::IO, l::Conv) print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2]) print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) - l.σ == identity || print(io, ", ", l.σ) + _print_conv_opt(io, l) print(io, ")") end +function _print_conv_opt(io::IO, l) + l.σ == identity || print(io, ", ", l.σ) + all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) + all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride)) + all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation)) + l.bias == Zeros() && print(io, ", bias=false") +end """ ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) @@ -256,7 +263,7 @@ end function Base.show(io::IO, l::ConvTranspose) print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) - l.σ == identity || print(io, ", ", l.σ) + _print_conv_opt(io, l) print(io, ")") end @@ -349,7 +356,7 @@ end function Base.show(io::IO, l::DepthwiseConv) print(io, "DepthwiseConv(", size(l.weight)[1:end-2]) print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end])) - l.σ == identity || print(io, ", ", l.σ) + _print_conv_opt(io, l) print(io, ")") end @@ -430,7 +437,7 @@ end function Base.show(io::IO, l::CrossCor) print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) - l.σ == identity || print(io, ", ", l.σ) + _print_conv_opt(io, l) print(io, ")") end From 3316366cd1e325dccd7c10436f0781e3fa604cfb Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 12:10:10 +0100 Subject: [PATCH 09/25] move to file show.jl --- src/Flux.jl | 1 + src/layers/basic.jl | 117 ------------------------------------------- src/layers/show.jl | 118 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 117 deletions(-) create mode 100644 src/layers/show.jl diff --git a/src/Flux.jl b/src/Flux.jl index 0689beb278..025d69ae36 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -43,6 +43,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") +include("layers/show.jl") include("outputsize.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a5f546636c..629d218924 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -422,120 +422,3 @@ function Base.show(io::IO, m::Parallel) join(io, m.layers, ", ") print(io, ")") end - -for T in [ - :Chain, :Parallel, :SkipConnection, - :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) -end - -function _big_show(io::IO, obj, indent=0, toclose=0) - children = Flux.trainable(obj) - if all(c -> isleaf(c) || isa(c,Tuple), children) # need isa(c,Tuple) to get Conv right - return _layer_show(io, obj, indent, toclose) - end - println(io, " "^indent, nameof(typeof(obj)), "(") - for (i,c) in enumerate(children) - close = i==length(children) && indent>0 - _big_show(io, c, indent+2, close ? toclose+1 : 0) - end - if indent == 0 - print(io, ")") - _big_finale(io, params(obj)) - end -end - -# Opt out of being printed as a container: -_big_show(io::IO, l::LayerNorm, i=0, t=0) = _layer_show(io, l, i, t) - -function _layer_show(io::IO, layer, indent=0, toclose=0) - str = sprint(show, layer, context=nothing) * ",)"^toclose - print(io, " "^indent, str, indent==0 ? "" : ",") - tab = indent==0 ? 20 : 39 # when inside Chain, move all parameter counts out to 40 - if !isempty(params(layer)) - print(" "^max(2, tab - indent - length(str))) - pars = underscorise(sum(length, params(layer))) - printstyled(io, "# ", pars, " parameters", color=:light_black) - if !all(x -> all(isfinite, x), params(layer)) - printstyled(io, " (some NaN or Inf)", color=:red) - elseif all(x -> all(iszero, x), params(layer)) - printstyled(io, " (all zero)", color=:cyan) - end - end - indent==0 || println(io) -end - -function _big_finale(io::IO, ps) - num = length(ps) - num < 3 && return println(io) - pars = underscorise(sum(length, ps)) - bytes = sum(sizeof, ps) - print(io, " "^19) - printstyled(io, "# Total: ", num, " arrays, "; color=:light_black) - printstyled(io, pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) -end - -underscorise(n::Integer) = - join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') - -Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, true) - -function _param_show(io::IO, p, iter::Bool) - length(p) == 0 && return print(io, typeof(p), "([])") - println(io, typeof(p), "([") - ipad = length(string(length(p))) + 2 - spad = min(40-6-ipad, maximum(length∘summary, p)) - wid = get(io, :displaysize, (0,100))[2] - for (i,x) in enumerate(p) - if iter - printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) - end - desc = Base._truncate_at_width_or_chars(summary(x), spad) - data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) - str = Base._truncate_at_width_or_chars(data, min(30, wid-40-12)) - print(io, " ", rpad(desc, spad), " ", str) - if any(isnan, x) - printstyled(io, " (some NaN)", color=:red) - elseif any(isinf, x) - printstyled(io, " (some Inf)", color=:red) - elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:cyan) - end - println(io) - end - print(io, "])") - pars = underscorise(sum(length, p)) - bytes = Base.format_bytes(sum(sizeof, p)) - printstyled(io, " "^18, "# Total: ", pars, " parameters, ", bytes; color=:light_black) -end - -function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) - println(io, "Zygote.Grads(") - pars, bytes, spad = 0, 0, 0 - for k in keys(g.grads) - x = - pars += length(g[k]) - bytes += sizeof(g[k]) - spad = max(spad, length(summary(g[k]))) - end - for k in keys(g.grads) - x = g[k] - str = sprint(show, x) - str = length(str) < 32 ? str : str[1:32] * "…" - print(io, " ", rpad(summary(x), spad), " ", str) - if any(isnan, x) - printstyled(io, " (some NaN)", color=:red) - elseif any(isinf, x) - printstyled(io, " (some Inf)", color=:red) - elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:cyan) - end - println(io) - end - print(io, ")") - printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) -end - - diff --git a/src/layers/show.jl b/src/layers/show.jl new file mode 100644 index 0000000000..1eba82f353 --- /dev/null +++ b/src/layers/show.jl @@ -0,0 +1,118 @@ + +for T in [ + :Chain, :Parallel, :SkipConnection, + :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, + :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, + ] + @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) +end + +function _big_show(io::IO, obj, indent=0, toclose=0) + children = Flux.trainable(obj) + if all(c -> isleaf(c) || isa(c,Tuple), children) # need isa(c,Tuple) to get Conv right + return _layer_show(io, obj, indent, toclose) + end + println(io, " "^indent, nameof(typeof(obj)), "(") + for (i,c) in enumerate(children) + close = i==length(children) && indent>0 + _big_show(io, c, indent+2, close ? toclose+1 : 0) + end + if indent == 0 + print(io, ")") + _big_finale(io, params(obj)) + end +end + +# Opt out of being printed as a container: +_big_show(io::IO, l::LayerNorm, i=0, t=0) = _layer_show(io, l, i, t) + +# used both within Chain printing, and alone at top level. +function _layer_show(io::IO, layer, indent=0, toclose=0) + str = sprint(show, layer, context=nothing) * ",)"^toclose + print(io, " "^indent, str, indent==0 ? "" : ",") + tab = indent==0 ? 20 : 39 # when inside Chain, move all parameter counts out to 40 + if !isempty(params(layer)) + print(" "^max(2, tab - indent - length(str))) + pars = underscorise(sum(length, params(layer))) + printstyled(io, "# ", pars, " parameters", color=:light_black) + if !all(x -> all(isfinite, x), params(layer)) + printstyled(io, " (some NaN or Inf)", color=:red) + elseif all(x -> all(iszero, x), params(layer)) + printstyled(io, " (all zero)", color=:cyan) + end + end + indent==0 || println(io) +end + +function _big_finale(io::IO, ps) + num = length(ps) + num < 3 && return println(io) + pars = underscorise(sum(length, ps)) + bytes = sum(sizeof, ps) + print(io, " "^19) + printstyled(io, "# Total: ", num, " arrays, "; color=:light_black) + printstyled(io, pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) +end + +underscorise(n::Integer) = + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + +# Zygote's containers + +Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, true) + +function _param_show(io::IO, p, iter::Bool) + length(p) == 0 && return print(io, typeof(p), "([])") + println(io, typeof(p), "([") + ipad = length(string(length(p))) + 2 + spad = min(40-6-ipad, maximum(length∘summary, p)) + wid = get(io, :displaysize, (0,100))[2] + for (i,x) in enumerate(p) + if iter + printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) + end + desc = Base._truncate_at_width_or_chars(summary(x), spad) + data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) + str = Base._truncate_at_width_or_chars(data, min(30, wid-40-12)) + print(io, " ", rpad(desc, spad), " ", str) + if any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + elseif !isempty(x) && all(iszero, x) + printstyled(io, " (all zero)", color=:cyan) + end + println(io) + end + print(io, "])") + pars = underscorise(sum(length, p)) + bytes = Base.format_bytes(sum(sizeof, p)) + printstyled(io, " "^18, "# Total: ", pars, " parameters, ", bytes; color=:light_black) +end + +function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) + println(io, "Zygote.Grads(") + pars, bytes, spad = 0, 0, 0 + for k in keys(g.grads) + pars += length(g[k]) + bytes += sizeof(g[k]) + spad = max(spad, length(summary(g[k]))) + end + for k in keys(g.grads) + x = g[k] + str = sprint(show, x) + str = length(str) < 32 ? str : str[1:32] * "…" + print(io, " ", rpad(summary(x), spad), " ", str) + if any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + elseif !isempty(x) && all(iszero, x) + printstyled(io, " (all zero)", color=:cyan) + end + println(io) + end + print(io, ")") + printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) +end + From a442b71a04a14cfe40f8fe04775a6f300071a448 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 16:23:08 +0100 Subject: [PATCH 10/25] lots of fixes --- src/layers/basic.jl | 2 + src/layers/show.jl | 95 ++++++++++++++++++++++----------------------- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 629d218924..7456de1447 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -417,6 +417,8 @@ Parallel(connection, layers...) = Parallel(connection, layers) Base.getindex(m::Parallel, i::Integer) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) +trainable(m::Parallel) = (m.connection, m.layers...) + function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") join(io, m.layers, ", ") diff --git a/src/layers/show.jl b/src/layers/show.jl index 1eba82f353..b38cfa9b56 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,15 +1,15 @@ for T in [ - :Chain, :Parallel, :SkipConnection, + :Chain, :Parallel, :SkipConnection, :Recur, :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) end -function _big_show(io::IO, obj, indent=0, toclose=0) - children = Flux.trainable(obj) - if all(c -> isleaf(c) || isa(c,Tuple), children) # need isa(c,Tuple) to get Conv right +function _big_show(io::IO, obj, indent::Int=0, toclose::Int=0) + children = trainable(obj) + if all(c -> isleaf(c) || _show_leaflike(c), children) return _layer_show(io, obj, indent, toclose) end println(io, " "^indent, nameof(typeof(obj)), "(") @@ -23,65 +23,47 @@ function _big_show(io::IO, obj, indent=0, toclose=0) end end -# Opt out of being printed as a container: -_big_show(io::IO, l::LayerNorm, i=0, t=0) = _layer_show(io, l, i, t) +_show_leaflike(::Any) = false +_show_leaflike(::Tuple{Vararg{<:Number}}) = true # stride of Conv +_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # parameters of LSTMcell +_show_leaflike(::Diagonal) = true # appears inside LayerNorm # used both within Chain printing, and alone at top level. -function _layer_show(io::IO, layer, indent=0, toclose=0) +function _layer_show(io::IO, layer, indent::Int=0, toclose::Int=0) str = sprint(show, layer, context=nothing) * ",)"^toclose print(io, " "^indent, str, indent==0 ? "" : ",") - tab = indent==0 ? 20 : 39 # when inside Chain, move all parameter counts out to 40 if !isempty(params(layer)) - print(" "^max(2, tab - indent - length(str))) - pars = underscorise(sum(length, params(layer))) - printstyled(io, "# ", pars, " parameters", color=:light_black) - if !all(x -> all(isfinite, x), params(layer)) - printstyled(io, " (some NaN or Inf)", color=:red) - elseif all(x -> all(iszero, x), params(layer)) - printstyled(io, " (all zero)", color=:cyan) - end + print(" "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) + printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters", color=:light_black) + _nan_show(io, params(layer)) end indent==0 || println(io) end function _big_finale(io::IO, ps) - num = length(ps) - num < 3 && return println(io) + length(ps) < 3 && return pars = underscorise(sum(length, ps)) - bytes = sum(sizeof, ps) - print(io, " "^19) - printstyled(io, "# Total: ", num, " arrays, "; color=:light_black) - printstyled(io, pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) + bytes = Base.format_bytes(sum(sizeof, ps)) + printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) end -underscorise(n::Integer) = - join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') - # Zygote's containers -Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p, true) +Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p) -function _param_show(io::IO, p, iter::Bool) +function _param_show(io::IO, p) length(p) == 0 && return print(io, typeof(p), "([])") println(io, typeof(p), "([") ipad = length(string(length(p))) + 2 spad = min(40-6-ipad, maximum(length∘summary, p)) - wid = get(io, :displaysize, (0,100))[2] - for (i,x) in enumerate(p) - if iter - printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) - end + wid = get(io, :displaysize, (0,100))[2] # not certain this is working + for (i,x) in enumerate(p) + printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) desc = Base._truncate_at_width_or_chars(summary(x), spad) data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) str = Base._truncate_at_width_or_chars(data, min(30, wid-40-12)) print(io, " ", rpad(desc, spad), " ", str) - if any(isnan, x) - printstyled(io, " (some NaN)", color=:red) - elseif any(isinf, x) - printstyled(io, " (some Inf)", color=:red) - elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:cyan) - end + _nan_show(io, x) println(io) end print(io, "])") @@ -100,19 +82,34 @@ function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) end for k in keys(g.grads) x = g[k] - str = sprint(show, x) - str = length(str) < 32 ? str : str[1:32] * "…" - print(io, " ", rpad(summary(x), spad), " ", str) - if any(isnan, x) - printstyled(io, " (some NaN)", color=:red) - elseif any(isinf, x) - printstyled(io, " (some Inf)", color=:red) - elseif !isempty(x) && all(iszero, x) - printstyled(io, " (all zero)", color=:cyan) - end + str = Base._truncate_at_width_or_chars(sprint(show, x), 32) # ?? + # print(io, " ", rpad(summary(x), spad), " ", str) + print(io, " ", rpad(summary(x), 20-4), " ", str) + _nan_show(io, x) println(io) end print(io, ")") printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) end +# utility functions + +underscorise(n::Integer) = + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + +function _nan_show(io::IO, x) + if !isempty(x) && _all(iszero, x) + printstyled(io, " (all zero)", color=:cyan) + elseif _any(isnan, x) + printstyled(io, " (some NaN)", color=:red) + elseif _any(isinf, x) + printstyled(io, " (some Inf)", color=:red) + end +end + +_any(f, xs::AbstractArray{<:Number}) = any(f, xs) +_any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) +_any(f, x::Number) = f(x) +_any(f, x) = false + +_all(f, xs) = !_any(!f, xs) From 1731feadf4ac738fa796e17f16e40666f5e4990c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 21:02:23 +0100 Subject: [PATCH 11/25] restore closing brackets to their own lines --- src/layers/show.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index b38cfa9b56..b6ff6da548 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -7,19 +7,20 @@ for T in [ @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) end -function _big_show(io::IO, obj, indent::Int=0, toclose::Int=0) +function _big_show(io::IO, obj, indent::Int=0) children = trainable(obj) if all(c -> isleaf(c) || _show_leaflike(c), children) - return _layer_show(io, obj, indent, toclose) + return _layer_show(io, obj, indent) end println(io, " "^indent, nameof(typeof(obj)), "(") for (i,c) in enumerate(children) - close = i==length(children) && indent>0 - _big_show(io, c, indent+2, close ? toclose+1 : 0) + _big_show(io, c, indent+2) end if indent == 0 print(io, ")") _big_finale(io, params(obj)) + else + println(io, " "^indent, "),") end end @@ -29,8 +30,9 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # parameters of LSTMcel _show_leaflike(::Diagonal) = true # appears inside LayerNorm # used both within Chain printing, and alone at top level. -function _layer_show(io::IO, layer, indent::Int=0, toclose::Int=0) - str = sprint(show, layer, context=nothing) * ",)"^toclose +function _layer_show(io::IO, layer, indent::Int=0) + # str = sprint(show, layer, context=io) + str = string(layer) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(" "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) From d5a9051b375e001546b99152d909257826cd9d46 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 16 Jan 2021 21:03:19 +0100 Subject: [PATCH 12/25] check IOContext to disable within arrays etc. --- src/layers/show.jl | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index b6ff6da548..fe5ba944e4 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -4,7 +4,15 @@ for T in [ :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] - @eval Base.show(io::IO, m::MIME"text/plain", x::$T) = _big_show(io, x) + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end + end end function _big_show(io::IO, obj, indent::Int=0) @@ -25,18 +33,17 @@ function _big_show(io::IO, obj, indent::Int=0) end _show_leaflike(::Any) = false -_show_leaflike(::Tuple{Vararg{<:Number}}) = true # stride of Conv -_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # parameters of LSTMcell -_show_leaflike(::Diagonal) = true # appears inside LayerNorm +_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv +_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell +_show_leaflike(::Diagonal) = true # appears inside LayerNorm, but let's collapse that. -# used both within Chain printing, and alone at top level. function _layer_show(io::IO, layer, indent::Int=0) # str = sprint(show, layer, context=io) str = string(layer) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) - print(" "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters", color=:light_black) + print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) + printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) _nan_show(io, params(layer)) end indent==0 || println(io) @@ -51,13 +58,15 @@ end # Zygote's containers -Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) = _param_show(io, p) +function Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) + get(io, :typeinfo, nothing) === nothing ? _param_show(io, p) : show(io, p) +end function _param_show(io::IO, p) length(p) == 0 && return print(io, typeof(p), "([])") println(io, typeof(p), "([") ipad = length(string(length(p))) + 2 - spad = min(40-6-ipad, maximum(length∘summary, p)) + spad = min(50-6-ipad, maximum(length∘summary, p)) wid = get(io, :displaysize, (0,100))[2] # not certain this is working for (i,x) in enumerate(p) printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) @@ -75,6 +84,10 @@ function _param_show(io::IO, p) end function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) + get(io, :typeinfo, nothing) === nothing ? _grad_show(io, g) : show(io, g) +end + +function _grad_show(io::IO, g) println(io, "Zygote.Grads(") pars, bytes, spad = 0, 0, 0 for k in keys(g.grads) From e071efd52fbc466b17ad38f1d5de2cd3f8449813 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 23 Jan 2021 11:26:26 +0100 Subject: [PATCH 13/25] rm Zygote's types --- src/layers/show.jl | 58 +++------------------------------------------- 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index fe5ba944e4..002d89dca6 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,7 +1,7 @@ for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, - :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, + :Chain, :Parallel, :SkipConnection, :Recur, # container types, as entry points for _big_show + :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, # others really sent to _layer_show :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) @@ -38,8 +38,7 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LS _show_leaflike(::Diagonal) = true # appears inside LayerNorm, but let's collapse that. function _layer_show(io::IO, layer, indent::Int=0) - # str = sprint(show, layer, context=io) - str = string(layer) + str = sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) @@ -56,57 +55,6 @@ function _big_finale(io::IO, ps) printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) end -# Zygote's containers - -function Base.show(io::IO, m::MIME"text/plain", p::Zygote.Params) - get(io, :typeinfo, nothing) === nothing ? _param_show(io, p) : show(io, p) -end - -function _param_show(io::IO, p) - length(p) == 0 && return print(io, typeof(p), "([])") - println(io, typeof(p), "([") - ipad = length(string(length(p))) + 2 - spad = min(50-6-ipad, maximum(length∘summary, p)) - wid = get(io, :displaysize, (0,100))[2] # not certain this is working - for (i,x) in enumerate(p) - printstyled(io, " ", lpad(string("[",i,"]"), ipad), color=:light_black) - desc = Base._truncate_at_width_or_chars(summary(x), spad) - data = sprint(show, x, context=IOContext(io, :compact => true, :limit => true, :typeinfo => eltype(x)), sizehint=0) - str = Base._truncate_at_width_or_chars(data, min(30, wid-40-12)) - print(io, " ", rpad(desc, spad), " ", str) - _nan_show(io, x) - println(io) - end - print(io, "])") - pars = underscorise(sum(length, p)) - bytes = Base.format_bytes(sum(sizeof, p)) - printstyled(io, " "^18, "# Total: ", pars, " parameters, ", bytes; color=:light_black) -end - -function Base.show(io::IO, m::MIME"text/plain", g::Zygote.Grads) - get(io, :typeinfo, nothing) === nothing ? _grad_show(io, g) : show(io, g) -end - -function _grad_show(io::IO, g) - println(io, "Zygote.Grads(") - pars, bytes, spad = 0, 0, 0 - for k in keys(g.grads) - pars += length(g[k]) - bytes += sizeof(g[k]) - spad = max(spad, length(summary(g[k]))) - end - for k in keys(g.grads) - x = g[k] - str = Base._truncate_at_width_or_chars(sprint(show, x), 32) # ?? - # print(io, " ", rpad(summary(x), spad), " ", str) - print(io, " ", rpad(summary(x), 20-4), " ", str) - _nan_show(io, x) - println(io) - end - print(io, ")") - printstyled(io, " "^19, "# Total: ", pars, " parameters, ", Base.format_bytes(bytes); color=:light_black) -end - # utility functions underscorise(n::Integer) = From 5f3586d8074fb9b80e57ae4ad2093c065fabb476 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 23 Jan 2021 11:28:18 +0100 Subject: [PATCH 14/25] conv doctests --- src/layers/conv.jl | 56 +++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 2ca3c7990c..788ec2222d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -67,7 +67,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref). julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images julia> lay = Conv((5,5), 3 => 7, relu; bias=false) -Conv((5, 5), 3=>7, relu) +Conv((5, 5), 3 => 7, relu) # 525 parameters julia> lay(xs) |> size (96, 96, 7, 50) @@ -98,7 +98,7 @@ end Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a convolutional layer with the given weight and bias. -Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)` +Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3 => 7, relu)` method. # Examples @@ -108,7 +108,7 @@ julia> weight = rand(3, 4, 5); julia> bias = zeros(5); julia> c1 = Conv(weight, bias, sigmoid) # expects 1 spatial dimension -Conv((3,), 4=>5, σ) +Conv((3,), 4 => 5, σ) # 65 parameters julia> c1(randn(100, 4, 64)) |> size (98, 5, 64) @@ -134,7 +134,7 @@ function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity end """ - convfilter(filter::Tuple, in=>out) + convfilter(filter::Tuple, in => out) Constructs a standard convolutional weight matrix with given `filter` and channels from `in` to `out`. @@ -159,7 +159,7 @@ end function Base.show(io::IO, l::Conv) print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) + print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight))) _print_conv_opt(io, l) print(io, ")") end @@ -191,15 +191,15 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = ConvTranspose((5,5), 3 => 7, relu) -ConvTranspose((5, 5), 3=>7, relu) +ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters julia> lay(xs) |> size (104, 104, 7, 50) -julia> ConvTranspose((5,5), 3=>7, stride=2)(xs) |> size +julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size (203, 203, 7, 50) -julia> ConvTranspose((5,5), 3=>7, stride=3, pad=SamePad())(xs) |> size +julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size (300, 300, 7, 50) ``` """ @@ -216,7 +216,7 @@ end ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method. +Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method. """ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -262,7 +262,7 @@ end function Base.show(io::IO, l::ConvTranspose) print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) + print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1)) _print_conv_opt(io, l) print(io, ")") end @@ -273,7 +273,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat end """ - DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) + DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) Depthwise convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -291,7 +291,7 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false) -DepthwiseConv((5, 5), 3=>6, relu) +DepthwiseConv((5, 5), 3 => 6, relu) # 150 parameters julia> lay(xs) |> size (96, 96, 6, 50) @@ -313,7 +313,7 @@ end DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method. +Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method. """ function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -334,7 +334,7 @@ end @functor DepthwiseConv """ - depthwiseconvfilter(filter::Tuple, in=>out) + depthwiseconvfilter(filter::Tuple, in => out) Constructs a depthwise convolutional weight array defined by `filter` and channels from `in` to `out`. @@ -355,7 +355,7 @@ end function Base.show(io::IO, l::DepthwiseConv) print(io, "DepthwiseConv(", size(l.weight)[1:end-2]) - print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end])) + print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end])) _print_conv_opt(io, l) print(io, ")") end @@ -379,12 +379,12 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false) -CrossCor((5, 5), 3=>6, relu) +CrossCor((5, 5), 3 => 6, relu) # 450 parameters julia> lay(xs) |> size (96, 96, 6, 50) -julia> CrossCor((5,5), 3=>7, stride=3, pad=(2,0))(xs) |> size +julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size (34, 32, 7, 50) ``` """ @@ -401,7 +401,7 @@ end CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Constructs a layer with the given weight and bias arrays. -Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method. +Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method. """ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} @@ -436,7 +436,7 @@ end function Base.show(io::IO, l::CrossCor) print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight))) + print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight))) _print_conv_opt(io, l) print(io, ")") end @@ -536,8 +536,7 @@ See also [`MaxPool`](@ref), [`GlobalMeanPool`](@ref). ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); -julia> m = Chain(Conv((3,3), 3=>7), GlobalMaxPool()) -Chain(Conv((3, 3), 3=>7), GlobalMaxPool()) +julia> m = Chain(Conv((3,3), 3 => 7), GlobalMaxPool()); julia> m(xs) |> size (1, 1, 7, 50) @@ -574,8 +573,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps. ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); -julia> m = Chain(Conv((3,3), 3=>7), GlobalMeanPool()) -Chain(Conv((3, 3), 3=>7), GlobalMeanPool()) +julia> m = Chain(Conv((3,3), 3 => 7), GlobalMeanPool()); julia> m(xs) |> size (1, 1, 7, 50) @@ -618,8 +616,11 @@ See also [`Conv`](@ref), [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref), [`Global ```jldoctest julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images -julia> m = Chain(Conv((5, 5), 3=>7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())) -Chain(Conv((5, 5), 3=>7), MaxPool((5, 5), pad=2)) +julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())) +Chain( + Conv((5, 5), 3 => 7, pad=2), # 532 parameters + MaxPool((5, 5), pad=2), +) julia> m[1](xs) |> size (100, 100, 7, 50) @@ -681,7 +682,10 @@ See also [`Conv`](@ref), [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). julia> xs = rand(Float32, 100, 100, 3, 50); julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad())) -Chain(Conv((5, 5), 3=>7), MeanPool((5, 5), pad=2)) +Chain( + Conv((5, 5), 3 => 7), # 532 parameters + MeanPool((5, 5), pad=2), +) julia> m[1](xs) |> size (96, 96, 7, 50) From 92ac152dafd8fa5de8c8e66755cd8280ccb08df8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 May 2021 22:03:37 -0400 Subject: [PATCH 15/25] fix after rebase, tidy up & simplify --- src/functor.jl | 2 +- src/layers/show.jl | 60 ++++++++++++++++++++++++++++------------------ 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 30305e2173..99d6411cfa 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,7 +1,7 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky using Zygote: IdSet -import Functors: Functors, @functor, functor, fmap +import Functors: Functors, @functor, functor, fmap, isleaf trainable(m) = functor(m)[1] diff --git a/src/layers/show.jl b/src/layers/show.jl index 002d89dca6..e70b888be3 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,8 +1,6 @@ for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, # container types, as entry points for _big_show - :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, # others really sent to _layer_show - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, + :Chain, :Parallel, :SkipConnection, :Recur # container types ] @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL @@ -17,25 +15,39 @@ end function _big_show(io::IO, obj, indent::Int=0) children = trainable(obj) - if all(c -> isleaf(c) || _show_leaflike(c), children) - return _layer_show(io, obj, indent) - end - println(io, " "^indent, nameof(typeof(obj)), "(") - for (i,c) in enumerate(children) - _big_show(io, c, indent+2) - end - if indent == 0 - print(io, ")") - _big_finale(io, params(obj)) + if all(_show_leaflike, children) + _layer_show(io, obj, indent) else - println(io, " "^indent, "),") + println(io, " "^indent, nameof(typeof(obj)), "(") + for c in children + _big_show(io, c, indent+2) + end + if indent == 0 + print(io, ")") + _big_finale(io, params(obj)) + else + println(io, " "^indent, "),") + end end end -_show_leaflike(::Any) = false -_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv +_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: +_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell -_show_leaflike(::Diagonal) = true # appears inside LayerNorm, but let's collapse that. +_show_leaflike(::Diagonal) = true # appears inside LayerNorm + +for T in [ + :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, + :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, + ] + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end + end +end function _layer_show(io::IO, layer, indent::Int=0) str = sprint(show, layer, context=io) @@ -49,10 +61,11 @@ function _layer_show(io::IO, layer, indent::Int=0) end function _big_finale(io::IO, ps) - length(ps) < 3 && return - pars = underscorise(sum(length, ps)) - bytes = Base.format_bytes(sum(sizeof, ps)) - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) + if length(ps) > 2 + pars = underscorise(sum(length, ps)) + bytes = Base.format_bytes(sum(sizeof, ps)) + printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) + end end # utility functions @@ -71,8 +84,9 @@ function _nan_show(io::IO, x) end _any(f, xs::AbstractArray{<:Number}) = any(f, xs) -_any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) +# _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) +_any(f, xs) = any(x -> _any(f, x), xs) _any(f, x::Number) = f(x) -_any(f, x) = false +# _any(f, x) = false _all(f, xs) = !_any(!f, xs) From 7f54144dd377e576fe3a841f483864579b227cd0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 May 2021 22:03:58 -0400 Subject: [PATCH 16/25] add some tests --- test/layers/show.jl | 70 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 71 insertions(+) create mode 100644 test/layers/show.jl diff --git a/test/layers/show.jl b/test/layers/show.jl new file mode 100644 index 0000000000..9c689eba49 --- /dev/null +++ b/test/layers/show.jl @@ -0,0 +1,70 @@ + +@testset "layer printing" begin # 2-arg show, defined with layes + + @test repr(Dense(2,3)) == "Dense(2, 3)" + @test repr(Chain(Dense(2,3))) == "Chain(Dense(2, 3))" + +end +@testset "nested model printing" begin # 3-arg show, defined in show.jl + + # Dense -- has parameter count, but not inside a matrix + + toplevel_dense = repr("text/plain", Dense(2,3)) + @test occursin("Dense(2, 3)", toplevel_dense) + @test occursin("# 9 parameters", toplevel_dense) + + @test Meta.isexpr(Meta.parse(toplevel_dense), :call) # comment is ignored + + vector_dense = repr("text/plain", [Dense(2,3), Dense(2,3)]) + @test occursin("Dense(2, 3)", vector_dense) + @test occursin("# 9 parameters", vector_dense) + + matrix_dense = repr("text/plain", fill(Dense(2,3), 3, 3)) + @test occursin("Dense(2, 3)", matrix_dense) + @test !occursin("# 9 parameters", matrix_dense) + + tuple_dense = repr("text/plain", tuple(Dense(2,3))) + @test occursin("Dense(2, 3)", tuple_dense) + @test !occursin("# 9 parameters", tuple_dense) + + # Chain -- gets split over lines at top level only + + toplevel_chain = repr("text/plain", Chain(Dense(2,3))) + @test occursin("Chain(\n Dense(2, 3)", toplevel_chain) + @test occursin("# 9 parameters", toplevel_chain) + @test !occursin("# Total:", toplevel_chain) + + vector_chain = repr("text/plain", [Chain(Dense(2,3)), Chain(Dense(2,3))]) + @test occursin("Chain(Dense(2, 3))", vector_chain) + @test occursin("# 9 parameters", vector_chain) + @test !occursin("# Total:", vector_chain) + + matrix_chain = repr("text/plain", fill(Chain(Dense(2,3)), 3,3)) + @test occursin("Chain(Dense(2, 3))", matrix_chain) + @test !occursin("# 9 parameters", matrix_chain) + @test !occursin("# Total:", matrix_chain) + + # ... and only long enough chains get + + longchain = Chain(Dense(2, 3), Dense(3, 4), Dense(4, 5), softmax) + + toplevel_longchain = repr("text/plain", longchain) + @test occursin("Chain(\n Dense(2, 3)", toplevel_longchain) + @test occursin("# 9 parameters", toplevel_longchain) + @test occursin("# Total: 6 arrays, 50 parameters", toplevel_longchain) + + vector_longchain = repr("text/plain", [longchain, longchain]) # pretty ugly in reality + @test occursin("Chain(Dense(2, 3)", vector_longchain) + @test occursin("# 50 parameters", vector_longchain) + @test !occursin("# 9 parameters", vector_longchain) + @test !occursin("# Total:", vector_longchain) + + matrix_longchain = repr("text/plain", fill(longchain, 3,3)) + @test occursin("Chain(Dense(2, 3)", matrix_longchain) + @test !occursin("# 9 parameters", matrix_longchain) + @test !occursin("# Total:", matrix_longchain) + + @test Meta.isexpr(Meta.parse(toplevel_longchain), :call) # comments are ignored + @test Meta.parse(toplevel_longchain).args[1] == :Chain + +end diff --git a/test/runtests.jl b/test/runtests.jl index a40433d0f1..0d02323807 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end include("layers/recurrent.jl") include("layers/conv.jl") include("layers/upsample.jl") + include("layers/show.jl") end @testset "outputsize" begin From 8f11ab81a67f9dc2c0901ec67b760ed3e5dd5a41 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 2 Jul 2021 19:37:44 -0400 Subject: [PATCH 17/25] change to use Base.summarysize --- src/layers/show.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index e70b888be3..0e17cb46d2 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -24,7 +24,7 @@ function _big_show(io::IO, obj, indent::Int=0) end if indent == 0 print(io, ")") - _big_finale(io, params(obj)) + _big_finale(io, obj) else println(io, " "^indent, "),") end @@ -60,10 +60,11 @@ function _layer_show(io::IO, layer, indent::Int=0) indent==0 || println(io) end -function _big_finale(io::IO, ps) +function _big_finale(io::IO, m) + ps = params(m) if length(ps) > 2 pars = underscorise(sum(length, ps)) - bytes = Base.format_bytes(sum(sizeof, ps)) + bytes = Base.format_bytes(Base.summarysize(m)) printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) end end From dc5eca2c3266fbf3ae4fd85ebdfad8885eacd1b6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 2 Jul 2021 21:23:56 -0400 Subject: [PATCH 18/25] count non-trainable parameters --- src/layers/show.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 0e17cb46d2..ee3f6732ba 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -65,10 +65,20 @@ function _big_finale(io::IO, m) if length(ps) > 2 pars = underscorise(sum(length, ps)) bytes = Base.format_bytes(Base.summarysize(m)) - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) + noncnt = _childarray_sum(_->1, m) - length(ps) + if noncnt > 0 + nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) + printstyled(io, " "^19, "# Total: ", length(ps), " trainable arrays, ", pars, " parameters,\n"; color=:light_black) + printstyled(io, " "^20, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, total size ", bytes; color=:light_black) + else + printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) + end end end +_childarray_sum(f, x::AbstractArray) = f(x) # count includes non-trainable arrays excluded from params +_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) + # utility functions underscorise(n::Integer) = From 9690dfb59d5ea0c41200efe4173079f90c3d8a4f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 2 Jul 2021 21:36:24 -0400 Subject: [PATCH 19/25] reduce indent --- src/layers/show.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index ee3f6732ba..797b7e0299 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -68,15 +68,15 @@ function _big_finale(io::IO, m) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^19, "# Total: ", length(ps), " trainable arrays, ", pars, " parameters,\n"; color=:light_black) - printstyled(io, " "^20, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, total size ", bytes; color=:light_black) + printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, with ", pars, " parameters,\n"; color=:light_black) + printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, total size ", bytes; color=:light_black) else printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) end end end -_childarray_sum(f, x::AbstractArray) = f(x) # count includes non-trainable arrays excluded from params +_childarray_sum(f, x::AbstractArray) = f(x) _childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) # utility functions From 3e9c4818374f68a9133a061b3b03da9eb3e067aa Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 00:08:13 -0400 Subject: [PATCH 20/25] note non-trainables on every layer which has such --- src/layers/show.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/layers/show.jl b/src/layers/show.jl index 797b7e0299..d3f5efa5bb 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -55,6 +55,10 @@ function _layer_show(io::IO, layer, indent::Int=0) if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) + nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) + if nonparam > 0 + printstyled(io, ", plus ", underscorise(nonparam); color=:light_black) + end _nan_show(io, params(layer)) end indent==0 || println(io) From 409c11e0c82a3b83be19d92a6b86d9a660fabe75 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 00:09:07 -0400 Subject: [PATCH 21/25] print some totals not in light_black --- src/layers/show.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index d3f5efa5bb..8de9ff1c4d 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -72,10 +72,13 @@ function _big_finale(io::IO, m) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, with ", pars, " parameters,\n"; color=:light_black) - printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, total size ", bytes; color=:light_black) + printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, with "; color=:light_black) + println(io, pars, " parameters") + printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) + print(io, bytes) else - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black) + printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) + print(io, pars, " parameters, ", bytes) end end end From f65d5147c6b8a72eec47b4c999d56c733635be83 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 00:47:01 -0400 Subject: [PATCH 22/25] extra when not indented --- src/layers/show.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 8de9ff1c4d..56fc78bafe 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -57,7 +57,7 @@ function _layer_show(io::IO, layer, indent::Int=0) printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) if nonparam > 0 - printstyled(io, ", plus ", underscorise(nonparam); color=:light_black) + printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) end _nan_show(io, params(layer)) end From 2423c1e51f22d2bd6ade4f453e417af4a6f0641a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 02:09:27 -0400 Subject: [PATCH 23/25] fix all doctests --- docs/src/utilities.md | 2 +- src/layers/basic.jl | 10 +++++++--- src/layers/conv.jl | 10 +++++----- src/utils.jl | 33 ++++++++++++++++++++------------- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 896ba38761..01bb34618c 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -28,7 +28,7 @@ To change the default on an applicable layer, pass the desired function with the ```jldoctest; setup = :(using Flux) julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal) -Conv((3, 3), 1=>8, relu) +Conv((3, 3), 1 => 8, relu) # 80 parameters ``` ```@docs diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7456de1447..70af256991 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -89,7 +89,7 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided # Examples ```jldoctest julia> d = Dense(5, 2) -Dense(5, 2) +Dense(5, 2) # 12 parameters julia> d(rand(Float32, 5, 64)) |> size (2, 64) @@ -98,7 +98,7 @@ julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimension (2, 1, 1, 64) julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix -Dense(5, 2, tanh; bias=false) +Dense(5, 2, tanh; bias=false) # 10 parameters julia> d1(ones(5)) 2-element Vector{Float64}: @@ -395,7 +395,11 @@ julia> size(model(rand(3))) (17,) julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)) -Parallel(+, Dense(10, 2), Dense(5, 2)) +Parallel( + +, + Dense(10, 2), # 22 parameters + Dense(5, 2), # 12 parameters +) # Total: 4 arrays, 34 parameters, 392 bytes julia> size(model(rand(10), rand(5))) (2,) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 788ec2222d..208ab7a265 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -67,7 +67,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref). julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images julia> lay = Conv((5,5), 3 => 7, relu; bias=false) -Conv((5, 5), 3 => 7, relu) # 525 parameters +Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters julia> lay(xs) |> size (96, 96, 7, 50) @@ -291,7 +291,7 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false) -DepthwiseConv((5, 5), 3 => 6, relu) # 150 parameters +DepthwiseConv((5, 5), 3 => 6, relu, bias=false) # 150 parameters julia> lay(xs) |> size (96, 96, 6, 50) @@ -379,7 +379,7 @@ See also [`Conv`](@ref) for more detailed description of keywords. julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false) -CrossCor((5, 5), 3 => 6, relu) # 450 parameters +CrossCor((5, 5), 3 => 6, relu, bias=false) # 450 parameters julia> lay(xs) |> size (96, 96, 6, 50) @@ -618,7 +618,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())) Chain( - Conv((5, 5), 3 => 7, pad=2), # 532 parameters + Conv((5, 5), 3 => 7, pad=2), # 532 parameters MaxPool((5, 5), pad=2), ) @@ -683,7 +683,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50); julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad())) Chain( - Conv((5, 5), 3 => 7), # 532 parameters + Conv((5, 5), 3 => 7), # 532 parameters MeanPool((5, 5), pad=2), ) diff --git a/src/utils.jl b/src/utils.jl index 06b2bb01b0..901b9807f6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,14 +13,12 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r # Examples ```jldoctest -julia> layer = Dense(10, 20) -Dense(10, 20) +julia> layer = Dense(10, 20); julia> Flux.nfan(size(layer.W)) (10, 20) -julia> layer = Conv((3, 3), 2=>10) -Conv((3, 3), 2=>10) +julia> layer = Conv((3, 3), 2=>10); julia> Flux.nfan(size(layer.weight)) (18, 90) @@ -506,7 +504,7 @@ julia> Flux.chunk(1:10, 3) 9:10 julia> Flux.chunk(collect(1:10), 3) -3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: +3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: [1, 2, 3, 4] [5, 6, 7, 8] [9, 10] @@ -720,19 +718,25 @@ over specific modules or subsets of the parameters # Examples ```jldoctest -julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)) -Chain(Dense(784, 64), BatchNorm(64, relu)) +julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)); julia> m2 = Chain(m1, Dense(64, 10)) -Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) +Chain( + Chain( + Dense(784, 64), # 50_240 parameters + BatchNorm(64, relu), # 128 parameters, plus 128 + ), + Dense(64, 10), # 650 parameters +) # Total: 6 trainable arrays, with 51_018 parameters + # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB julia> Flux.modules(m2) 5-element Vector{Any}: - Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) - Chain(Dense(784, 64), BatchNorm(64, relu)) - Dense(784, 64) - BatchNorm(64, relu) - Dense(64, 10) + Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable + Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable + Dense(784, 64) # 50_240 parameters + BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable + Dense(64, 10) # 650 parameters julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense) L2 (generic function with 1 method) @@ -760,6 +764,7 @@ julia> loss() = rand(); julia> trigger = Flux.patience(() -> loss() < 1, 3); + julia> Flux.@epochs 10 begin trigger() && break end @@ -796,6 +801,7 @@ julia> loss = let l = 0 julia> es = Flux.early_stopping(loss, 3); + julia> Flux.@epochs 10 begin es() && break end @@ -836,6 +842,7 @@ julia> f = let v = 10 julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18); + julia> Flux.@epochs 10 begin trigger() && break end From a5e39df12d9031573827e640c82ff469aa996eb5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 02:14:49 -0400 Subject: [PATCH 24/25] rm comment --- src/layers/normalise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 8e7e4b0e71..1d0e8f73b9 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -279,7 +279,7 @@ testmode!(m::BatchNorm, mode=true) = function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") (l.λ == identity) || print(io, ", $(l.λ)") - hasaffine(l) || print(io, ", affine=false") # ?? + hasaffine(l) || print(io, ", affine=false") print(io, ")") end From c91e731caf5eadba39765fdfb13bc9ae73f262da Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 3 Jul 2021 02:21:32 -0400 Subject: [PATCH 25/25] tweaks --- src/layers/basic.jl | 2 +- src/layers/show.jl | 8 ++++---- src/utils.jl | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 70af256991..8f4b1053b7 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -399,7 +399,7 @@ Parallel( +, Dense(10, 2), # 22 parameters Dense(5, 2), # 12 parameters -) # Total: 4 arrays, 34 parameters, 392 bytes +) # Total: 4 arrays, 34 parameters, 392 bytes. julia> size(model(rand(10), rand(5))) (2,) diff --git a/src/layers/show.jl b/src/layers/show.jl index 56fc78bafe..40d49dd9d1 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -72,13 +72,13 @@ function _big_finale(io::IO, m) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, with "; color=:light_black) - println(io, pars, " parameters") + printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) + println(io, pars, " parameters,") printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) - print(io, bytes) + print(io, bytes, ".") else printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) - print(io, pars, " parameters, ", bytes) + print(io, pars, " parameters, ", bytes, ".") end end end diff --git a/src/utils.jl b/src/utils.jl index 901b9807f6..73acd45f96 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -727,8 +727,8 @@ Chain( BatchNorm(64, relu), # 128 parameters, plus 128 ), Dense(64, 10), # 650 parameters -) # Total: 6 trainable arrays, with 51_018 parameters - # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB +) # Total: 6 trainable arrays, 51_018 parameters, + # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB. julia> Flux.modules(m2) 5-element Vector{Any}: