diff --git a/docs/src/api.md b/docs/src/api.md index 9373d231a..4c37d8c6f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -13,13 +13,22 @@ @uniform @groupsize @ndrange -synchronize -allocate +``` + +### Reduction + +```@docs +@groupreduce +@warp_groupreduce +KernelAbstractions.shfl_down +KernelAbstractions.supports_warp_reduction ``` ## Host language ```@docs +synchronize +allocate KernelAbstractions.zeros ``` diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index b82dadc54..29e0d0905 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -793,6 +793,8 @@ argconvert(k::Kernel{T}, arg) where {T} = supports_enzyme(::Backend) = false function __fake_compiler_job end +include("groupreduction.jl") + ### # Extras # - LoopInfo diff --git a/src/groupreduction.jl b/src/groupreduction.jl new file mode 100644 index 000000000..ed42a5377 --- /dev/null +++ b/src/groupreduction.jl @@ -0,0 +1,120 @@ +export @groupreduce, @warp_groupreduce + +""" + @groupreduce op val neutral [groupsize] + +Perform group reduction of `val` using `op`. + +# Arguments + +- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`. + +- `groupsize` specifies size of the workgroup. + If a kernel does not specifies `groupsize` statically, then it is required to + provide `groupsize`. + Also can be used to perform reduction accross first `groupsize` threads + (if `groupsize < @groupsize()`). + +# Returns + +Result of the reduction. +""" +macro groupreduce(op, val) + :(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__))))))) +end +macro groupreduce(op, val, groupsize) + :(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize))))) +end + +function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize} + storage = @localmem T groupsize + + local_idx = @index(Local) + @inbounds local_idx ≤ groupsize && (storage[local_idx] = val) + @synchronize() + + s::UInt64 = groupsize ÷ 0x02 + while s > 0x00 + if (local_idx - 0x01) < s + other_idx = local_idx + s + if other_idx ≤ groupsize + @inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) + end + end + @synchronize() + s >>= 0x01 + end + + if local_idx == 0x01 + @inbounds val = storage[local_idx] + end + return val +end + +# Warp groupreduce. + +""" + @warp_groupreduce op val neutral [groupsize] + +Perform group reduction of `val` using `op`. +Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic, +followed by final reduction over results of individual warp reductions. + +!!! note + + Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction. +""" +macro warp_groupreduce(op, val, neutral) + :(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__))))))) +end +macro warp_groupreduce(op, val, neutral, groupsize) + :(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize))))) +end + +""" + shfl_down(val::T, offset::Integer)::T where T + +Read `val` from a lane with higher id given by `offset`. +""" +function shfl_down end +supports_warp_reduction() = false + +""" + supports_warp_reduction(::Backend) + +Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction. +""" +supports_warp_reduction(::Backend) = false + +# Assume warp is 32 lanes. +const __warpsize = UInt32(32) +# Maximum number of warps (for a groupsize = 1024). +const __warp_bins = UInt32(32) + +@inline function __warp_reduce(val, op) + offset::UInt32 = __warpsize ÷ 0x02 + while offset > 0x00 + val = op(val, shfl_down(val, offset)) + offset >>= 0x01 + end + return val +end + +function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize} + storage = @localmem T __warp_bins + + local_idx = @index(Local) + lane = (local_idx - 0x01) % __warpsize + 0x01 + warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01 + + # Each warp performs a reduction and writes results into its own bin in `storage`. + val = __warp_reduce(val, op) + @inbounds lane == 0x01 && (storage[warp_id] = val) + @synchronize() + + # Final reduction of the `storage` on the first warp. + within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize + @inbounds val = within_storage ? storage[lane] : neutral + warp_id == 0x01 && (val = __warp_reduce(val, op)) + return val +end diff --git a/test/groupreduce.jl b/test/groupreduce.jl new file mode 100644 index 000000000..658dea402 --- /dev/null +++ b/test/groupreduce.jl @@ -0,0 +1,70 @@ +@kernel cpu=false function groupreduce_1!(y, x, op, neutral) + i = @index(Global) + val = i > length(x) ? neutral : x[i] + res = @groupreduce(op, val) + i == 1 && (y[1] = res) +end + +@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} + i = @index(Global) + val = i > length(x) ? neutral : x[i] + res = @groupreduce(op, val, groupsize) + i == 1 && (y[1] = res) +end + +@kernel cpu=false function warp_groupreduce_1!(y, x, op, neutral) + i = @index(Global) + val = i > length(x) ? neutral : x[i] + res = @warp_groupreduce(op, val, neutral) + i == 1 && (y[1] = res) +end + +@kernel cpu=false function warp_groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} + i = @index(Global) + val = i > length(x) ? neutral : x[i] + res = @warp_groupreduce(op, val, neutral, groupsize) + i == 1 && (y[1] = res) +end + +function groupreduce_testsuite(backend, AT) + # TODO should be a better way of querying max groupsize + groupsizes = "$backend" == "oneAPIBackend" ? + (256,) : + (256, 512, 1024) + + @testset "@groupreduce" begin + @testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes + x = AT(ones(T, n)) + y = AT(zeros(T, 1)) + neutral = zero(T) + op = + + + groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n) + @test Array(y)[1] == n + + for groupsize in (64, 128) + groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n) + @test Array(y)[1] == groupsize + end + end + end + + if KernelAbstractions.supports_warp_reduction(backend()) + @testset "@warp_groupreduce" begin + @testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes + x = AT(ones(T, n)) + y = AT(zeros(T, 1)) + neutral = zero(T) + op = + + + warp_groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n) + @test Array(y)[1] == n + + for groupsize in (64, 128) + warp_groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n) + @test Array(y)[1] == groupsize + end + end + end + end +end diff --git a/test/testsuite.jl b/test/testsuite.jl index 29f780272..528321ea0 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -38,6 +38,7 @@ include("reflection.jl") include("examples.jl") include("convert.jl") include("specialfunctions.jl") +include("groupreduce.jl") function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}()) @conditional_testset "Unittests" skip_tests begin @@ -92,6 +93,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{ examples_testsuite(backend_str) end + # TODO @index(Local) only works as a top-level expression on CPU. + if backend != CPU + @conditional_testset "@groupreduce" skip_tests begin + groupreduce_testsuite(backend, AT) + end + end + return end