@@ -14,24 +14,24 @@ using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_m
14
14
using LinearAlgebra
15
15
16
16
# Various Wrapper Arrays defined in LinearAlgebra
17
- function materialize_traced_array (
17
+ function TracedUtils . materialize_traced_array (
18
18
x:: Transpose{TracedRNumber{T},TracedRArray{T,N}}
19
19
) where {T,N}
20
20
px = parent (x)
21
21
A = ndims (px) == 1 ? reshape (px, :, 1 ) : px
22
22
return permutedims (A, (2 , 1 ))
23
23
end
24
24
25
- function materialize_traced_array (
25
+ function TracedUtils . materialize_traced_array (
26
26
x:: Adjoint{TracedRNumber{T},TracedRArray{T,N}}
27
27
) where {T,N}
28
28
return conj (materialize_traced_array (transpose (parent (x))))
29
29
end
30
30
31
- function materialize_traced_array (
32
- x:: LinearAlgebra. Diagonal{TracedRNumber{T},TracedRArray{T,1}}
31
+ function TracedUtils . materialize_traced_array (
32
+ x:: Diagonal{TracedRNumber{T},TracedRArray{T,1}}
33
33
) where {T}
34
- return LinearAlgebra . diagm (parent (x))
34
+ return diagm (parent (x))
35
35
end
36
36
37
37
function TracedUtils. materialize_traced_array (x:: Tridiagonal{T,TracedRArray{T,1}} ) where {T}
@@ -42,7 +42,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
42
42
uAT = Symbol (:Unit , AT)
43
43
@eval begin
44
44
function TracedUtils. materialize_traced_array (
45
- x:: $ (AT){T ,TracedRArray{T,2 }}
45
+ x:: $ (AT){TracedRNumber{T} ,TracedRArray{T,2 }}
46
46
) where {T}
47
47
m, n = size (x)
48
48
row_idxs = Ops. iota (Int, [m, n]; iota_dimension= 1 )
@@ -52,7 +52,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
52
52
end
53
53
54
54
function TracedUtils. materialize_traced_array (
55
- x:: $ (uAT){T ,TracedRArray{T,2 }}
55
+ x:: $ (uAT){TracedRNumber{T} ,TracedRArray{T,2 }}
56
56
) where {T}
57
57
m, n = size (x)
58
58
row_idxs = Ops. iota (Int, [m, n]; iota_dimension= 1 )
@@ -64,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
64
64
end
65
65
end
66
66
67
- function TracedUtils. materialize_traced_array (x:: Symmetric{T,TracedRArray{T,2}} ) where {T}
67
+ function TracedUtils. materialize_traced_array (
68
+ x:: Symmetric{TracedRNumber{T},TracedRArray{T,2}}
69
+ ) where {T}
68
70
m, n = size (x)
69
71
row_idxs = Ops. iota (Int, [m, n]; iota_dimension= 1 )
70
72
col_idxs = Ops. iota (Int, [m, n]; iota_dimension= 2 )
@@ -107,7 +109,9 @@ function TracedUtils.set_mlir_data!(
107
109
return x
108
110
end
109
111
110
- function TracedUtils. set_mlir_data! (x:: Diagonal{TracedRNumber{T},TracedRArray{T,1}} , data) where {T}
112
+ function TracedUtils. set_mlir_data! (
113
+ x:: Diagonal{TracedRNumber{T},TracedRArray{T,1}} , data
114
+ ) where {T}
111
115
parent (x). mlir_data = diag (TracedRArray {T} (data)). mlir_data
112
116
return x
113
117
end
@@ -119,7 +123,7 @@ for (AT, dcomp, ocomp) in (
119
123
(:UnitUpperTriangular , " LT" , " GE" ),
120
124
)
121
125
@eval function TracedUtils. set_mlir_data! (
122
- x:: LinearAlgebra. $ (AT){T ,TracedRArray{T,2 }}, data
126
+ x:: $ (AT){TracedRNumber{T} ,TracedRArray{T,2 }}, data
123
127
) where {T}
124
128
tdata = TracedRArray {T} (data)
125
129
z = zero (tdata)
@@ -137,17 +141,19 @@ for (AT, dcomp, ocomp) in (
137
141
end
138
142
139
143
function TracedUtils. set_mlir_data! (
140
- x:: LinearAlgebra. Symmetric{T ,TracedRArray{T,2}} , data
144
+ x:: Symmetric{TracedRNumber{T} ,TracedRArray{T,2}} , data
141
145
) where {T}
142
146
if x. uplo == ' L'
143
- set_mlir_data! (LinearAlgebra . LowerTriangular (parent (x)), data)
147
+ set_mlir_data! (LowerTriangular (parent (x)), data)
144
148
else
145
- set_mlir_data! (LinearAlgebra . UpperTriangular (parent (x)), data)
149
+ set_mlir_data! (UpperTriangular (parent (x)), data)
146
150
end
147
151
return x
148
152
end
149
153
150
- function TracedUtils. set_mlir_data! (x:: Tridiagonal{T,TracedRArray{T,1}} , data) where {T}
154
+ function TracedUtils. set_mlir_data! (
155
+ x:: Tridiagonal{TracedRNumber{T},TracedRArray{T,1}} , data
156
+ ) where {T}
151
157
tdata = TracedRArray {T} (data)
152
158
set_mlir_data! (x. dl, diag (tdata, - 1 ). mlir_data)
153
159
set_mlir_data! (x. d, diag (tdata, 0 ). mlir_data)
0 commit comments