Skip to content

Commit 6e92e5f

Browse files
authored
[SYCL][E2E] Add free function kernel tests as device function and host function (#18824)
This PR adds new e2e tests for free function kernels extension based on test plan https://github.com/intel/llvm/blob/sycl/sycl/test-e2e/FreeFunctionKernels/test-plan.md#perform-test-that-free-function-kernel-can-be-used-as-device-function-within-another-kernel Extension spec: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc The overall idea behind this test to verify if free function kernel even if marked with one of the properties (`nd_range_kernel` and `single_task_kernel`) can still be used as device or host function.
1 parent 335b7a0 commit 6e92e5f

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// REQUIRES: aspect-usm_shared_allocations
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// This test verifies whether free function kernel can be used as device
6+
// function within another kernel or can be used as normal host function.
7+
8+
#include <numeric>
9+
10+
#include <sycl/usm.hpp>
11+
12+
#include <sycl/ext/oneapi/free_function_queries.hpp>
13+
14+
#include "helpers.hpp"
15+
16+
template <typename T, int Dims>
17+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<Dims>))
18+
void setValues(T *DataPtr, size_t N, T ExpectedResult) {
19+
#if __SYCL_DEVICE_ONLY__
20+
auto GlobalLinId =
21+
syclext::this_work_item::get_nd_item<Dims>().get_global_linear_id();
22+
DataPtr[GlobalLinId] = ExpectedResult;
23+
#else
24+
for (size_t I = 0; I < N; ++I)
25+
DataPtr[I] = ExpectedResult;
26+
#endif
27+
}
28+
29+
template <typename T>
30+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel))
31+
void performReverse(T *DataPtr, size_t N) {
32+
for (size_t I = 0, J = N - 1; I < J; ++I, --J) {
33+
std::swap(DataPtr[I], DataPtr[J]);
34+
}
35+
}
36+
37+
namespace ns {
38+
template <typename T>
39+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::single_task_kernel))
40+
void singleTaskKernel(T *DataPtr, size_t N) {
41+
performReverse(DataPtr, N);
42+
}
43+
} // namespace ns
44+
45+
template <typename T, int Dims>
46+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<Dims>))
47+
void ndRangekernel(T *DataPtr, size_t N, T ExpectedResult) {
48+
setValues<T, Dims>(DataPtr, N, ExpectedResult);
49+
}
50+
51+
int main() {
52+
int Failed = 0;
53+
constexpr size_t N = 256;
54+
{
55+
constexpr int ExpectedResultValue = 111;
56+
std::array<int, N> Numbers;
57+
std::fill(Numbers.begin(), Numbers.end(), 0);
58+
setValues<int, 1>(Numbers.data(), Numbers.size(), ExpectedResultValue);
59+
Failed += performResultCheck(
60+
N, Numbers.data(),
61+
"setValues() free function kernel used as normal host function",
62+
ExpectedResultValue);
63+
}
64+
65+
{
66+
std::array<int, N> Numbers;
67+
std::iota(Numbers.begin(), Numbers.end(), 0);
68+
std::array<int, N> ExpectedResultValues;
69+
std::iota(ExpectedResultValues.begin(), ExpectedResultValues.end(), 0);
70+
std::reverse(ExpectedResultValues.begin(), ExpectedResultValues.end());
71+
performReverse(Numbers.data(), Numbers.size());
72+
Failed += performResultCheck<N>(
73+
Numbers.data(),
74+
"performReverse() free function kernel used as normal host function",
75+
ExpectedResultValues);
76+
}
77+
78+
sycl::queue Queue;
79+
sycl::context Context = Queue.get_context();
80+
81+
{
82+
sycl::kernel UsedKernel = getKernel<ns::singleTaskKernel<int>>(Context);
83+
std::array<int, N> ExpectedResultValues;
84+
std::iota(ExpectedResultValues.begin(), ExpectedResultValues.end(), 0);
85+
std::reverse(ExpectedResultValues.begin(), ExpectedResultValues.end());
86+
87+
int *DataPtr = sycl::malloc_shared<int>(N, Queue);
88+
std::iota(DataPtr, DataPtr + N, 0);
89+
90+
Queue
91+
.submit([&](sycl::handler &Handler) {
92+
Handler.set_args(DataPtr, N);
93+
Handler.single_task(UsedKernel);
94+
})
95+
.wait();
96+
Failed += performResultCheck<N>(
97+
DataPtr,
98+
"performReverse() free function kernel used as device function within "
99+
"another kernel",
100+
ExpectedResultValues);
101+
sycl::free(DataPtr, Queue);
102+
}
103+
104+
{
105+
constexpr int ExpectedResultValue = 222;
106+
107+
sycl::kernel UsedKernel = getKernel<ndRangekernel<int, 3>>(Context);
108+
109+
sycl::nd_range NdRange{sycl::range{16, 4, 4}, sycl::range{2, 2, 2}};
110+
size_t NumberOfElements = NdRange.get_global_range().size();
111+
int *DataPtr = sycl::malloc_shared<int>(NumberOfElements, Queue);
112+
std::fill(DataPtr, DataPtr + NumberOfElements, 0);
113+
Queue
114+
.submit([&](sycl::handler &Handler) {
115+
Handler.set_args(DataPtr, NumberOfElements, ExpectedResultValue);
116+
Handler.parallel_for(NdRange, UsedKernel);
117+
})
118+
.wait();
119+
120+
Failed += performResultCheck(NumberOfElements, DataPtr,
121+
"setValues() free function kernel used as "
122+
"device function within another kernel",
123+
ExpectedResultValue);
124+
sycl::free(DataPtr, Queue);
125+
}
126+
return Failed;
127+
}

sycl/test-e2e/FreeFunctionKernels/helpers.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ static int performResultCheck(size_t NumberOfElements, const T *ResultPtr,
2121
return Failed;
2222
}
2323

24+
template <size_t NumOfElements, typename T, typename S>
25+
static int
26+
performResultCheck(const T *ResultPtr, std::string_view TestName,
27+
std::array<S, NumOfElements> ExpectedResultValue) {
28+
int Failed{0};
29+
for (size_t i = 0; i < NumOfElements; i++) {
30+
if (ResultPtr[i] != ExpectedResultValue[i]) {
31+
std::cerr << "Failed " << TestName << " : " << ResultPtr[i]
32+
<< " != " << ExpectedResultValue[i] << std::endl;
33+
++Failed;
34+
}
35+
}
36+
return Failed;
37+
}
38+
2439
template <auto *Func> static sycl::kernel getKernel(sycl::context &Context) {
2540
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
2641
syclexp::get_kernel_bundle<Func, sycl::bundle_state::executable>(Context);

0 commit comments

Comments
 (0)