Skip to content

Commit b2c7e54

Browse files
committed
feat: generalize diagm
1 parent a9a8b2a commit b2c7e54

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

src/linear_algebra.jl

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,7 @@ end
1818
function materialize_traced_array(
1919
x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
2020
) where {T}
21-
scatter_indices = vcat(
22-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1),
23-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0),
24-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1),
25-
)
26-
scatter_indices = Ops.constant(scatter_indices)
27-
28-
updates = TracedRArray{T,1}(
29-
(),
30-
MLIR.IR.result(
31-
MLIR.Dialects.stablehlo.concatenate(
32-
[x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0
33-
),
34-
1,
35-
),
36-
(size(scatter_indices, 1),),
37-
)
38-
39-
return simple_scatter_op(size(x), scatter_indices, updates)
21+
return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
4022
end
4123

4224
for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
@@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
251233
return TracedRArray{T,1}((), res, (diag_length,))
252234
end
253235

254-
function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
255-
return LinearAlgebra.diagm(length(v), length(v), v)
256-
end
257-
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
258-
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check
259-
indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :])
260-
return simple_scatter_op((m, n), indices, materialize_traced_array(v))
236+
function LinearAlgebra._diagm(
237+
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
238+
) where {T}
239+
m, n = LinearAlgebra.diagm_size(shape, kv...)
240+
scatter_indices = Matrix{Int64}[]
241+
concat_inputs = MLIR.IR.Value[]
242+
for (k, v) in kv
243+
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
244+
push!(concat_inputs, get_mlir_data(v))
245+
end
246+
scatter_indices = Ops.constant(reduce(vcat, scatter_indices))
247+
values = TracedRArray{T,1}(
248+
(),
249+
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
250+
(size(scatter_indices, 1),),
251+
)
252+
return simple_scatter_op((m, n), scatter_indices, values)
261253
end
262254

263255
# Common Utilities
@@ -309,7 +301,7 @@ function simple_scatter_op(
309301
return TracedRArray{T,2}((), res, shape)
310302
end
311303

312-
# The cartesian version doesn't exist in julia 1.10
304+
## The cartesian version doesn't exist in julia 1.10
313305
function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
314306
Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k))
315307
Cstep = CartesianIndex(1, 1)

0 commit comments

Comments
 (0)