Skip to content

Commit f39adeb

Browse files
committed
fix: incorrect rebase
1 parent 1c0d744 commit f39adeb

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

src/Ops.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,12 +1414,12 @@ instead.
14141414
#! format: off
14151415
scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
14161416
MLIR.IR.context(),
1417-
0, Int64[],
1418-
N, collect(Int64, 0:(N - 1)),
1419-
0, Int64[],
1420-
0, Int64[],
1421-
N, collect(Int64, 0:(N - 1)),
1422-
1
1417+
Int64(0), Int64[],
1418+
Int64(N), collect(Int64, 0:(N - 1)),
1419+
Int64(0), Int64[],
1420+
Int64(0), Int64[],
1421+
Int64(N), collect(Int64, 0:(N - 1)),
1422+
Int64(1)
14231423
)
14241424
#! format: on
14251425

src/TracedUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function set_mlir_data!(
5858
return x
5959
end
6060

61-
function set_mlir_data!(x::AnyTracedRArray, data)
61+
function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
6262
setindex!(x, TracedRArray{T}(data), axes(x)...)
6363
return x
6464
end

src/stdlibs/LinearAlgebra.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function TracedUtils.materialize_traced_array(x::Diagonal{T,TracedRArray{T,1}})
2828
return diagm(parent(x))
2929
end
3030

31-
function materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
31+
function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
3232
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
3333
end
3434

@@ -108,7 +108,7 @@ for (AT, dcomp, ocomp) in (
108108
(:UpperTriangular, "LE", "GT"),
109109
(:UnitUpperTriangular, "LT", "GE"),
110110
)
111-
@eval function set_mlir_data!(
111+
@eval function TracedUtils.set_mlir_data!(
112112
x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data
113113
) where {T}
114114
tdata = TracedRArray{T}(data)
@@ -126,7 +126,9 @@ for (AT, dcomp, ocomp) in (
126126
end
127127
end
128128

129-
function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) where {T}
129+
function TracedUtils.set_mlir_data!(
130+
x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data
131+
) where {T}
130132
if x.uplo == 'L'
131133
set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data)
132134
else
@@ -135,7 +137,7 @@ function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) w
135137
return x
136138
end
137139

138-
function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
140+
function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
139141
tdata = TracedRArray{T}(data)
140142
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
141143
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)
@@ -249,7 +251,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
249251
# <unknown>:0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64>
250252
length(indices) 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[])
251253

252-
return Ops.gather_getindex(x, promote_to(TracedRArray{Int,2}, indices))
254+
return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices))
253255
end
254256

255257
function LinearAlgebra._diagm(

0 commit comments

Comments
 (0)