Skip to content

Commit ccc61d4

Browse files
avik-palwsmoses
authored andcommitted
fix: dispatches
1 parent 0e6d376 commit ccc61d4

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,24 @@ using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_m
1414
using LinearAlgebra
1515

1616
# Various Wrapper Arrays defined in LinearAlgebra
17-
function materialize_traced_array(
17+
function TracedUtils.materialize_traced_array(
1818
x::Transpose{TracedRNumber{T},TracedRArray{T,N}}
1919
) where {T,N}
2020
px = parent(x)
2121
A = ndims(px) == 1 ? reshape(px, :, 1) : px
2222
return permutedims(A, (2, 1))
2323
end
2424

25-
function materialize_traced_array(
25+
function TracedUtils.materialize_traced_array(
2626
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}
2727
) where {T,N}
2828
return conj(materialize_traced_array(transpose(parent(x))))
2929
end
3030

31-
function materialize_traced_array(
32-
x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}
31+
function TracedUtils.materialize_traced_array(
32+
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}
3333
) where {T}
34-
return LinearAlgebra.diagm(parent(x))
34+
return diagm(parent(x))
3535
end
3636

3737
function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
@@ -42,7 +42,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
4242
uAT = Symbol(:Unit, AT)
4343
@eval begin
4444
function TracedUtils.materialize_traced_array(
45-
x::$(AT){T,TracedRArray{T,2}}
45+
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}
4646
) where {T}
4747
m, n = size(x)
4848
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
@@ -52,7 +52,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
5252
end
5353

5454
function TracedUtils.materialize_traced_array(
55-
x::$(uAT){T,TracedRArray{T,2}}
55+
x::$(uAT){TracedRNumber{T},TracedRArray{T,2}}
5656
) where {T}
5757
m, n = size(x)
5858
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
@@ -64,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
6464
end
6565
end
6666

67-
function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T}
67+
function TracedUtils.materialize_traced_array(
68+
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}
69+
) where {T}
6870
m, n = size(x)
6971
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
7072
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
@@ -107,7 +109,9 @@ function TracedUtils.set_mlir_data!(
107109
return x
108110
end
109111

110-
function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T}
112+
function TracedUtils.set_mlir_data!(
113+
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data
114+
) where {T}
111115
parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data
112116
return x
113117
end
@@ -119,7 +123,7 @@ for (AT, dcomp, ocomp) in (
119123
(:UnitUpperTriangular, "LT", "GE"),
120124
)
121125
@eval function TracedUtils.set_mlir_data!(
122-
x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data
126+
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data
123127
) where {T}
124128
tdata = TracedRArray{T}(data)
125129
z = zero(tdata)
@@ -137,17 +141,19 @@ for (AT, dcomp, ocomp) in (
137141
end
138142

139143
function TracedUtils.set_mlir_data!(
140-
x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data
144+
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data
141145
) where {T}
142146
if x.uplo == 'L'
143-
set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data)
147+
set_mlir_data!(LowerTriangular(parent(x)), data)
144148
else
145-
set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data)
149+
set_mlir_data!(UpperTriangular(parent(x)), data)
146150
end
147151
return x
148152
end
149153

150-
function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
154+
function TracedUtils.set_mlir_data!(
155+
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data
156+
) where {T}
151157
tdata = TracedRArray{T}(data)
152158
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
153159
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)

0 commit comments

Comments
 (0)