Skip to content

Commit ef4fedd

Browse files
authored
Apply 3 suggestions from myself
1 parent 7eddf0e commit ef4fedd

File tree

3 files changed

+20
-29
lines changed

3 files changed

+20
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.2.8, 0.3"
16-
Yota = "0.7.3"
16+
Yota = "0.8.0"
1717
Zygote = "0.6.40"
1818
julia = "1.6"
1919

docs/src/index.md

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -92,36 +92,24 @@ end;
9292
Unfortunately this example doesn't actually run right now. This is the error:
9393
```
9494
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95-
sum(m(x))
96-
end;
97-
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98-
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99-
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100-
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101-
102-
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103-
95+
sum(m(x))
96+
end;
97+
ERROR: BoundsError: attempt to access Nothing at index [1]
10498
Stacktrace:
105-
[1] error(s::String)
106-
@ Base ./error.jl:35
107-
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109-
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111-
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113-
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115-
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116-
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117-
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118-
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
99+
[1] _getfield(value::Nothing, fld::Int64)
100+
@ Yota ~/.julia/packages/Yota/uu3H0/src/helpers.jl:40
101+
[2] mkcall(::Function, ::Umlaut.Variable, ::Vararg{Any}; val::Missing, line::Nothing, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
102+
@ Umlaut ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:192
103+
[3] mkcall
104+
@ ~/.julia/packages/Umlaut/SvDaQ/src/tape.jl:174 [inlined]
105+
[4] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx})
106+
@ Yota ~/.julia/packages/Yota/uu3H0/src/grad.jl:183
107+
119108
...
120109
121-
(jl_GWa2lX) pkg> st
122-
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123-
⌃ [587475ba] Flux v0.13.4
124-
[cd998857] Yota v0.7.4
110+
(@v1.9) pkg> st Yota
111+
Status `~/.julia/environments/v1.9/Project.toml`
112+
[cd998857] Yota v0.8.0
125113
```
126114

127115
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ struct TwoThirds a; b; c; end
1313
Functors.@functor TwoThirds (a, c)
1414
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
1515

16-
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
16+
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]));
17+
y2z(::AbstractZero) = nothing # we don't care about different flavours
18+
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t)))
19+
y2z(x) = x
1720

1821
@testset verbose=true "Optimisers.jl" begin
1922
@testset verbose=true "Features" begin

0 commit comments

Comments
 (0)