|
18 | 18 | function materialize_traced_array(
|
19 | 19 | x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
|
20 | 20 | ) where {T}
|
21 |
| - scatter_indices = vcat( |
22 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1), |
23 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0), |
24 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1), |
25 |
| - ) |
26 |
| - scatter_indices = Ops.constant(scatter_indices) |
27 |
| - |
28 |
| - updates = TracedRArray{T,1}( |
29 |
| - (), |
30 |
| - MLIR.IR.result( |
31 |
| - MLIR.Dialects.stablehlo.concatenate( |
32 |
| - [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 |
33 |
| - ), |
34 |
| - 1, |
35 |
| - ), |
36 |
| - (size(scatter_indices, 1),), |
37 |
| - ) |
38 |
| - |
39 |
| - return simple_scatter_op(size(x), scatter_indices, updates) |
| 21 | + return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du) |
40 | 22 | end
|
41 | 23 |
|
42 | 24 | for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
|
@@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
|
251 | 233 | return TracedRArray{T,1}((), res, (diag_length,))
|
252 | 234 | end
|
253 | 235 |
|
254 |
| -function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} |
255 |
| - return LinearAlgebra.diagm(length(v), length(v), v) |
256 |
| -end |
257 |
| -function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} |
258 |
| - m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check |
259 |
| - indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) |
260 |
| - return simple_scatter_op((m, n), indices, materialize_traced_array(v)) |
| 236 | +function LinearAlgebra._diagm( |
| 237 | + shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... |
| 238 | +) where {T} |
| 239 | + m, n = LinearAlgebra.diagm_size(shape, kv...) |
| 240 | + scatter_indices = Matrix{Int64}[] |
| 241 | + concat_inputs = MLIR.IR.Value[] |
| 242 | + for (k, v) in kv |
| 243 | + push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) |
| 244 | + push!(concat_inputs, get_mlir_data(v)) |
| 245 | + end |
| 246 | + scatter_indices = Ops.constant(reduce(vcat, scatter_indices)) |
| 247 | + values = TracedRArray{T,1}( |
| 248 | + (), |
| 249 | + MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), |
| 250 | + (size(scatter_indices, 1),), |
| 251 | + ) |
| 252 | + return simple_scatter_op((m, n), scatter_indices, values) |
261 | 253 | end
|
262 | 254 |
|
263 | 255 | # Common Utilities
|
@@ -309,15 +301,14 @@ function simple_scatter_op(
|
309 | 301 | return TracedRArray{T,2}((), res, shape)
|
310 | 302 | end
|
311 | 303 |
|
312 |
| -# The cartesian version doesn't exist in julia 1.10 |
| 304 | +## The cartesian version doesn't exist in julia 1.10 |
313 | 305 | function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
|
314 |
| - Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k)) |
315 |
| - Cstep = CartesianIndex(1, 1) |
316 |
| - res = StepRangeLen(Cstart, Cstep, max(0, k <= 0 ? min(m + k, n) : min(m, n - k))) |
317 |
| - indices = Matrix{Int}(undef, (length(res), 2)) |
318 |
| - for (i, idx) in enumerate(res) |
319 |
| - indices[i, 1] = idx[1] - 1 |
320 |
| - indices[i, 2] = idx[2] - 1 |
| 306 | + idx1, idx2 = 1 + max(0, -k), 1 + max(0, k) |
| 307 | + L = max(0, k ≤ 0 ? min(m + k, n) : min(m, n - k)) |
| 308 | + indices = Matrix{Int}(undef, (L, 2)) |
| 309 | + for i in axes(indices, 1) |
| 310 | + indices[i, 1] = idx1 + i - 2 |
| 311 | + indices[i, 2] = idx2 + i - 2 |
321 | 312 | end
|
322 | 313 | return indices
|
323 | 314 | end
|
0 commit comments