Skip to content

Commit 0c94353

Browse files
committed
Support algebra operations on expansions
1 parent 3a5eaa6 commit 0c94353

File tree

4 files changed

+72
-4
lines changed

4 files changed

+72
-4
lines changed

src/ContinuumArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert,
77
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
88
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport,
99
adjointlayout, LdivApplyStyle, arguments, _arguments, call, broadcastlayout, layout_getindex,
10-
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles
10+
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout
1111
import LinearAlgebra: pinv
1212
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
1313
import FillArrays: AbstractFill, getindex_value, SquareEye

src/bases/bases.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,53 @@ function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Abstrac
149149
T \ L.B[p]
150150
end
151151

152+
153+
##
154+
# Algebra
155+
##
156+
157+
# struct ExpansionLayout <: MemoryLayout end
158+
# applylayout(::Type{typeof(*)}, ::BasisLayout, _) = ExpansionLayout()
159+
160+
const Expansion{T,Space<:Basis,Coeffs<:AbstractVector} = ApplyQuasiVector{T,typeof(*),<:Tuple{Space,Coeffs}}
161+
162+
for op in (:*, :\)
163+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), x::Number, f::Expansion)
164+
S,c = arguments(f)
165+
S * broadcast($op, x, c)
166+
end
167+
end
168+
for op in (:*, :/)
169+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), f::Expansion, x::Number)
170+
S,c = arguments(f)
171+
S * broadcast($op, c, x)
172+
end
173+
end
174+
175+
176+
function broadcastbasis(::typeof(+), a, b)
177+
a b && error("Overload broadcastbasis(::typeof(+), ::$(typeof(a)), ::$(typeof(b)))")
178+
a
179+
end
180+
181+
broadcastbasis(::typeof(-), a, b) = broadcastbasis(+, a, b)
182+
183+
for op in (:+, :-)
184+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), f::Expansion, g::Expansion)
185+
S,c = arguments(f)
186+
T,d = arguments(g)
187+
ST = broadcastbasis($op, S, T)
188+
ST * $op((ST \ S) * c , (ST \ T) * d)
189+
end
190+
end
191+
192+
@eval function ==(f::Expansion, g::Expansion)
193+
S,c = arguments(f)
194+
T,d = arguments(g)
195+
ST = broadcastbasis(+, S, T)
196+
(ST \ S) * c == (ST \ T) * d
197+
end
198+
152199
## materialize views
153200

154201
# materialize(S::SubQuasiArray{<:Any,2,<:ApplyQuasiArray{<:Any,2,typeof(*),<:Tuple{<:Basis,<:Any}}}) =

src/operators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ function diff(d::AbstractQuasiVector)
158158
Derivative(x)*d
159159
end
160160

161+
^(D::Derivative, k::Integer) = ApplyQuasiArray(^, D, k)
161162

162163
# struct Multiplication{T,F,A} <: AbstractQuasiMatrix{T}
163164
# f::F

test/runtests.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ContinuumArrays, QuasiArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, FastTransforms, Test
22
import ContinuumArrays: ℵ₁, materialize, SimplifyStyle, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout,
3-
MappedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout
3+
MappedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout, Expansion
44
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
55
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivApplyStyle
66

@@ -124,10 +124,28 @@ end
124124
end
125125
end
126126

127-
@testset "Derivative" begin
127+
@testset "Algebra" begin
128128
L = LinearSpline([1,2,3])
129129
f = L*[1,2,4]
130+
g = L*[5,6,7]
131+
132+
@test f isa Expansion
133+
@test 2f isa Expansion
134+
@test f*2 isa Expansion
135+
@test 2\f isa Expansion
136+
@test f/2 isa Expansion
137+
@test f+g isa Expansion
138+
@test f-g isa Expansion
130139
@test f[1.2] == 1.2
140+
@test (2f)[1.2] == (f*2)[1.2] == 2.4
141+
@test (2\f)[1.2] == (f/2)[1.2] == 0.6
142+
@test (f+g)[1.2] f[1.2] + g[1.2]
143+
@test (f-g)[1.2] f[1.2] - g[1.2]
144+
end
145+
146+
@testset "Derivative" begin
147+
L = LinearSpline([1,2,3])
148+
f = L*[1,2,4]
131149

132150
D = Derivative(axes(L,1))
133151
@test ApplyStyle(*,typeof(D),typeof(L)) isa SimplifyStyle
@@ -151,11 +169,13 @@ end
151169
@test fp[1.1] 1
152170
@test fp[2.2] 2
153171

154-
155172
fp = D*f
156173
@test length(fp.args) == 2
157174
@test fp[1.1] 1
158175
@test fp[2.2] 2
176+
177+
@test D^2 isa ApplyQuasiMatrix{Float64,typeof(*)}
178+
159179
end
160180

161181
@testset "Weak Laplacian" begin

0 commit comments

Comments
 (0)