Skip to content

Commit 92ac152

Browse files
committed
fix after rebase, tidy up & simplify
1 parent 5f3586d commit 92ac152

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

src/functor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Adapt: adapt, adapt_storage
22
using LinearAlgebra: Cholesky
33
using Zygote: IdSet
4-
import Functors: Functors, @functor, functor, fmap
4+
import Functors: Functors, @functor, functor, fmap, isleaf
55

66
trainable(m) = functor(m)[1]
77

src/layers/show.jl

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11

22
for T in [
3-
:Chain, :Parallel, :SkipConnection, :Recur, # container types, as entry points for _big_show
4-
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, # others really sent to _layer_show
5-
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
3+
:Chain, :Parallel, :SkipConnection, :Recur # container types
64
]
75
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
86
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
@@ -17,25 +15,39 @@ end
1715

1816
function _big_show(io::IO, obj, indent::Int=0)
1917
children = trainable(obj)
20-
if all(c -> isleaf(c) || _show_leaflike(c), children)
21-
return _layer_show(io, obj, indent)
22-
end
23-
println(io, " "^indent, nameof(typeof(obj)), "(")
24-
for (i,c) in enumerate(children)
25-
_big_show(io, c, indent+2)
26-
end
27-
if indent == 0
28-
print(io, ")")
29-
_big_finale(io, params(obj))
18+
if all(_show_leaflike, children)
19+
_layer_show(io, obj, indent)
3020
else
31-
println(io, " "^indent, "),")
21+
println(io, " "^indent, nameof(typeof(obj)), "(")
22+
for c in children
23+
_big_show(io, c, indent+2)
24+
end
25+
if indent == 0
26+
print(io, ")")
27+
_big_finale(io, params(obj))
28+
else
29+
println(io, " "^indent, "),")
30+
end
3231
end
3332
end
3433

35-
_show_leaflike(::Any) = false
36-
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
34+
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
35+
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
3736
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
38-
_show_leaflike(::Diagonal) = true # appears inside LayerNorm, but let's collapse that.
37+
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
38+
39+
for T in [
40+
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense,
41+
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
42+
]
43+
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
44+
if !get(io, :compact, false)
45+
_layer_show(io, x)
46+
else
47+
show(io, x)
48+
end
49+
end
50+
end
3951

4052
function _layer_show(io::IO, layer, indent::Int=0)
4153
str = sprint(show, layer, context=io)
@@ -49,10 +61,11 @@ function _layer_show(io::IO, layer, indent::Int=0)
4961
end
5062

5163
function _big_finale(io::IO, ps)
52-
length(ps) < 3 && return
53-
pars = underscorise(sum(length, ps))
54-
bytes = Base.format_bytes(sum(sizeof, ps))
55-
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black)
64+
if length(ps) > 2
65+
pars = underscorise(sum(length, ps))
66+
bytes = Base.format_bytes(sum(sizeof, ps))
67+
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, ", pars, " parameters, ", bytes; color=:light_black)
68+
end
5669
end
5770

5871
# utility functions
@@ -71,8 +84,9 @@ function _nan_show(io::IO, x)
7184
end
7285

7386
_any(f, xs::AbstractArray{<:Number}) = any(f, xs)
74-
_any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs)
87+
# _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs)
88+
_any(f, xs) = any(x -> _any(f, x), xs)
7589
_any(f, x::Number) = f(x)
76-
_any(f, x) = false
90+
# _any(f, x) = false
7791

7892
_all(f, xs) = !_any(!f, xs)

0 commit comments

Comments
 (0)