Skip to content

Commit d4e7c76

Browse files
CUDA kernels take 3 (#427)
* CUDA take 3 * conditional run cuda * Update test/integration/cuda.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bump enzymexla * fix * fix gpu reg * Update BUILD * Update BUILD * Update Project.toml * Update ReactantCUDAExt.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix reactant method blocker * Update ReactantCUDAExt.jl * only do compile * use names in cache * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * cleanup further gc issues * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b0a58bd commit d4e7c76

File tree

11 files changed

+88
-85
lines changed

11 files changed

+88
-85
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ PythonCall = "0.9"
6060
Random = "1.10"
6161
Random123 = "1.7"
6262
ReactantCore = "0.1.3"
63-
Reactant_jll = "0.0.32"
63+
Reactant_jll = "0.0.33"
6464
Scratch = "1.2"
6565
Statistics = "1.10"
6666
YaoBlocks = "0.13"

deps/ReactantExtra/.bazelrc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ build -c opt
1818
build:cuda --repo_env TF_NEED_CUDA=1
1919
build:cuda --repo_env TF_NVCC_CLANG=1
2020
build:cuda --repo_env TF_NCCL_USE_STUB=1
21-
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
22-
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
21+
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2"
22+
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.6.0"
2323
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
2424
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
2525
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"

deps/ReactantExtra/BUILD

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ cc_toolchain_config(
5454
coverage_link_flags = ["--coverage"],
5555
cpu = "k8",
5656
cxx_builtin_include_directories = [
57+
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0",
58+
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/x86_64-linux-musl",
59+
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/backward",
5760
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0",
5861
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl",
5962
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward",
@@ -149,14 +152,14 @@ cc_toolchain_config(
149152
abi_libc_version = "local",
150153
abi_version = "local",
151154
cxx_builtin_include_directories = [
152-
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
153-
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
155+
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
156+
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
154157
"/opt/BB_TARGET/BB_TARGET/include",
155158
"/opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
156-
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
157-
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
158-
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
159-
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel"
159+
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
160+
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
161+
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
162+
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel"
160163
],
161164
tool_paths = {
162165
"ar": "/opt/bin/BB_FULL_TARGET/ar",
@@ -193,14 +196,14 @@ cc_toolchain_config(
193196
"-Wno-free-nonheap-object",
194197
"-fno-omit-frame-pointer",
195198
# TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually
196-
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
197-
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
199+
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
200+
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
198201
"-isystem /opt/BB_TARGET/BB_TARGET/include",
199202
"-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
200-
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
201-
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
202-
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
203-
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel",
203+
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
204+
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
205+
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
206+
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel",
204207
],
205208
opt_compile_flags = [
206209
"-g0",
@@ -361,6 +364,7 @@ cc_library(
361364

362365
) + [
363366
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
367+
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
364368
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
365369
# "@xla//xla:xla.pb.cc",
366370
"@xla//xla:xla_data.pb.cc",
@@ -429,7 +433,7 @@ cc_library(
429433
"-Wl,-exported_symbol,_ifrt_*",
430434
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
431435
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
432-
"-Wl,-exported_symbol,_EnzymeGPUCustomCall",
436+
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
433437
"-Wl,-exported_symbol,_ReactantThrowError",
434438
]}),
435439
deps = [
@@ -469,6 +473,9 @@ cc_library(
469473
"@llvm-project//llvm:X86CodeGen",
470474
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
471475
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
476+
# "@enzyme_ad//src/enzyme_ad/jax:gpu",
477+
"@xla//xla/ffi/api:ffi",
478+
"@xla//xla/ffi:ffi_api",
472479
"@stablehlo//:chlo_ops",
473480
"@xla//xla/pjrt:pjrt_api",
474481
"@xla//xla/pjrt:pjrt_c_api_client",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486"
12+
ENZYMEXLA_COMMIT = "74046d05089c02946058f8fd94ed23efd0bf3ccc"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

deps/build_local.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ else
9393
run(
9494
Cmd(
9595
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
96+
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
97+
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
9698
--repo_env HERMETIC_PYTHON_VERSION="3.10"
9799
--check_visibility=false --verbose_failures :libReactantExtra.so`;
98100
dir=source_dir,

ext/ReactantCUDAExt.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,11 @@ function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where
218218
return res
219219
end
220220

221-
const _kernel_instances = Dict{Any,Any}()
222-
221+
# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
222+
# and not the operation itself).
223223
struct LLVMFunc{F,tt}
224224
f::Union{F,Nothing}
225-
entry::MLIR.IR.Operation
225+
entry::String
226226
end
227227

228228
const GPUCompiler = CUDA.GPUCompiler
@@ -324,9 +324,9 @@ function compile(job)
324324
)::MLIR.API.MlirOperation
325325

326326
entry = MLIR.IR.Operation(linkRes)
327-
328-
entry
327+
String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name"))
329328
end
329+
330330
return LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry)
331331
end
332332

@@ -378,9 +378,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
378378

379379
output_operand_aliases = MLIR.IR.Attribute(aliases)
380380

381-
fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
382-
# Force public for now while we don't have real users
383-
# MLIR.IR.rmattr!(func.entry, "sym_visibility")
381+
fname = func.entry
384382

385383
operands = MLIR.IR.Value[]
386384
for idx in
@@ -460,25 +458,27 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
460458
end
461459

462460
function __init__()
463-
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
464-
if handle === nothing
465-
handle = C_NULL
466-
end
467-
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
468-
if ptr1 === nothing
469-
ptr1 = C_NULL
470-
end
471-
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
472-
if ptr2 === nothing
473-
ptr2 = C_NULL
474-
end
475-
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
476-
if ptr3 === nothing
477-
ptr3 = C_NULL
461+
if CUDA.CUDA_Driver_jll.libcuda !== nothing
462+
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
463+
if handle === nothing
464+
handle = C_NULL
465+
end
466+
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
467+
if ptr1 === nothing
468+
ptr1 = C_NULL
469+
end
470+
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
471+
if ptr2 === nothing
472+
ptr2 = C_NULL
473+
end
474+
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
475+
if ptr3 === nothing
476+
ptr3 = C_NULL
477+
end
478+
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
479+
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
480+
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
478481
end
479-
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
480-
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
481-
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
482482
return nothing
483483
end
484484

src/XLA.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,7 @@ function __init__()
144144
end
145145
end
146146

147-
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget(
148-
"enzymexla_gpu"::Cstring,
149-
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid},
150-
"CUDA"::Cstring,
151-
)::Cvoid
147+
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
152148

153149
# This wasn't properly exported on macos, we'll remove the try once macOS JLL
154150
# has the fix.

src/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,10 @@ function call_with_reactant_generator(
364364
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
365365
end
366366

367-
ir, any_changed = rewrite_insts!(ir, interp)
367+
if !is_reactant_method(mi::Core.MethodInstance)
368+
ir, any_changed = rewrite_insts!(ir, interp)
369+
end
370+
368371
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
369372
src.slotnames = fill(:none, length(ir.argtypes) + 1)
370373
src.slotflags = fill(zero(UInt8), length(ir.argtypes))

test/cuda.jl

Lines changed: 0 additions & 36 deletions
This file was deleted.

test/integration/cuda.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using Reactant
2+
using Test
3+
using CUDA
4+
5+
function square_kernel!(x, y)
6+
i = threadIdx().x
7+
x[i] *= y[i]
8+
sync_threads()
9+
return nothing
10+
end
11+
12+
# basic squaring on GPU
13+
function square!(x, y)
14+
@cuda blocks = 1 threads = length(x) square_kernel!(x, y)
15+
return nothing
16+
end
17+
18+
@testset "Square Kernel" begin
19+
oA = collect(1:1:64)
20+
A = Reactant.to_rarray(oA)
21+
B = Reactant.to_rarray(100 .* oA)
22+
if CUDA.functional()
23+
@jit square!(A, B)
24+
@test all(Array(A) .≈ (oA .* oA .* 100))
25+
@test all(Array(B) .≈ (oA .* 100))
26+
else
27+
@code_hlo optimize = :before_kernel square!(A, B)
28+
end
29+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5959
end
6060

6161
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
62+
# Temporarily disabled as minutia are debugged
63+
# @safetestset "CUDA" include("integration/cuda.jl")
6264
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
6365
@safetestset "AbstractFFTs" include("integration/fft.jl")
6466
@safetestset "Random" include("integration/random.jl")

0 commit comments

Comments
 (0)