-
Notifications
You must be signed in to change notification settings - Fork 22
feat: partial NNlib.gather support + better indexing support #252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
183b751
ccd6a72
ebfb543
d40781a
4a2cef2
c360871
95e1ba9
3cbbb65
8d888cf
2547eeb
4de7be4
cf6bffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,4 +296,64 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) | |
) | ||
end | ||
|
||
# XXX: reevaluate this manual optimization once | ||
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled | ||
function NNlib.gather!( | ||
dst::TracedRArray{T1,2}, | ||
src::AnyTracedRArray{T2,2}, | ||
idxs::Union{AbstractUnitRange{<:Number}}, | ||
) where {T1,T2} | ||
dst.mlir_data = src[:, idxs].mlir_data | ||
return dst | ||
end | ||
|
||
function NNlib.gather!( | ||
dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} | ||
) where {T1,T2} | ||
dims = NNlib.scatter_dims(src, dst, idxs) | ||
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function | ||
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data | ||
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data | ||
|
||
#! format: off | ||
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( | ||
MLIR.IR.context(), | ||
Int64(1), Int64[0], | ||
Int64(1), Int64[1], | ||
Int64(0), Int64[], | ||
Int64(0), Int64[], | ||
Int64(1), Int64[1], | ||
Int64(1) | ||
) | ||
#! format: on | ||
|
||
res = MLIR.IR.result( | ||
Reactant.MLIR.Dialects.stablehlo.dynamic_gather( | ||
src.mlir_data, idxs, slice_sizes; dimension_numbers | ||
), | ||
1, | ||
) | ||
dst.mlir_data = res | ||
return dst | ||
end | ||
|
||
# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop | ||
# instead of unrolling the loop (the case for AbstractArray can just use | ||
# `stablehlo.gather`). See above for the special case implementation that is optimized. | ||
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray) | ||
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you'd like you can put this behind a global to make sure it's only printed once (I think other indexing does that). Though also I'm fine with having it always warn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah that's what maxlog does! |
||
This case is not optimized and will be slow." maxlog = 1 | ||
dims = NNlib.scatter_dims(src, dst, idxs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we emit a warning here at least in the interim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think this is the right way to go. Even for a small testcase (nanoGPT) it takes forever to compile. Let me try to understand the stablehlo gather and get it fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 , if you'd like feel free to separate the dynamic_slice stuff into a different PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optimized the common cases and printing a warning for the other cases. |
||
colons = ntuple(Returns(Colon()), dims) | ||
start_sizes = ntuple(i -> size(src, i), dims) | ||
results = map(CartesianIndices(idxs)) do k | ||
res = src[colons..., Tuple(idxs[k])...] | ||
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) | ||
return reshape(res, start_sizes..., :) | ||
end | ||
res = reshape(cat(results...; dims=(dims + 1)), size(dst)) | ||
dst.mlir_data = res.mlir_data | ||
return dst | ||
end | ||
|
||
end # module ReactantNNlibExt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,9 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) | |
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) | ||
end | ||
|
||
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} | ||
function Base.getindex( | ||
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} | ||
) where {T,N} | ||
@warn( | ||
"""Performing scalar indexing on task $(current_task()). | ||
Invocation resulted in scalar indexing of a TracedRArray. | ||
|
@@ -65,49 +67,59 @@ Such implementations *do not* execute on device, but very slowly on the CPU, | |
and require expensive copies and synchronization each time and therefore should be avoided.""" | ||
) | ||
|
||
start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] | ||
slice_sizes = [Int64(1) for _ in index] | ||
|
||
res1 = MLIR.IR.result( | ||
MLIR.Dialects.stablehlo.slice( | ||
a.mlir_data; | ||
start_indices=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in index]), | ||
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(i) for i in index]), | ||
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in index]), | ||
), | ||
1, | ||
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 | ||
) | ||
res2 = MLIR.IR.result( | ||
MLIR.Dialects.stablehlo.reshape( | ||
res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1))) | ||
), | ||
1, | ||
) | ||
|
||
return TracedRNumber{T}((), res2) | ||
end | ||
|
||
function Base.getindex(a::TracedRArray{T,0}) where {T} | ||
return TracedRNumber{T}((), a.mlir_data) | ||
end | ||
|
||
# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually | ||
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} | ||
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)] | ||
indices = map(enumerate(indices)) do (idx, i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I think this should be doable with gather. That part I'm less confident we have all the optimization rules to lower into dynamic slice |
||
i isa Colon && return 1:size(a, idx) | ||
i isa CartesianIndex && return Tuple(i) | ||
return i | ||
end | ||
|
||
foreach(indices) do idxs | ||
idxs isa Number && return nothing | ||
contiguous = all(isone, diff(idxs)) | ||
# XXX: We want to throw error even for dynamic indexing | ||
if typeof(a) <: Bool | ||
contiguous || error("non-contiguous indexing is not supported") | ||
end | ||
end | ||
|
||
start_indices = map(indices) do i | ||
return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data | ||
end | ||
slice_sizes = [Int64(length(i)) for i in indices] | ||
res = MLIR.IR.result( | ||
MLIR.Dialects.stablehlo.slice( | ||
a.mlir_data; | ||
start_indices=MLIR.IR.DenseArrayAttribute([ | ||
Int64(first(i) - 1) for i in indices | ||
]), | ||
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]), | ||
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]), | ||
), | ||
1, | ||
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 | ||
) | ||
|
||
x = TracedRArray{T,N}((), res, Tuple(length.(indices))) | ||
ddims = findall(x -> x isa Integer, indices) | ||
!isempty(ddims) && return dropdims(x; dims=Tuple(ddims)) | ||
ddims = findall(Base.Fix2(isa, Integer), indices) | ||
isempty(ddims) || return dropdims(x; dims=Tuple(ddims)) | ||
return x | ||
end | ||
|
||
# Prevent ambiguity | ||
function Base.getindex(a::WrappedTracedRArray, index::Int...) | ||
function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...) | ||
return getindex(ancestor(a), get_ancestor_indices(a, index...)...) | ||
end | ||
|
||
|
@@ -116,7 +128,9 @@ function Base.getindex(a::WrappedTracedRArray, indices...) | |
end | ||
|
||
function Base.setindex!( | ||
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} | ||
a::TracedRArray{T,N}, | ||
v, | ||
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, | ||
) where {T,N} | ||
indices = map(enumerate(indices)) do (idx, i) | ||
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) | ||
|
@@ -138,13 +152,17 @@ function Base.setindex!( | |
end | ||
|
||
function Base.setindex!( | ||
a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} | ||
a::AnyTracedRArray{T,N}, | ||
v, | ||
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, | ||
) where {T,N} | ||
ancestor_indices = get_ancestor_indices(a, indices...) | ||
setindex!(ancestor(a), v, ancestor_indices...) | ||
return a | ||
end | ||
|
||
Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x)) | ||
|
||
Base.size(x::TracedRArray) = x.shape | ||
|
||
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A)) | ||
|
@@ -699,7 +717,7 @@ end | |
|
||
function broadcast_to_size(arg::T, rsize) where {T<:Number} | ||
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize))) | ||
return arg = TracedRArray{T,length(rsize)}( | ||
return TracedRArray{T,length(rsize)}( | ||
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize | ||
) | ||
end | ||
|
@@ -711,6 +729,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize) | |
) | ||
end | ||
|
||
function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} | ||
arg = materialize_traced_array(arg) | ||
return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) | ||
end | ||
|
||
function broadcast_to_size(arg::AnyTracedRArray, rsize) | ||
arg = materialize_traced_array(arg) | ||
size(arg) == rsize && return arg | ||
|
@@ -856,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN | |
return x | ||
end | ||
end | ||
|
||
Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) | ||
Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) |
Uh oh!
There was an error while loading. Please reload this page.