Skip to content

Commit e006a95

Browse files
committed
Added @materialize convenience DSL
1 parent 7336bc5 commit e006a95

File tree

4 files changed

+189
-3
lines changed

4 files changed

+189
-3
lines changed

src/LazyArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ end
6060

6161
export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
6262
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,
63-
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray
63+
applied, materialize, materialize!, @materialize, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray
6464

6565

6666
include("lazyapplying.jl")
67+
include("materialize_dsl.jl")
6768
include("lazybroadcasting.jl")
6869
include("linalg/linalg.jl")
6970
include("cache.jl")

src/materialize_dsl.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# For unparametrized destination types
2+
generate_copyto!_signature(dest, dest_type::Symbol, Msig) =
3+
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)))
4+
5+
# For parametrized destination types
6+
function generate_copyto!_signature(dest, dest_type::Expr, Msig)
7+
dest_type.head == :curly ||
8+
throw(ArgumentError("Invalid destination specification $(dest)::$(dest_type)"))
9+
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)) where {$(dest_type.args[2:end]...)})
10+
end
11+
12+
function generate_copyto!(body, factor_names, Msig)
13+
body.head == :(->) ||
14+
throw(ArgumentError("Invalid copyto! specification"))
15+
body.args[1].head == :(::) ||
16+
throw(ArgumentError("Invalid destination specification $(body.args[1])"))
17+
(dest,dest_type) = body.args[1].args
18+
copyto!_signature = generate_copyto!_signature(dest, dest_type, Msig)
19+
f_body = quote
20+
axes($dest) == axes(applied_obj) || throw(DimensionMismatch("axes must be same"))
21+
$(factor_names) = applied_obj.args
22+
$(body.args[2].args...)
23+
$(dest)
24+
end
25+
Expr(:function, copyto!_signature, f_body)
26+
end
27+
28+
"""
29+
@materialize function op(args...)
30+
31+
This macro simplifies the setup of a few functions necessary for the
32+
materialization of [`Applied`](@ref) objects:
33+
34+
- `ApplyStyle`, used to ensure dispatch of the applied object to the
35+
routines below
36+
37+
- `copyto!(dest::DestType, applied_obj::Applied{...,op})` performs the
38+
actual materialization of `applied_obj` into the destination object
39+
which has been generated by
40+
41+
- `similar` which usually returns a suitable matrix
42+
43+
- `materialize` which makes use of the above functions
44+
45+
# Example
46+
47+
```julia
48+
@materialize function *(Ac::MyAdjointBasis,
49+
O::MyOperator,
50+
B::MyBasis)
51+
MyApplyStyle # An instance of this type will be returned by ApplyStyle
52+
T -> begin # generates similar
53+
A = parent(Ac)
54+
parent(A) == parent(B) ||
55+
throw(ArgumentError("Incompatible bases"))
56+
57+
# There may be different matrices best representing different
58+
# situations:
59+
if ...
60+
Diagonal(Vector{T}(undef, size(B,1)))
61+
else
62+
Tridiagonal(Vector{T}(undef, size(B,1)-1),
63+
Vector{T}(undef, size(B,1)),
64+
Vector{T}(undef, size(B,1)-1))
65+
end
66+
end
67+
dest::Diagonal{T} -> begin # generate copyto!(dest::Diagonal{T}, ...) where T
68+
dest.diag .= 1
69+
end
70+
dest::Tridiagonal{T} -> begin # generate copyto!(dest::Tridiagonal{T}, ...) where T
71+
dest.dl .= -2
72+
dest.ev .= 1
73+
dest.du .= 3
74+
end
75+
end
76+
```
77+
"""
78+
macro materialize(expr)
79+
expr.head == :function || expr.head == :(=) || error("Must start with a function")
80+
@assert expr.args[1].head == :call
81+
op = expr.args[1].args[1]
82+
83+
bodies = filter(e -> !(e isa LineNumberNode), expr.args[2].args)
84+
length(bodies) < 3 &&
85+
throw(ArgumentError("At least three blocks required (ApplyStyle, similar, and at least one copyto!)"))
86+
87+
factor_types = :(<:Tuple{})
88+
factor_names = :(())
89+
apply_style = first(bodies)
90+
apply_style_fun = :(LazyArrays.ApplyStyle(::typeof($op)) = $(apply_style)())
91+
92+
# Generate Applied signature
93+
for arg in expr.args[1].args[2:end]
94+
arg isa Expr && arg.head == :(::) ||
95+
throw(ArgumentError("Invalid argument specification $(arg)"))
96+
arg_name, arg_typ = arg.args
97+
push!(factor_types.args[1].args, :(<:$(arg_typ)))
98+
push!(factor_names.args, arg_name)
99+
push!(apply_style_fun.args[1].args, :(::Type{<:$(arg_typ)}))
100+
end
101+
Msig = :(LazyArrays.Applied{$(apply_style), typeof($op), $(factor_types)})
102+
103+
sim_body = bodies[2]
104+
sim_body.head == :(->) ||
105+
throw(ArgumentError("Invalid similar specification"))
106+
T = first(sim_body.args)
107+
108+
copytos! = map(body -> generate_copyto!(body, factor_names, Msig), bodies[3:end])
109+
110+
f = quote
111+
$(apply_style_fun)
112+
113+
function Base.similar(applied_obj::$Msig, ::Type{$T}=eltype(applied_obj)) where $T
114+
$(factor_names) = applied_obj.args
115+
$(sim_body.args[2])
116+
end
117+
118+
$(copytos!...)
119+
120+
LazyArrays.materialize(applied_obj::$Msig) =
121+
copyto!(similar(applied_obj, eltype(applied_obj)), applied_obj)
122+
end
123+
esc(f)
124+
end

test/materialize_dsl.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
struct MyOperator{T}
2+
n::Int
3+
kind::Symbol
4+
end
5+
6+
Base.axes(O::MyOperator) = (Base.OneTo(O.n),Base.OneTo(O.n))
7+
Base.axes(O::MyOperator,i) = axes(O)[i]
8+
Base.size(O::MyOperator) = (O.n,O.n)
9+
Base.eltype(::MyOperator{T}) where T = T
10+
11+
struct MyApplyStyle <: ApplyStyle end
12+
13+
@materialize function *(Ac::Adjoint{<:Any,<:AbstractMatrix},
14+
O::MyOperator,
15+
B::AbstractMatrix)
16+
MyApplyStyle
17+
T -> begin
18+
A = parent(Ac)
19+
20+
if O.kind == :diagonal
21+
Diagonal(Vector{T}(undef, O.n))
22+
else
23+
Tridiagonal(Vector{T}(undef, O.n-1),
24+
Vector{T}(undef, O.n),
25+
Vector{T}(undef, O.n-1))
26+
end
27+
end
28+
dest::Diagonal{T} -> begin
29+
dest.diag .= 1
30+
end
31+
dest::Tridiagonal{T} -> begin
32+
dest.dl .= -2
33+
dest.d .= 1
34+
dest.du .= 3
35+
end
36+
end
37+
38+
@testset "Materialize DSL" begin
39+
o = ones(10)
40+
M = ones(10,10)
41+
D = MyOperator{Float64}(10, :diagonal)
42+
T = MyOperator{ComplexF64}(10, :tridiagonal)
43+
44+
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(D), typeof(M)) == MyApplyStyle()
45+
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(T), typeof(M)) == MyApplyStyle()
46+
47+
d = apply(*, M', D, M)
48+
@test d isa Diagonal{Float64}
49+
@test all(d.diag .== 1)
50+
51+
t = apply(*, M', T, M)
52+
@test t isa Tridiagonal
53+
@test all(t.dl .== -2)
54+
@test all(t.d .== 1)
55+
@test all(t.du .== 3)
56+
57+
= ones(11,11)
58+
@test_throws DimensionMismatch apply(*, M̃', D, M̃)
59+
end

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, ArrayLayouts
22
import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted,
3-
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout
3+
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout,
4+
ApplyStyle
45

56
@testset "Lazy MemoryLayout" begin
67
@testset "ApplyArray" begin
@@ -25,6 +26,7 @@ import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcas
2526
end
2627
end
2728
include("applytests.jl")
29+
include("materialize_dsl.jl")
2830
include("multests.jl")
2931
include("ldivtests.jl")
3032
include("addtests.jl")
@@ -341,4 +343,4 @@ end
341343
@test exp.(transpose(v)) isa BroadcastMatrix
342344
@test exp.(M') isa BroadcastMatrix
343345
@test exp.(transpose(M)) isa BroadcastMatrix
344-
end
346+
end

0 commit comments

Comments
 (0)