Skip to content

Commit 8930bc9

Browse files
committed
Prefix scan for various data types with inclusive/exclusive option
This commit improves the existing ``jit_scan()`` function with support for various data types: - int32/uint32 - uint64 - float - double The user scan now also specify whether the scan should be inclusive or exclusive. Finally, the commit adds comments to facilitate future modifications of this code.
1 parent 8ecaaf2 commit 8930bc9

16 files changed

+10810
-3640
lines changed

include/drjit-core/array.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ Array empty(size_t size) {
364364
: AllocType::HostAsync,
365365
byte_size);
366366
return Array::steal(
367-
jit_var_map_mem(Array::Backend, Array::Type, ptr, size, 1));
367+
jit_var_mem_map(Array::Backend, Array::Type, ptr, size, 1));
368368
}
369369

370370
template <typename Array>

include/drjit-core/jit.h

+31-17
Original file line numberDiff line numberDiff line change
@@ -1592,30 +1592,44 @@ extern JIT_EXPORT void jit_memcpy_async(JIT_ENUM JitBackend backend, void *dst,
15921592
*/
15931593
extern JIT_EXPORT void jit_reduce(JIT_ENUM JitBackend backend, JIT_ENUM VarType type,
15941594
JIT_ENUM ReduceOp rtype,
1595-
const void *ptr, uint32_t size, void *out);
1595+
const void *in, uint32_t size, void *out);
15961596

1597-
/**
1598-
* \brief Perform an exclusive scan / prefix sum over an unsigned 32 bit integer
1599-
* array
1597+
/** \brief Compute n prefix sum over the given input array
1598+
*
1599+
* Both exclusive and inclusive variants are supported. If desired, the scan
1600+
* can be performed in-place (i.e., <tt>out == in</tt>). The operation runs
1601+
* asynchronously.
1602+
*
1603+
* The operation is currenly implemented for the following numeric types:
1604+
* ``VarType::Int32``, ``VarType::UInt32``, ``VarType::UInt64``,
1605+
* ``VarType::Float32``, and ``VarType::Float64``.
16001606
*
1601-
* If desired, the scan can be performed in-place (i.e. <tt>in == out</tt>).
1602-
* Note that the CUDA implementation will round up \c size to the maximum of
1603-
* the following three values for performance reasons:
1607+
* Note that the CUDA implementation may round \c size to the maximum of the
1608+
* following three values for performance and implementation-related reasons
1609+
* (the prefix sum uses a tree-based parallelization scheme).
16041610
*
1605-
* - the value 4,
1611+
* - the value 4
16061612
* - the next highest power of two (when size <= 4096),
16071613
* - the next highest multiple of 2K (when size > 4096),
16081614
*
16091615
* For this reason, the the supplied memory regions must be sufficiently large
1610-
* to avoid both out-of-bounds reads and writes. This is not an issue for
1611-
* memory obtained using \ref jit_malloc(), which internally rounds
1612-
* allocations to the next largest power of two and enforces a 64 byte minimum
1613-
* allocation size.
1614-
*
1615-
* Runs asynchronously.
1616-
*/
1617-
extern JIT_EXPORT void jit_scan_u32(JIT_ENUM JitBackend backend, const uint32_t *in,
1618-
uint32_t size, uint32_t *out);
1616+
* to avoid out-of-bounds reads and writes. This is not an issue for memory
1617+
* obtained using \ref jit_malloc(), which internally rounds allocations to the
1618+
* next largest power of two and enforces a 64 byte minimum allocation size.
1619+
*
1620+
* The CUDA backend implementation for *large* numeric types (double precision
1621+
* floats, 64 bit integers) has the following technical limitation: when
1622+
* reducing 64-bit integers, their values must be smaller than 2**62. When
1623+
* reducing double precision arrays, the two least significant mantissa bits
1624+
* are clamped to zero when forwarding the prefix from one 512-wide block to
1625+
* the next (at a very minor loss in accuracy). The reason is that the
1626+
* operations requires two status bits to coordinate the prefix and status of
1627+
* each 512-wide block, and those must each fit into a single 64 bit value
1628+
* (128-bit writes aren't guaranteed to be atomic).
1629+
*/
1630+
extern JIT_EXPORT void jit_scan(JIT_ENUM JitBackend backend,
1631+
JIT_ENUM VarType type, int exclusive,
1632+
const void *in, uint32_t size, void *out);
16191633

16201634
/**
16211635
* \brief Compress a mask into a list of nonzero indices

resources/Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
COMPUTE_CAPABILITY=compute_70
22
CUDA_VER=10.2
3-
NVCC=/usr/local/cuda-$(CUDA_VER)/bin/nvcc -m64 --ptx --expt-relaxed-constexpr
3+
NVCC=/usr/local/cuda-$(CUDA_VER)/bin/nvcc -m64 --ptx --expt-relaxed-constexpr -std=c++14
44

55
all: kernels.h
66

resources/common.h

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <limits>
66

77
#define KERNEL extern "C" __global__
8+
#define DEVICE __device__
9+
#define FINLINE __forceinline__
10+
#define WARP_SIZE 32
11+
#define FULL_MASK 0xffffffff
812

913
template <typename T> struct SharedMemory {
1014
__device__ inline static T *get() {

resources/compress.cuh

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010

1111
#include "common.h"
1212

13+
DEVICE FINLINE void store_cg(uint64_t *ptr, uint64_t val) {
14+
asm volatile("st.cg.u64 [%0], %1;" : : "l"(ptr), "l"(val));
15+
}
16+
17+
DEVICE FINLINE uint64_t load_cg(uint64_t *ptr) {
18+
uint64_t retval;
19+
asm volatile("ld.cg.u64 %0, [%1];" : "=l"(retval) : "l"(ptr));
20+
return retval;
21+
}
22+
1323
KERNEL void compress_small(const uint8_t *in, uint32_t *out, uint32_t size, uint32_t *count_out) {
1424
uint32_t *shared = SharedMemory<uint32_t>::get();
1525

0 commit comments

Comments
 (0)