Skip to content

Commit 2441686

Browse files
committed
fix: non-contiguous indexing is now supported
1 parent 3b8a215 commit 2441686

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

src/Ops.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,4 +1278,47 @@ function scatter_setindex(
12781278
)
12791279
end
12801280

1281+
"""
1282+
gather_getindex(src, gather_indices)
1283+
1284+
Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices
1285+
specified by `gather_indices`. If the indices are contiguous it is recommended to directly
1286+
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1287+
"""
1288+
function gather_getindex(
1289+
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
1290+
) where {T,N}
1291+
@assert size(gather_indices, 2) == N
1292+
1293+
#! format: off
1294+
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
1295+
MLIR.IR.context(),
1296+
Int64(1), Int64[1],
1297+
Int64(N - 1), collect(Int64, 0:(N - 2)),
1298+
Int64(0), Int64[],
1299+
Int64(0), Int64[],
1300+
Int64(N), collect(Int64, 0:(N - 1)),
1301+
1
1302+
)
1303+
#! format: on
1304+
1305+
return reshape(
1306+
TracedRArray{T,2}(
1307+
(),
1308+
MLIR.IR.result(
1309+
MLIR.Dialects.stablehlo.gather(
1310+
src.mlir_data,
1311+
gather_indices.mlir_data;
1312+
dimension_numbers,
1313+
slice_sizes=fill(Int64(1), N),
1314+
indices_are_sorted=false,
1315+
),
1316+
1,
1317+
),
1318+
(size(gather_indices, 1), 1),
1319+
),
1320+
size(gather_indices, 1),
1321+
)
1322+
end
1323+
12811324
end # module Ops

src/TracedRArray.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1111
) where {T,N}
1212
shape = Tuple(shape)
1313
if !isnothing(mlir_data)
14-
@assert size(MLIR.IR.type(mlir_data)) == shape
14+
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
1515
end
1616
return new{T,N}(paths, mlir_data, shape)
1717
end
@@ -114,21 +114,34 @@ function Base.getindex(a::TracedRArray{T,0}) where {T}
114114
return TracedRNumber{T}((), a.mlir_data)
115115
end
116116

117-
# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
118117
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
119118
indices = map(enumerate(indices)) do (idx, i)
120119
i isa Colon && return 1:size(a, idx)
121120
i isa CartesianIndex && return Tuple(i)
122121
return i
123122
end
124123

124+
non_contiguous_getindex = false
125125
for idxs in indices
126126
idxs isa Number && continue
127127
contiguous = all(isone, diff(idxs))
128128
# XXX: We want to throw error even for dynamic indexing
129-
if typeof(contiguous) <: Bool
130-
contiguous || error("non-contiguous indexing is not supported")
129+
if typeof(contiguous) <: Bool && !contiguous
130+
non_contiguous_getindex = true
131+
break
132+
end
133+
end
134+
135+
if non_contiguous_getindex
136+
indices_tuples = collect(Iterators.product(indices...))
137+
indices = Matrix{Int}(undef, (length(indices_tuples), 2))
138+
for (i, idx) in enumerate(indices_tuples)
139+
indices[i, 1] = idx[1] - 1
140+
indices[i, 2] = idx[2] - 1
131141
end
142+
indices = promote_to(TracedRArray{Int,2}, indices)
143+
res = Ops.gather_getindex(a, indices)
144+
return Ops.reshape(res, size(indices_tuples)...)
132145
end
133146

134147
start_indices = map(indices) do i
@@ -179,7 +192,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
179192
indices[i, 1] = idx[1] - 1
180193
indices[i, 2] = idx[2] - 1
181194
end
182-
indices = promote_to(TracedRArray{Int, 2}, indices)
195+
indices = promote_to(TracedRArray{Int,2}, indices)
183196
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
184197
a.mlir_data = res.mlir_data
185198
return v

0 commit comments

Comments
 (0)