diff --git a/Project.toml b/Project.toml index c2f3ddb4cd..8ad03d669c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.118" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" @@ -63,6 +64,7 @@ ReactantYaoBlocksExt = "YaoBlocks" AbstractFFTs = "1.5" Adapt = "4.1" ArrayInterface = "7.17.1" +Bijections = "0.2.1" CEnum = "0.5" CUDA = "5.6" Downloads = "1.6" diff --git a/src/Compiler.jl b/src/Compiler.jl index 243ca5a5d5..1914e4e684 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -3,6 +3,7 @@ module Compiler using Reactant_jll using Libdl: dlsym using LinearAlgebra: BLAS +using Bijections import ..Reactant: Reactant, @@ -3238,8 +3239,7 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs.. end # inspired by RuntimeGeneratedFunction.jl -const __thunk_fwd_body_cache = Dict{Symbol,Expr}() -const __thunk_rev_body_cache = Dict{Expr,Symbol}() +const __thunk_body_cache = Bijection{Symbol,Expr}() function compile(f, args; sync=false, kwargs...) _, exec, mlir_fn_res, device, client, str = compile_xla(f, args; kwargs...) @@ -3352,12 +3352,11 @@ function compile(f, args; sync=false, kwargs...) display(mlir_fn_res.donated_args_mask) end - fname = if body in keys(__thunk_rev_body_cache) - __thunk_rev_body_cache[body] + fname = if hasvalue(__thunk_body_cache, body) + __thunk_body_cache(body) else fname2 = gensym(Symbol(Symbol(f), :_reactant)) - __thunk_rev_body_cache[body] = fname2 - __thunk_fwd_body_cache[fname2] = body + __thunk_body_cache[fname2] = body fname2 end @@ -3439,7 +3438,7 @@ end ) end end - body = __thunk_fwd_body_cache[tag] + body = __thunk_body_cache[tag] if IsClosure return quote args = (thunk.f, args...)