diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 66882fb6a8..508ff06b9c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,12 @@ jobs: version: '1.10' assertions: true test_group: neural_networks + - os: ubuntu-20.04 + arch: x64 + libReactant: packaged + version: '1.10' + assertions: true + test_group: integration - os: ubuntu-20.04 arch: x86 libReactant: packaged diff --git a/Project.toml b/Project.toml index 9a1277d9d6..2ecce09b36 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Scratch = "6c6a2e73-6563-6170-7368-637461726353" @@ -23,17 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" +[sources] +ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" +ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -50,6 +53,8 @@ LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" +Random = "1.10" +Random123 = "1.7" ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" diff --git a/docs/make.jl b/docs/make.jl index 7515a566db..fcbaca60ef 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -43,6 +43,7 @@ pages = [ ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", + "Internal API" => "api/internal.md", ], ] diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 942a9415df..1dc25f2ad5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -78,7 +78,8 @@ export default defineConfig({ { text: "MLIR API", link: "/api/mlirc" }, { text: "XLA", link: "/api/xla" }, ], - } + }, + { text: "Internal API", link: "/api/internal" }, ], }, { @@ -132,6 +133,7 @@ export default defineConfig({ { text: "XLA", link: "/api/xla" }, ], }, + { text: "Internal API", link: "/api/internal" }, ], }, }, diff --git a/docs/src/api/internal.md b/docs/src/api/internal.md new file mode 100644 index 0000000000..a8788e5fb9 --- /dev/null +++ b/docs/src/api/internal.md @@ -0,0 +1,12 @@ +```@meta +CollapsedDocStrings = true +``` + +# Internal API + +These functions are not part of the public API and are subject to change at any time. + +```@docs +Reactant.REDUB_ARGUMENTS_NAME +Reactant.within_reactant_interpreter +``` diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl new file mode 100644 index 0000000000..d701fdc7e4 --- /dev/null +++ b/ext/ReactantRandom123Ext.jl @@ -0,0 +1,11 @@ +module ReactantRandom123Ext + +using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x +using Reactant: TracedRandom + +TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" +TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" +TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" +TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" + +end diff --git a/src/Ops.jl b/src/Ops.jl index fa8c17b3ca..18ab2d7d4b 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1016,19 +1016,150 @@ end end # random ops +""" + rng_bit_generator( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Generate a random array of type `T` with the given shape and seed from a uniform random +distribution between 0 and 1. Returns a NamedTuple with the following fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" @noinline function rng_bit_generator( + ::Type{T}, seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) - output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape) +) where {T<:Integer} + @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") + if algorithm == "PHILOX" + @assert length(seed) ∈ (2, 3) + elseif algorithm == "THREE_FRY" + @assert length(seed) == 2 + end + + output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) + output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64)) rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm) - op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location) + op = stablehlo.rng_bit_generator( + seed.mlir_data; output, output_state, rng_algorithm, location + ) return (; - output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)), - output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape), + output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)), + output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)), + ) +end + +@noinline function rng_bit_generator( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), +) where {T<:AbstractFloat} + nbits = sizeof(T) * 8 + uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64) + (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) + output = divide( + convert(TracedRArray{T,ndims(output)}, output), + constant(fill(T(typemax(uT)), Tuple(shape)); location), + ) + return (; output_state, output) +end + +""" + randn( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), ) + +Generate a random array of type `T` with the given shape and seed from a standard normal +distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following +fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" +@noinline function randn( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), +) where {T} + res = rng_bit_generator(T, seed, shape; algorithm, location) + rand_uniform = res.output + seed = res.output_state + scaled_uniform = subtract( + multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))), + constant(fill(T(1), size(rand_uniform))), + ) + probit = erf_inv(scaled_uniform) + rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform)))) + return (; output_state=seed, output=rand_normal) +end + +""" + randexp( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + ) + +Generate a random array of type `T` with the given shape and seed from an exponential +distribution with rate 1. Returns a NamedTuple with the following fields: + +- `output_state`: The state of the random number generator after the operation. +- `output`: The generated array. + +# Arguments + +- `T`: The type of the generated array. +- `seed`: The seed for the random number generator. +- `shape`: The shape of the generated array. +- `algorithm`: The algorithm to use for generating the random numbers. Defaults to + "DEFAULT". Other options include "PHILOX" and "THREE_FRY". +""" +@noinline function randexp( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rand", @__FILE__, @__LINE__), +) where {T} + res = rng_bit_generator(T, seed, shape; algorithm, location) + rand_uniform = res.output + seed = res.output_state + rand_exp = negate(log_plus_one(negate(rand_uniform))) + return (; output_state=seed, output=rand_exp) end # functional ops diff --git a/src/Overlay.jl b/src/Overlay.jl index 6d4752acd9..b9785b7fa3 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -3,6 +3,15 @@ # correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved # we should move all the reactant_overrides to relevant files. +# Helper Function to determine if we are inside the ReactantInterpreter +""" + within_reactant_interpreter() + +Returns `true` if we are currently inside the ReactantInterpreter. +""" +@noinline within_reactant_interpreter() = false +@reactant_overlay @noinline within_reactant_interpreter() = true + # Compiling within a compile should return simply the original function @reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false @@ -10,7 +19,7 @@ return f end -# Enzyme overrides +# Enzyme.jl overlays @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} @@ -22,3 +31,87 @@ end ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end + +# Random.jl overlays +@reactant_overlay @noinline function Random.default_rng() + return call_with_reactant(TracedRandom.default_rng) +end + +## Only problematic edge case here is the direct `(rng, A::AbstractArray)` call +## We can't directly overlay that call without breaking the semantics of inplace update +for randfun in (:rand, :randn, :randexp) + randfun! = Symbol(randfun, :!) + overload_randfun = Symbol(:overload_, randfun) + overload_randfun! = Symbol(:overload_, randfun!) + + @eval begin + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}, dims::Dims + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T, dims) + end + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T, dims) + end + + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, dim1::Integer, dims::Integer... + ) + return TracedRandom.$(overload_randfun)(rng, dim1, dims...) + end + + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) + end + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T, dim1, dims...) + end + + # scalars + @reactant_overlay @noinline function Random.$(randfun)( + rng::AbstractRNG, ::Type{T}=Float64 + ) where {T} + if T <: ReactantPrimitive + return TracedRandom.$(overload_randfun)(rng, T) + end + return error( + "Reactant doesn't support sampling of $(T) with the current interpreter." + ) + # XXX: The following will lead to illegal instruction + # @warn "Reactant doesn't support sampling of $(T) with the current \ + # interpreter. Falling back to native interpreter." maxlog = 1 + # return Random.$(randfun)(rng, T) + end + + # inplace + @reactant_overlay @noinline function Random.$(randfun!)( + rng::AbstractRNG, A::AnyTracedRArray + ) + return TracedRandom.$(overload_randfun!)(rng, A) + end + + # XXX: Uncomment once AbsInt issues with recursive calls are resolved + # @reactant_overlay @noinline function Random.$(randfun!)( + # rng::AbstractRNG, A::AbstractArray + # ) + # @warn "Directly writing to an array using Random.jl functions inside \ + # ReactantInterpreter will generate a constant array in the IR. Use with \ + # caution." maxlog = 1 + # return Random.$(randfun!)(rng, A) + # end + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index e7c8805de9..bea0150744 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,6 +3,8 @@ module Reactant using ReactantCore: ReactantCore, @trace, MissingTracedValue using LinearAlgebra: LinearAlgebra +using Random: Random, AbstractRNG + using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` @@ -122,7 +124,14 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -include("linear_algebra.jl") +mutable struct TracedRNG <: Random.AbstractRNG + seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} + const algorithm::String +end + +# StdLib Overloads +include("stdlibs/LinearAlgebra.jl") +include("stdlibs/Random.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/linear_algebra.jl b/src/stdlibs/LinearAlgebra.jl similarity index 100% rename from src/linear_algebra.jl rename to src/stdlibs/LinearAlgebra.jl diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl new file mode 100644 index 0000000000..271b78f802 --- /dev/null +++ b/src/stdlibs/Random.jl @@ -0,0 +1,168 @@ +module TracedRandom + +# Implementation based on the following: +# 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl +# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl + +using ..Reactant: + Reactant, + TracedRArray, + TracedRNumber, + TracedRNG, + AnyTracedRArray, + Reactant, + TracedUtils, + Ops, + ConcreteRArray +using Random: Random, AbstractRNG + +@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) + # XXX: We should really be able to call this here. But with our AbsInt it leads to a + # segfault. So we'll just call it in the rand! method. + # return rand(rng, UInt64, 2) + seed = Array{UInt64}(undef, 2) + Random.rand!(rng, seed) + return seed +end + +function Random.seed!(rng::TracedRNG, seed::Number) + if seed isa TracedRNumber + error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \ + `TracedRArray` of the appropriate size instead.") + end + + seed = reinterpret(UInt64, Random.hash_seed(seed)) + seed = if Reactant.within_reactant_interpreter() + TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) + else + ConcreteRArray(seed[1:length(rng.seed)]) + end + return Random.seed!(rng, seed) +end + +function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1}) + return Random.seed!(rng, UInt64.(seed)) +end + +function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1}) + return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed)) +end + +function Random.seed!( + rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} +) + rng.seed = seed + return rng +end + +@noinline TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) +@noinline TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") + +@noinline function default_rng() + Reactant.within_reactant_interpreter() || return TracedRNG() + return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") +end + +@noinline rng_algorithm(rng::TracedRNG) = rng.algorithm +@noinline rng_algorithm(::AbstractRNG) = "DEFAULT" + +@noinline function internal_overload_rand!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} + length(A) == 0 && return A + res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + TracedUtils.set_mlir_data!(A, res.output.mlir_data) + return A +end + +@noinline function internal_overload_randn!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} + length(A) == 0 && return A + res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + TracedUtils.set_mlir_data!(A, res.output.mlir_data) + return A +end + +@noinline function internal_overload_randexp!( + rng::TracedRNG, A::AnyTracedRArray{T,N} +) where {T,N} + length(A) == 0 && return A + res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm) + rng.seed = res.output_state + TracedUtils.set_mlir_data!(A, res.output.mlir_data) + return A +end + +for randfun in (:rand, :randn, :randexp) + randfun! = Symbol(randfun, :!) + overload_randfun = Symbol(:internal_overload_, randfun) + overload_randfun! = Symbol(:internal_overload_, randfun!) + + @eval begin + @noinline function $(overload_randfun)( + rng::TracedRNG, ::Type{T}, dims::Dims + ) where {T} + return $(overload_randfun!)( + rng, TracedRArray{T,length(dims)}((), nothing, dims) + ) + end + + @noinline function $(overload_randfun)(rng::TracedRNG, dims::Dims) + return $(overload_randfun)(rng, Float64, dims) + end + + @noinline function $(overload_randfun)( + rng::TracedRNG, dim1::Integer, dims::Integer... + ) + return $(overload_randfun)(rng, Dims((dim1, dims...))) + end + + @noinline function $(overload_randfun)( + rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... + ) where {T} + return $(overload_randfun)(rng, T, Dims((dim1, dims...))) + end + + @noinline function $(overload_randfun!)(A::AnyTracedRArray) + return $(overload_randfun!)(default_rng(), A) + end + + # scalars + @noinline function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} + A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0))) + $(overload_randfun!)(rng, A) + return TracedRNumber{T}((), A.mlir_data) + end + end +end + +# call from overlay-ed variants. we write this with 2 tiers -- overload_* and +# internal_overload_* -- to avoid method ambiguities +for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) + overload_randfun = Symbol(:overload_, randfun) + internal_overload_randfun = Symbol(:internal_overload_, randfun) + @eval begin + @noinline function $(overload_randfun)(rng::AbstractRNG, args...) + rng = TracedRNG( + TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)), + rng_algorithm(rng), + ) + return $(internal_overload_randfun)(rng, args...) + end + + @noinline function $(overload_randfun)(rng::TracedRNG, args...) + return $(internal_overload_randfun)(rng, args...) + end + end +end + +# TODO: At some later point we might want to implement the sampler API as well since it +# makes all RNG implementation work by default. From the post-optimize IR we need to +# confirm that the dynamic_update_slice calls are optimized away into a single +# `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on +# how the seeding should work? + +end diff --git a/src/utils.jl b/src/utils.jl index 16b784d587..b8eb028494 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,7 +99,8 @@ function should_rewrite_ft(@nospecialize(ft)) # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions if has_ancestor(mod, Reactant.Ops) || has_ancestor(mod, Reactant.TracedUtils) || - has_ancestor(mod, Reactant.MLIR) + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Reactant.TracedRandom) return false end end @@ -305,7 +306,7 @@ function call_with_reactant_generator( overdubbed_codelocs = Int32[] # No method could be found (including in our method table), bail with an error - if lookup_result == nothing + if lookup_result === nothing return stub(world, source, method_error) end @@ -501,7 +502,7 @@ function call_with_reactant_generator( # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require takign the function argument. We can work around the latter + # Opaque closures also require taking the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if false && Base.issingletontype(args[1]) res = Core._call_in_world_total( diff --git a/test/Project.toml b/test/Project.toml index 9b3c5a6b49..cb0ccc4f69 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,10 +2,12 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -15,10 +17,13 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/integration/random.jl b/test/integration/random.jl new file mode 100644 index 0000000000..275e0e2447 --- /dev/null +++ b/test/integration/random.jl @@ -0,0 +1,187 @@ +using Reactant, Test, Random, Random123, StableRNGs, Statistics +using StatsBase, Statistics, HypothesisTests, Distributions + +# First Testing overlay works correctly +@testset "Random.jl Overlay" begin + hlo = @code_hlo rand(Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float64, (2, 3)) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister(), Float64) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + hlo = @code_hlo rand(MersenneTwister()) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + fn(x) = begin + # XXX: MersenneTwister without seed leads to illegal instructions + rng = MersenneTwister(0) + Random.rand!(rng, x) + return x + end + hlo = @code_hlo fn(Reactant.to_rarray(rand(Float64, 2, 3))) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + fn2() = begin + # XXX: MersenneTwister without seed leads to illegal instructions + rng = MersenneTwister(0) + x = zeros(Float64, 2, 3) + Random.rand!(rng, x) + return x + end + hlo = @code_hlo fn2() + @test !contains(repr(hlo), "stablehlo.rng_bit_generator") +end + +@testset "Random123" begin + hlo = @code_hlo rand(Random123.Threefry4x(), Float32, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "THREE_FRY") + + hlo = @code_hlo rand(Random123.Threefry2x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "THREE_FRY") + + hlo = @code_hlo rand(Random123.Philox4x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "PHILOX") + + hlo = @code_hlo rand(Random123.Philox2x(), Float64, 2, 3) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + @test contains(repr(hlo), "PHILOX") +end + +# Next we test that the random number generators actually generate data from the correct +# distributions +@testset "Uniform Random" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) + + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return rand(rng, 10000) + end + + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + ks_test = ExactOneSampleKSTest(X, Uniform(0.0, 1.0)) + @test pvalue(ks_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Correct Range" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + X_min, X_max = extrema(X) + @test X_min ≥ 0.0 + @test X_max ≤ 1.0 + end + + @testset "Mean & Variance" begin + X = Array(@jit(rand(StableRNG(0), 10000))) + μ = mean(X) + σ² = var(X) + @test μ ≈ 0.5 atol = 0.05 rtol = 0.05 + @test σ² ≈ (1//12) atol = 0.05 rtol = 0.05 + end +end + +@testset "Normal Distribution" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) + + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return randn(rng, 10000) + end + + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + sw_test = ShapiroWilkTest(X) + @test pvalue(sw_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Mean & Variance" begin + X = Array(@jit(randn(StableRNG(0), 10000))) + μ = mean(X) + σ² = var(X) + @test μ ≈ 0.0 atol = 0.05 rtol = 0.05 + @test σ² ≈ 1.0 atol = 0.05 rtol = 0.05 + end +end + +@testset "Exponential Distribution" begin + @testset "Deterministic Seed" begin + seed1 = ConcreteRArray(UInt64[1, 3]) + seed2 = ConcreteRArray(UInt64[1, 5]) + + fn(seed) = begin + rng = Random.default_rng() + Random.seed!(rng, seed) + return randexp(rng, 10000) + end + + fn_compiled = @compile fn(seed1) + @test fn_compiled(seed1) ≈ fn_compiled(seed1) + @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) + end + + @testset "Correct Distribution" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + ks_test = ExactOneSampleKSTest(X, Exponential(1.0)) + @test pvalue(ks_test) > 0.05 + end + + @testset "AutoCorrelation" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + autocorr = cor(X[1:(end - 1)], X[2:end]) + @test abs(autocorr) < 0.05 + end + + @testset "Correct Range" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + X_min, X_max = extrema(X) + @test X_min ≥ 0.0 + end + + @testset "Mean" begin + X = Array(@jit(randexp(StableRNG(0), 10000))) + μ = mean(X) + @test μ ≈ 1.0 atol = 0.05 rtol = 0.05 + end +end diff --git a/test/nn/lux.jl b/test/nn/lux.jl index 49fa37f52c..7916ce10fd 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -8,7 +8,7 @@ end function gradient_loss_function(model, x, y, ps, st) dps = Enzyme.make_zero(ps) _, res = Enzyme.autodiff( - ReverseWithPrimal, + set_runtime_activity(ReverseWithPrimal), loss_function, Active, Const(model), diff --git a/test/ops.jl b/test/ops.jl index 07f911e88b..82ec4cc8b8 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -538,8 +538,50 @@ end end @testset "rng_bit_generator" begin - # seed = ConcreteRArray([0, 0]) - # @jit Ops.rng_bit_generator(seed, [2]) + genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4]) + genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) + genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) + genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4]) + genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4]) + + @testset for (alg, sz) in + [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] + seed = ConcreteRArray(zeros(UInt64, sz)) + + res = @jit genInt32(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int32,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genInt64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int64,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genUInt64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{UInt64,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genFloat32(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float32,2} + @test size(res.output) == (2, 4) + + seed = res.output_state + res = @jit genFloat64(seed) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float64,2} + @test size(res.output) == (2, 4) + end end @testset "round_nearest_afz" begin diff --git a/test/runtests.jl b/test/runtests.jl index fddc963ced..68dfcaead3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") + @safetestset "Random" include("integration/random.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"