We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using Enzyme, Reactant, Test, Random function simple_forward(x, st) rng = copy(st.rng) y = similar(x) rand!(rng, y) return x .+ y, (; rng) end function gradient_fn(x, st) stₙ = Ref{Any}(nothing) function lfn(x, st_old) y, st_new = simple_forward(x, st_old) stₙ[] = st_new return sum(abs2, y) end return Enzyme.gradient(Reverse, lfn, x, Const(st)), stₙ[] end x = Reactant.to_rarray(rand(2, 2)) st = (; rng=Reactant.ConcreteRNG()) @code_hlo optimize = true gradient_fn(x, st)
"builtin.module"() <{sym_name = "reactant_gradien..."}> ({ "func.func"() <{function_type = (tensor<f64>, tensor<f64>) -> (tensor<f64>, tensor<f64>, tensor<f64>), sym_name = "+_broadcast_scalar", sym_visibility = "private"}> ({ ^bb0(%arg7: tensor<f64>, %arg8: tensor<f64>): %45 = "stablehlo.add"(%arg7, %arg8) : (tensor<f64>, tensor<f64>) -> tensor<f64> "func.return"(%45, %arg7, %arg8) : (tensor<f64>, tensor<f64>, tensor<f64>) -> () }) : () -> () "func.func"() <{function_type = (tensor<f64>) -> (tensor<f64>, tensor<f64>), sym_name = "abs2_broadcast_scalar", sym_visibility = "private"}> ({ ^bb0(%arg6: tensor<f64>): %43 = "stablehlo.abs"(%arg6) : (tensor<f64>) -> tensor<f64> %44 = "stablehlo.multiply"(%43, %43) : (tensor<f64>, tensor<f64>) -> tensor<f64> "func.return"(%44, %arg6) : (tensor<f64>, tensor<f64>) -> () }) : () -> () "func.func"() <{function_type = (tensor<2x2xf64>, tensor<2xui64>) -> (tensor<f64>, tensor<2xui64>, tensor<2x2xf64>), sym_name = "Const{var\22#lfn#5\22{Base.RefValue{Any}}}(var\22#lfn#5\22{Base.RefValue{Any}}(Base.RefValue{Any}(nothing)))_autodiff", sym_visibility = "private"}> ({ ^bb0(%arg2: tensor<2x2xf64>, %arg3: tensor<2xui64>): %18 = "stablehlo.transpose"(%arg2) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %19 = "stablehlo.transpose"(%arg3) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %20 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<2x2xf64>}> : () -> tensor<2x2xf64> %21:2 = "stablehlo.rng_bit_generator"(%19) <{rng_algorithm = #stablehlo<rng_algorithm DEFAULT>}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>) %22 = "stablehlo.constant"() <{value = dense<12> : tensor<2x2xui64>}> : () -> tensor<2x2xui64> %23 = "stablehlo.shift_right_logical"(%21#1, %22) : (tensor<2x2xui64>, tensor<2x2xui64>) -> tensor<2x2xui64> %24 = "stablehlo.constant"() <{value = dense<4607182418800017408> : tensor<2x2xui64>}> : () -> tensor<2x2xui64> %25 = "stablehlo.or"(%23, %24) : (tensor<2x2xui64>, tensor<2x2xui64>) -> tensor<2x2xui64> %26 = "stablehlo.bitcast_convert"(%25) : (tensor<2x2xui64>) -> tensor<2x2xf64> %27 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<2x2xf64>}> : () -> tensor<2x2xf64> %28 = "stablehlo.subtract"(%26, %27) : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> %29 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<2x2xf64>}> : () -> tensor<2x2xf64> %30 = "stablehlo.broadcast_in_dim"(%18) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %31 = "stablehlo.broadcast_in_dim"(%28) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %32:3 = "enzyme.batch"(%30, %31) <{batch_shape = array<i64: 2, 2>, fn = @"+_broadcast_scalar"}> : (tensor<2x2xf64>, tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>, tensor<2x2xf64>) %33 = "stablehlo.convert"(%32#0) : (tensor<2x2xf64>) -> tensor<2x2xf64> %34 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64> %35 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<2x2xf64>}> : () -> tensor<2x2xf64> %36 = "stablehlo.broadcast_in_dim"(%33) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %37:2 = "enzyme.batch"(%36) <{batch_shape = array<i64: 2, 2>, fn = @abs2_broadcast_scalar}> : (tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>) %38 = "stablehlo.convert"(%37#0) : (tensor<2x2xf64>) -> tensor<2x2xf64> %39 = "stablehlo.reduce"(%38, %34) <{dimensions = array<i64: 0, 1>}> ({ ^bb0(%arg4: tensor<f64>, %arg5: tensor<f64>): %42 = "stablehlo.add"(%arg4, %arg5) : (tensor<f64>, tensor<f64>) -> tensor<f64> "stablehlo.return"(%42) : (tensor<f64>) -> () }) : (tensor<2x2xf64>, tensor<f64>) -> tensor<f64> %40 = "stablehlo.transpose"(%21#0) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %41 = "stablehlo.transpose"(%18) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> "func.return"(%39, %40, %41) : (tensor<f64>, tensor<2xui64>, tensor<2x2xf64>) -> () }) : () -> () "func.func"() <{arg_attrs = [{tf.aliasing_output = 2 : i32}, {tf.aliasing_output = 3 : i32}], function_type = (tensor<2x2xf64>, tensor<2xui64>) -> (tensor<2x2xf64>, tensor<2xui64>, tensor<2x2xf64>, tensor<2xui64>), sym_name = "main"}> ({ ^bb0(%arg0: tensor<2x2xf64>, %arg1: tensor<2xui64>): %0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %1 = "stablehlo.transpose"(%arg1) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %2 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64> %3 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<2x2xf64>}> : () -> tensor<2x2xf64> %4 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64> %5 = "stablehlo.broadcast_in_dim"(%4) <{broadcast_dimensions = array<i64>}> : (tensor<f64>) -> tensor<2x2xf64> %6 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f64>}> : () -> tensor<f64> %7 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %8 = "stablehlo.transpose"(%1) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %9 = "stablehlo.transpose"(%5) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %10:3 = "enzyme.autodiff"(%7, %8, %6, %9) <{activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], fn = @"Const{var\22#lfn#5\22{Base.RefValue{Any}}}(var\22#lfn#5\22{Base.RefValue{Any}}(Base.RefValue{Any}(nothing)))_autodiff", ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>], width = 1 : i64}> : (tensor<2x2xf64>, tensor<2xui64>, tensor<f64>, tensor<2x2xf64>) -> (tensor<2xui64>, tensor<2x2xf64>, tensor<2x2xf64>) %11 = "stablehlo.transpose"(%10#0) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %12 = "stablehlo.transpose"(%10#1) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %13 = "stablehlo.transpose"(%10#2) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %14 = "stablehlo.transpose"(%13) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %15 = "stablehlo.transpose"(%21#0) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> %16 = "stablehlo.transpose"(%12) <{permutation = array<i64: 1, 0>}> : (tensor<2x2xf64>) -> tensor<2x2xf64> %17 = "stablehlo.transpose"(%11) <{permutation = array<i64: 0>}> : (tensor<2xui64>) -> tensor<2xui64> "func.return"(%14, %15, %16, %17) : (tensor<2x2xf64>, tensor<2xui64>, tensor<2x2xf64>, tensor<2xui64>) -> () }) : () -> () }) {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} : () -> ()
xref LuxDL/Lux.jl#1337
The text was updated successfully, but these errors were encountered:
Successfully merging a pull request may close this issue.
xref LuxDL/Lux.jl#1337
The text was updated successfully, but these errors were encountered: