|
18 | 18 | function materialize_traced_array(
|
19 | 19 | x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
|
20 | 20 | ) where {T}
|
21 |
| - scatter_indices = vcat( |
22 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1), |
23 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0), |
24 |
| - diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1), |
25 |
| - ) |
26 |
| - scatter_indices = Ops.constant(scatter_indices) |
27 |
| - |
28 |
| - updates = TracedRArray{T,1}( |
29 |
| - (), |
30 |
| - MLIR.IR.result( |
31 |
| - MLIR.Dialects.stablehlo.concatenate( |
32 |
| - [x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0 |
33 |
| - ), |
34 |
| - 1, |
35 |
| - ), |
36 |
| - (size(scatter_indices, 1),), |
37 |
| - ) |
38 |
| - |
39 |
| - return simple_scatter_op(size(x), scatter_indices, updates) |
| 21 | + return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du) |
40 | 22 | end
|
41 | 23 |
|
42 | 24 | for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
|
@@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
|
251 | 233 | return TracedRArray{T,1}((), res, (diag_length,))
|
252 | 234 | end
|
253 | 235 |
|
254 |
| -function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} |
255 |
| - return LinearAlgebra.diagm(length(v), length(v), v) |
256 |
| -end |
257 |
| -function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} |
258 |
| - m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check |
259 |
| - indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) |
260 |
| - return simple_scatter_op((m, n), indices, materialize_traced_array(v)) |
| 236 | +function LinearAlgebra._diagm( |
| 237 | + shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... |
| 238 | +) where {T} |
| 239 | + m, n = LinearAlgebra.diagm_size(shape, kv...) |
| 240 | + scatter_indices = Matrix{Int64}[] |
| 241 | + concat_inputs = MLIR.IR.Value[] |
| 242 | + for (k, v) in kv |
| 243 | + push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) |
| 244 | + push!(concat_inputs, get_mlir_data(v)) |
| 245 | + end |
| 246 | + scatter_indices = Ops.constant(reduce(vcat, scatter_indices)) |
| 247 | + values = TracedRArray{T,1}( |
| 248 | + (), |
| 249 | + MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), |
| 250 | + (size(scatter_indices, 1),), |
| 251 | + ) |
| 252 | + return simple_scatter_op((m, n), scatter_indices, values) |
261 | 253 | end
|
262 | 254 |
|
263 | 255 | # Common Utilities
|
@@ -309,7 +301,7 @@ function simple_scatter_op(
|
309 | 301 | return TracedRArray{T,2}((), res, shape)
|
310 | 302 | end
|
311 | 303 |
|
312 |
| -# The cartesian version doesn't exist in julia 1.10 |
| 304 | +## The cartesian version doesn't exist in julia 1.10 |
313 | 305 | function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
|
314 | 306 | Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k))
|
315 | 307 | Cstep = CartesianIndex(1, 1)
|
|
0 commit comments