Skip to content

Commit 1affd14

Browse files
committed
fix: diagm for repeated indices and initial tests
1 parent c536e57 commit 1affd14

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,20 @@ function LinearAlgebra._diagm(
274274
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
275275
) where {T}
276276
m, n = LinearAlgebra.diagm_size(shape, kv...)
277+
278+
# For repeated indices we need to aggregate the values
279+
kv_updated = Dict{Integer,AnyTracedRArray{T,1}}()
280+
for (k, v) in kv
281+
if haskey(kv_updated, k)
282+
kv_updated[k] = kv_updated[k] + v
283+
else
284+
kv_updated[k] = v
285+
end
286+
end
287+
277288
scatter_indices = Matrix{Int64}[]
278289
concat_inputs = MLIR.IR.Value[]
279-
for (k, v) in kv
290+
for (k, v) in pairs(kv_updated)
280291
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
281292
push!(concat_inputs, get_mlir_data(v))
282293
end

test/integration/linear_algebra.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ end
130130
@test @jit(diagm(4, 5, x_ra)) diagm(4, 5, x)
131131
@test @jit(diagm(6, 6, x_ra)) diagm(6, 6, x)
132132
@test_throws DimensionMismatch @jit(diagm(3, 3, x_ra))
133+
134+
x1 = rand(3)
135+
x2 = rand(3)
136+
x3 = rand(2)
137+
x_ra1 = Reactant.to_rarray(x1)
138+
x_ra2 = Reactant.to_rarray(x2)
139+
x_ra3 = Reactant.to_rarray(x3)
140+
141+
@test @jit(diagm(1 => x_ra1)) diagm(1 => x1)
142+
@test @jit(diagm(1 => x_ra1, -1 => x_ra3)) diagm(1 => x1, -1 => x3)
143+
@test @jit(diagm(1 => x_ra1, 1 => x_ra2)) diagm(1 => x1, 1 => x2)
133144
end
134145

135146
# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be

0 commit comments

Comments
 (0)