diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 577e262bf2..0c64328623 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ steps: - - label: "GPU integration with julia v1.6" + - label: "GPU on julia v1.6" plugins: - JuliaCI/julia#v1: # Drop default "registries" directory, so it is not persisted from execution to execution @@ -12,7 +12,7 @@ steps: cuda: "*" timeout_in_minutes: 60 - - label: "GPU integration with julia v1" + - label: "GPU on julia v1" plugins: - JuliaCI/julia#v1: version: "1" @@ -26,14 +26,16 @@ steps: JULIA_CUDA_USE_BINARYBUILDER: "true" timeout_in_minutes: 60 - # - label: "GPU nightly" - # plugins: - # - JuliaCI/julia#v1: - # version: "nightly" - # - JuliaCI/julia-test#v1: ~ - # agents: - # queue: "juliagpu" - # cuda: "*" - # timeout_in_minutes: 60 + - label: "GPU on julia nightly" + plugins: + - JuliaCI/julia#v1: + version: "nightly" + - JuliaCI/julia-test#v1: ~ + agents: + queue: "juliagpu" + cuda: "*" + env: + JULIA_CUDA_USE_BINARYBUILDER: "true" + timeout_in_minutes: 60 env: SECRET_CODECOV_TOKEN: "fAV/xwuaV0l5oaIYSAXRQIor8h7yHdlrpLUZFwNVnchn7rDk9UZoz0oORG9vlKLc1GK2HhaPRAy+fTkJ3GM/8Y0phHh3ANK8f5UsGm2DUTNsnf6u9izgnwnoRTcsWu+vSO0fyYrxBvBCoJwljL+yZbDFz3oE16DP7HPIzxfQagm+o/kMEszVuoUXhuLXXH0LxT6pXl214qjqs04HfMRmKIIiup48NB6fBLdhGlQz64MdMNHBfgDa/fafB7eNvn0X6pEOxysoy6bDQLUhKelOXgcDx1UsTo34Yiqr+QeJPAeKcO//PWurwQhPoUoHfLad2da9DN4uQk4YQLqAlcIuAA==;U2FsdGVkX1+mRXF2c9soCXT7DYymY3msM+vrpaifiTp8xA+gMpbQ0G63WY3tJ+6V/fJcVnxYoKZVXbjcg8fl4Q==" diff --git a/Project.toml b/Project.toml index 0b02d1b583..feefd026b8 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" + [compat] Adapt = "3.0" CUDA = "3" @@ -41,13 +45,18 @@ StatsBase = "0.33" Zygote = "0.6.49" julia = "1.6" +[extensions] +CUDAExt = ["CUDA", "NNlibCUDA"] + [extras] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "CUDA", "NNlibCUDA", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/ext/CUDAExt/CUDAExt.jl b/ext/CUDAExt/CUDAExt.jl new file mode 100644 index 0000000000..e9a4efd0b0 --- /dev/null +++ b/ext/CUDAExt/CUDAExt.jl @@ -0,0 +1,27 @@ +module CUDAExt + +using CUDA +import NNlib, NNlibCUDA + +using Flux +import Flux: adapt_storage, _gpu, FluxCPUAdaptor, _isleaf, dropout_mask, _dropout_mask + +using Adapt +using ChainRulesCore +using Random +using Zygote + +const use_cuda = Ref{Union{Nothing,Bool}}(nothing) + +include("utils.jl") +include("functor.jl") + +include("layers/normalise.jl") + +include("cudnn.jl") + +function __init__() + Flux.cuda_loaded[] = true +end + +end # module diff --git a/src/cuda/cudnn.jl b/ext/CUDAExt/cudnn.jl similarity index 100% rename from src/cuda/cudnn.jl rename to ext/CUDAExt/cudnn.jl diff --git a/ext/CUDAExt/functor.jl b/ext/CUDAExt/functor.jl new file mode 100644 index 0000000000..0e270a2dce --- /dev/null +++ b/ext/CUDAExt/functor.jl @@ -0,0 +1,46 @@ +struct FluxCUDAAdaptor end +adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) +adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) +if VERSION >= v"1.7" + adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() +else + adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng() +end +adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x +adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = +error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") + +# TODO: figure out the correct design for OneElement +adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) + + +adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) +adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() + +function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray) + Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),) +end + +function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) + adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),) +end + + +function _gpu(x) + check_use_cuda() + use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x +end + +function check_use_cuda() + if use_cuda[] === nothing + use_cuda[] = CUDA.functional() + if use_cuda[] && !CUDA.has_cudnn() + @warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." + end + if !(use_cuda[]) + @info """The GPU function is being called but the GPU is not accessible. + Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1 + end + end +end +ChainRulesCore.@non_differentiable check_use_cuda() diff --git a/ext/CUDAExt/layers/normalise.jl b/ext/CUDAExt/layers/normalise.jl new file mode 100644 index 0000000000..37283d06c2 --- /dev/null +++ b/ext/CUDAExt/layers/normalise.jl @@ -0,0 +1,3 @@ +dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) +dropout_mask(rng, x::CuArray, p; kwargs...) = + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) diff --git a/ext/CUDAExt/utils.jl b/ext/CUDAExt/utils.jl new file mode 100644 index 0000000000..1553e0f88d --- /dev/null +++ b/ext/CUDAExt/utils.jl @@ -0,0 +1 @@ +rng_from_array(::CuArray) = CUDA.default_rng() diff --git a/src/Flux.jl b/src/Flux.jl index 66796491dd..0d8bf00489 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -39,9 +39,6 @@ include("train.jl") using .Train # using .Train: setup, @train_autodiff -using CUDA -const use_cuda = Ref{Union{Nothing,Bool}}(nothing) - using Adapt, Functors, OneHotArrays include("utils.jl") include("functor.jl") @@ -67,6 +64,9 @@ using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12 include("deprecations.jl") -include("cuda/cuda.jl") +# If package extensions are not supported in this Julia version +if !isdefined(Base, :get_extension) + include("../ext/CUDAExt/CUDAExt.jl") +end end # module diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl deleted file mode 100644 index 6e18a066af..0000000000 --- a/src/cuda/cuda.jl +++ /dev/null @@ -1,11 +0,0 @@ -module CUDAint - -using ..CUDA - -import ..Flux: Flux -using ChainRulesCore -import NNlib, NNlibCUDA - -include("cudnn.jl") - -end diff --git a/src/functor.jl b/src/functor.jl index 986574e33e..31419ec66a 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -90,45 +90,20 @@ end # Allows caching of the parameters when params is called within gradient() to fix #2040. # @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054 -# That speeds up implicit use, and silently breaks explicit use. +# That speeds up implicit use, and silently breaks explicit use. # From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing -struct FluxCUDAAdaptor end -adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) -if VERSION >= v"1.7" - adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() -else - adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng() -end -adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x -adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = - error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") - -# TODO: figure out the correct design for OneElement -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) - struct FluxCPUAdaptor end # define rules for handling structured arrays adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x -adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x -adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x -function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray) - Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),) -end - -function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),) -end - # CPU/GPU movement conveniences """ @@ -163,11 +138,17 @@ _isbitsarray(x) = false _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) +const cuda_loaded = Ref{Bool}(false) + """ gpu(x) +Requires CUDA and NNlibCUDA to be loaded +```julia-rept +julia> using Flux, CUDA, NNlibCUDA +``` Moves `m` to the current GPU device, if available. It is a no-op otherwise. -See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/) +See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/) to help identify the current device. This works for functions, and any struct marked with [`@functor`](@ref). @@ -187,23 +168,17 @@ CuArray{Float32, 2} ``` """ function gpu(x) - check_use_cuda() - use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x -end - -function check_use_cuda() - if use_cuda[] === nothing - use_cuda[] = CUDA.functional() - if use_cuda[] && !CUDA.has_cudnn() - @warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." - end - if !(use_cuda[]) - @info """The GPU function is being called but the GPU is not accessible. - Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1 - end + if cuda_loaded[] + return _gpu(x) + else + @info """ + The GPU functionality is being called via `Flux.gpu` but `NNlibCUDA` + must be loaded to access GPU functionality""" maxlog=1 + return x end end -ChainRulesCore.@non_differentiable check_use_cuda() + +function _gpu end # Precision diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 89eee976ee..d5e68df041 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -36,9 +36,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true) end dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) -dropout_mask(rng, x::CuArray, p; kwargs...) = - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) function _dropout_mask(rng, x, p; dims=:) realfptype = float(real(eltype(x))) @@ -56,9 +53,9 @@ ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) Dropout layer. While training, for each input, this layer either sets that input to `0` (with probability -`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the +`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the `dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input -(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during +(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during training. In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..5b4a1d697b 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -5,7 +5,6 @@ using Zygote using Zygote: @adjoint using ChainRulesCore using ..Flux: ofeltype, epseltype -using CUDA using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/utils.jl b/src/utils.jl index 884fcd7465..99ef13fdb4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,7 +44,6 @@ The current defaults are: - Julia version is >= 1.7: `Random.default_rng()` """ rng_from_array(::AbstractArray) = default_rng_value() -rng_from_array(::CuArray) = CUDA.default_rng() @non_differentiable rng_from_array(::Any) @@ -226,7 +225,7 @@ ChainRulesCore.@non_differentiable kaiming_normal(::Any...) """ truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array truncated_normal([rng]; kw...) -> Function - + Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. @@ -393,7 +392,7 @@ Has the following behaviour * 2D: An identity matrix (useful for an identity matrix multiplication) * More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution) -Some caveats: +Some caveats: * Not all layers will be identity mapping when used with this init. Exceptions include recurrent layers and normalization layers. diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 2b4fec6e4c..5a8df0029a 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -1,5 +1,5 @@ using Flux, Test -using Flux.CUDA +using CUDA using Flux: cpu, gpu using Statistics: mean using LinearAlgebra: I, cholesky, Cholesky @@ -91,7 +91,7 @@ end struct SimpleBits field::Int32 end - + @test gpu((;a=ones(1))).a isa CuVector{Float32} @test gpu((;a=['a', 'b', 'c'])).a isa CuVector{Char} @test gpu((;a=[SimpleBits(1)])).a isa CuVector{SimpleBits} diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 8024681a06..8eaa7d8523 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -61,7 +61,7 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te if isnothing(gs_cpu[p_cpu]) @test isnothing(gs_gpu[p_gpu]) else - @test gs_gpu[p_gpu] isa Flux.CUDA.CuArray + @test gs_gpu[p_gpu] isa CUDA.CuArray if test_cpu @test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1f-3 atol=1f-3 end @@ -259,7 +259,7 @@ end input = randn(10, 10, 10, 10) |> gpu layer_gpu = Parallel(+, zero, identity) |> gpu @test layer_gpu(input) == input - @test layer_gpu(input) isa Flux.CUDA.CuArray + @test layer_gpu(input) isa CUDA.CuArray end @testset "vararg input" begin diff --git a/test/runtests.jl b/test/runtests.jl index 29b2bad311..b0a30fe969 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Random, Statistics, LinearAlgebra using IterTools: ncycle using Zygote using CUDA +using NNlibCUDA # required to trigger CUDAExt Random.seed!(0)