Open
Description
I was testing out this example from discourse on the main
branch to see if PoCL helps speed up the matrix case where we lagged behind JAX, but it seems that this example stochastically segfaults on my machine:
MWE:
using KernelAbstractions
abstract type Params end
struct Data
x1::Float64
x2::Float64
end
const DataMatVec{Data} = Union{AbstractVector{Data},AbstractMatrix{Data}}
const BLOCK_SIZE = 256
# Interface for primary and secondary pdf: these have to be implemented case by case
function pdf_x1 end
function cdf_x1 end
function pdf_x2_given_x1 end
function cdf_x2_given_x1 end
@kernel function _pdf_x1x2_kernel!(res, data::DataMatVec, params::Params, inv_norm_p1)
i = @index(Global)
@inbounds res[i] = pdf_x1(data[i], params) * inv_norm_p1 * pdf_x2_given_x1(data[i], params) * inv(cdf_x2_given_x1(data[i], params))
end
function pdf_x1x2(data::DataMatVec, params::Params)
backend = get_backend(data)
res = similar(data, Float64)
norm_p1 = cdf_x1(params)
inv_norm_p1 = inv(norm_p1)
kernel! = _pdf_x1x2_kernel!(backend)
kernel!(res, data, params, inv_norm_p1; ndrange=size(data), workgroupsize=BLOCK_SIZE)
res
end
# Core functions
function tpl(x::Float64, alpha, x_min::Float64, x_max::Float64)::Float64
if x_min < x < x_max
return @fastmath x^alpha
else
return 0.0
end
end
function tpl_cdf(x::Float64, alpha, x_min::Float64)::Float64
if alpha == -1
return @fastmath log(x_min) - log(x)
else
return @fastmath (x^(1 + alpha) - x_min^(1 + alpha)) / (1 + alpha)
end
end
struct TPLParams <: Params
x_low::Float64
x_high::Float64
alpha::Float64
beta::Float64
end
pdf_x1(data::Data, params::TPLParams) = tpl(data.x1, -params.alpha, params.x_low, params.x_high)
cdf_x1(params::TPLParams) = tpl_cdf(params.x_high, -params.alpha, params.x_low)
pdf_x2_given_x1(data::Data, params::TPLParams) = tpl(data.x2, params.beta, params.x_low, data.x1)
cdf_x2_given_x1(data::Data, params::TPLParams) = tpl_cdf(data.x1, params.beta, params.x_low)
function bench()
params = TPLParams(0.1,5., 2., 1.5)
x1_vec = range(2,4,5000)
x2_vec = range(1,3,5000)
x1_mat = repeat(x1_vec', 300, 1)
x2_mat = repeat(x2_vec', 300, 1)
data_mat = Data.(x1_mat, x2_mat)
print("Bench CPU Matrix: "); @time pdf_x1x2(data_mat, params)
end
julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × AMD Ryzen 5 5600X 6-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 6 default, 0 interactive, 3 GC (on 12 virtual cores)
Environment:
JULIA_NUM_THREADS = 6
(AK_CUDA_example_discourse) pkg> st
Status `~/Nextcloud/Julia/scrap/AK_CUDA_example_discourse/Project.toml`
[63c18a36] KernelAbstractions v0.10.0-dev `https://github.com/JuliaGPU/KernelAbstractions.jl.git#main`
julia> bench()
Bench CPU Matrix: 10.136724 seconds (8.23 M allocations: 421.202 MiB, 0.86% gc time, 98.64% compilation time: 2% of which was recompilation)
julia> bench()
Bench CPU Matrix: 0.065198 seconds (65 allocations: 11.449 MiB, 6.96% gc time)
julia> bench()
Bench CPU Matrix:
[11033] signal 11 (1): Segmentation fault
in expression starting at REPL[2]:1
Allocations: 28692815 (Pool: 28692068; Big: 747); GC: 30
zsh: segmentation fault (core dumped) julia --project=.
Metadata
Metadata
Assignees
Labels
No labels