Skip to content

Commit 995d957

Browse files
committed
feat: generalize diagm
1 parent a9a8b2a commit 995d957

File tree

1 file changed

+25
-34
lines changed

1 file changed

+25
-34
lines changed

src/linear_algebra.jl

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,7 @@ end
1818
function materialize_traced_array(
1919
x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
2020
) where {T}
21-
scatter_indices = vcat(
22-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1),
23-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0),
24-
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1),
25-
)
26-
scatter_indices = Ops.constant(scatter_indices)
27-
28-
updates = TracedRArray{T,1}(
29-
(),
30-
MLIR.IR.result(
31-
MLIR.Dialects.stablehlo.concatenate(
32-
[x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0
33-
),
34-
1,
35-
),
36-
(size(scatter_indices, 1),),
37-
)
38-
39-
return simple_scatter_op(size(x), scatter_indices, updates)
21+
return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
4022
end
4123

4224
for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
@@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
251233
return TracedRArray{T,1}((), res, (diag_length,))
252234
end
253235

254-
function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
255-
return LinearAlgebra.diagm(length(v), length(v), v)
256-
end
257-
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
258-
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check
259-
indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :])
260-
return simple_scatter_op((m, n), indices, materialize_traced_array(v))
236+
function LinearAlgebra._diagm(
237+
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
238+
) where {T}
239+
m, n = LinearAlgebra.diagm_size(shape, kv...)
240+
scatter_indices = Matrix{Int64}[]
241+
concat_inputs = MLIR.IR.Value[]
242+
for (k, v) in kv
243+
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
244+
push!(concat_inputs, get_mlir_data(v))
245+
end
246+
scatter_indices = Ops.constant(reduce(vcat, scatter_indices))
247+
values = TracedRArray{T,1}(
248+
(),
249+
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
250+
(size(scatter_indices, 1),),
251+
)
252+
return simple_scatter_op((m, n), scatter_indices, values)
261253
end
262254

263255
# Common Utilities
@@ -309,15 +301,14 @@ function simple_scatter_op(
309301
return TracedRArray{T,2}((), res, shape)
310302
end
311303

312-
# The cartesian version doesn't exist in julia 1.10
304+
## The cartesian version doesn't exist in julia 1.10
313305
function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
314-
Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k))
315-
Cstep = CartesianIndex(1, 1)
316-
res = StepRangeLen(Cstart, Cstep, max(0, k <= 0 ? min(m + k, n) : min(m, n - k)))
317-
indices = Matrix{Int}(undef, (length(res), 2))
318-
for (i, idx) in enumerate(res)
319-
indices[i, 1] = idx[1] - 1
320-
indices[i, 2] = idx[2] - 1
306+
idx1, idx2 = 1 + max(0, -k), 1 + max(0, k)
307+
L = max(0, k 0 ? min(m + k, n) : min(m, n - k))
308+
indices = Matrix{Int}(undef, (L, 2))
309+
for i in axes(indices, 1)
310+
indices[i, 1] = idx1 + i - 2
311+
indices[i, 2] = idx2 + i - 2
321312
end
322313
return indices
323314
end

0 commit comments

Comments
 (0)