Skip to content

Use Bijection for thunk body cache #1338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Use Bijection for thunk body cache #1338

wants to merge 17 commits into from

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented May 27, 2025

The __thunk_fwd_body_cache and __thunk_rev_body_cache internal variables can be unified as a Bijection.

@wsmoses The new Bijections.jl 0.2 release also has all I required from BijectiveDicts.jl for Tenet.jl, so I'm refactoring my package with it.

If we are able to correctly and performantly trace over a Bijection (more specifically, a Bijection{X, TracedRArray, Dict{X, TracedRArray}, IdDict{TracedRArray, X}} where X can be whatever, then I shouldn't need any Reactant tracing customization in Tenet.

EDIT: I've generalized some tracing functions to work with Bijection too.

github-actions[bot]

This comment was marked as outdated.

@mofeing mofeing requested a review from wsmoses June 2, 2025 21:31
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@mofeing
Copy link
Collaborator Author

mofeing commented Jun 2, 2025

mmm all these format comments seem to come from code unrelated to this PR...

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

src/Tracing.jl Outdated
Comment on lines 219 to 252
K = dict_key(T)
V = dict_value(T)
if V === nothing

K_traced = if !isnothing(K)
traced_type_inner(K, seen, mode, track_numbers, sharding, runtime)
else
nothing
end
V_traced = if !isnothing(V)
traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)
else
nothing
end

if K == K_traced && V == V_traced
return T
end

dictty = if T isa UnionAll
T.body.name.wrapper
else
K = dict_key(T)
V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)
if V == V2
return T
end
dictty = if T isa UnionAll
T.body.name.wrapper
else
T.name.wrapper
end
if K !== nothing
return dictty{K,V2}
else
return (dictty{KT,V2} where {KT})
end
T.name.wrapper
end

if isnothing(K_traced) && isnothing(V_traced)
return (dictty{Kt,Vt} where {Kt,Vt})
elseif isnothing(K_traced)
return (dictty{Kt,V_traced} where {Kt})
elseif isnothing(V_traced)
return (dictty{K_traced,Vt} where {Vt})
else
return dictty{K_traced,V_traced}
end
end
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this changes the behavior of tracing for Dict such that the key type is also traced.
it is required for Bijection to work.

@wsmoses is that alright?

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

Copy link
Collaborator Author

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses I managed to get tracing working with general AbstractDict. Instead of tracing collect(dict) which is a Vector{<:Pair}, it does so by tracing the Pairs iteratively and enumerating them. Also, I had to refactor traced_getfield(::AbstractDict) to accept a Integer which is the iteration index.

When I introduced traced_getfield, it was to "fake" fields. It's behavior seems to have evolved but I have the feeling that for this case it doesn't fit perfectly. The reason is that we are now passing a iteration index (and before this PR it should be sth like traced_getindex).

Finally, I didn't manage to get it with inplace mutations (i.e. adding a new field), so adding or removing entries might be problematic.

For example, this works

julia> function combine_ab_to_c!(d)
         d[:a] = d[:a] + d[:b]
         return d
       end

julia> @jit combine_ab_to_c!(d)
Bijection{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Dict{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, IdDict{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Symbol}} with 2 entries:
  :a => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([5.0])
  :b => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([1.0])

but this doesn't work

julia> function combine_ab_to_c!(d)
         d[:c] = d[:a] + d[:b]
         return d
       end
combine_ab_to_c! (generic function with 1 method)

julia> @jit combine_ab_to_c!(d)
ERROR: ArgumentError: collection must be non-empty
Stacktrace:
 [1] first
   @ ./abstractarray.jl:473 [inlined]
 [2] traced_getfield
   @ ~/Developer/Reactant.jl/src/Compiler.jl:37 [inlined]
 [3] macro expansion
   @ ~/Developer/Reactant.jl/src/Compiler.jl:2914 [inlined]
 [4] (::Reactant.Compiler.Thunk{…})(args::Bijection{…})
   @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:3473
 [5] top-level scope
   @ ~/Developer/Reactant.jl/src/Compiler.jl:2333
Some type information was truncated. Use `show(err)` to see complete types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant