@@ -11,7 +11,7 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
11
11
) where {T,N}
12
12
shape = Tuple (shape)
13
13
if ! isnothing (mlir_data)
14
- @assert size (MLIR. IR. type (mlir_data)) == shape
14
+ @assert size (MLIR. IR. type (mlir_data)) == shape " Expected: $(shape) , got: $( size (MLIR . IR . type (mlir_data))) "
15
15
end
16
16
return new {T,N} (paths, mlir_data, shape)
17
17
end
@@ -114,21 +114,34 @@ function Base.getindex(a::TracedRArray{T,0}) where {T}
114
114
return TracedRNumber {T} ((), a. mlir_data)
115
115
end
116
116
117
- # XXX : We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
118
117
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
119
118
indices = map (enumerate (indices)) do (idx, i)
120
119
i isa Colon && return 1 : size (a, idx)
121
120
i isa CartesianIndex && return Tuple (i)
122
121
return i
123
122
end
124
123
124
+ non_contiguous_getindex = false
125
125
for idxs in indices
126
126
idxs isa Number && continue
127
127
contiguous = all (isone, diff (idxs))
128
128
# XXX : We want to throw error even for dynamic indexing
129
- if typeof (contiguous) <: Bool
130
- contiguous || error (" non-contiguous indexing is not supported" )
129
+ if typeof (contiguous) <: Bool && ! contiguous
130
+ non_contiguous_getindex = true
131
+ break
132
+ end
133
+ end
134
+
135
+ if non_contiguous_getindex
136
+ indices_tuples = collect (Iterators. product (indices... ))
137
+ indices = Matrix {Int} (undef, (length (indices_tuples), 2 ))
138
+ for (i, idx) in enumerate (indices_tuples)
139
+ indices[i, 1 ] = idx[1 ] - 1
140
+ indices[i, 2 ] = idx[2 ] - 1
131
141
end
142
+ indices = promote_to (TracedRArray{Int,2 }, indices)
143
+ res = Ops. gather_getindex (a, indices)
144
+ return Ops. reshape (res, size (indices_tuples)... )
132
145
end
133
146
134
147
start_indices = map (indices) do i
@@ -179,7 +192,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
179
192
indices[i, 1 ] = idx[1 ] - 1
180
193
indices[i, 2 ] = idx[2 ] - 1
181
194
end
182
- indices = promote_to (TracedRArray{Int, 2 }, indices)
195
+ indices = promote_to (TracedRArray{Int,2 }, indices)
183
196
res = Ops. scatter_setindex (a, indices, Ops. reshape (v, length (v)))
184
197
a. mlir_data = res. mlir_data
185
198
return v
0 commit comments