-
Notifications
You must be signed in to change notification settings - Fork 22
Use Bijection
for thunk body cache
#1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mmm all these format comments seem to come from code unrelated to this PR... |
src/Tracing.jl
Outdated
K = dict_key(T) | ||
V = dict_value(T) | ||
if V === nothing | ||
|
||
K_traced = if !isnothing(K) | ||
traced_type_inner(K, seen, mode, track_numbers, sharding, runtime) | ||
else | ||
nothing | ||
end | ||
V_traced = if !isnothing(V) | ||
traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) | ||
else | ||
nothing | ||
end | ||
|
||
if K == K_traced && V == V_traced | ||
return T | ||
end | ||
|
||
dictty = if T isa UnionAll | ||
T.body.name.wrapper | ||
else | ||
K = dict_key(T) | ||
V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) | ||
if V == V2 | ||
return T | ||
end | ||
dictty = if T isa UnionAll | ||
T.body.name.wrapper | ||
else | ||
T.name.wrapper | ||
end | ||
if K !== nothing | ||
return dictty{K,V2} | ||
else | ||
return (dictty{KT,V2} where {KT}) | ||
end | ||
T.name.wrapper | ||
end | ||
|
||
if isnothing(K_traced) && isnothing(V_traced) | ||
return (dictty{Kt,Vt} where {Kt,Vt}) | ||
elseif isnothing(K_traced) | ||
return (dictty{Kt,V_traced} where {Kt}) | ||
elseif isnothing(V_traced) | ||
return (dictty{K_traced,Vt} where {Vt}) | ||
else | ||
return dictty{K_traced,V_traced} | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this changes the behavior of tracing for Dict
such that the key type is also traced.
it is required for Bijection
to work.
@wsmoses is that alright?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsmoses I managed to get tracing working with general AbstractDict
. Instead of tracing collect(dict)
which is a Vector{<:Pair}
, it does so by tracing the Pair
s iteratively and enumerating them. Also, I had to refactor traced_getfield(::AbstractDict)
to accept a Integer
which is the iteration index.
When I introduced traced_getfield
, it was to "fake" fields. It's behavior seems to have evolved but I have the feeling that for this case it doesn't fit perfectly. The reason is that we are now passing a iteration index (and before this PR it should be sth like traced_getindex
).
Finally, I didn't manage to get it with inplace mutations (i.e. adding a new field), so adding or removing entries might be problematic.
For example, this works
julia> function combine_ab_to_c!(d)
d[:a] = d[:a] + d[:b]
return d
end
julia> @jit combine_ab_to_c!(d)
Bijection{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Dict{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, IdDict{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Symbol}} with 2 entries:
:a => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([5.0])
:b => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([1.0])
but this doesn't work
julia> function combine_ab_to_c!(d)
d[:c] = d[:a] + d[:b]
return d
end
combine_ab_to_c! (generic function with 1 method)
julia> @jit combine_ab_to_c!(d)
ERROR: ArgumentError: collection must be non-empty
Stacktrace:
[1] first
@ ./abstractarray.jl:473 [inlined]
[2] traced_getfield
@ ~/Developer/Reactant.jl/src/Compiler.jl:37 [inlined]
[3] macro expansion
@ ~/Developer/Reactant.jl/src/Compiler.jl:2914 [inlined]
[4] (::Reactant.Compiler.Thunk{…})(args::Bijection{…})
@ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:3473
[5] top-level scope
@ ~/Developer/Reactant.jl/src/Compiler.jl:2333
Some type information was truncated. Use `show(err)` to see complete types.
The
__thunk_fwd_body_cache
and__thunk_rev_body_cache
internal variables can be unified as aBijection
.@wsmoses The new Bijections.jl 0.2 release also has all I required from BijectiveDicts.jl for Tenet.jl, so I'm refactoring my package with it.
If we are able to correctly and performantly trace over a
Bijection
(more specifically, aBijection{X, TracedRArray, Dict{X, TracedRArray}, IdDict{TracedRArray, X}}
whereX
can be whatever, then I shouldn't need any Reactant tracing customization in Tenet.EDIT: I've generalized some tracing functions to work with
Bijection
too.