Skip to content

Commit f5513ff

Browse files
committed
Static v0.8 compatibility
1 parent 99a603b commit f5513ff

File tree

8 files changed

+117
-28
lines changed

8 files changed

+117
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ NaNMath = "0.3, 1"
4848
PrettyPrinting = "0.3, 0.4"
4949
Reexport = "1"
5050
SpecialFunctions = "2"
51-
Static = "0.5, 0.6"
51+
Static = "0.5, 0.6, 0.8"
5252
Tricks = "0.1"
5353
julia = "1.3"

src/MeasureBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using PrettyPrinting
3232
const Pretty = PrettyPrinting
3333

3434
using ChainRulesCore
35-
using FillArrays
35+
import FillArrays
3636
using Static
3737
using FunctionChains
3838

@@ -106,6 +106,7 @@ using Compat
106106

107107
using IrrationalConstants
108108

109+
include("static.jl")
109110
include("smf.jl")
110111
include("getdof.jl")
111112
include("transport.jl")

src/combinators/power.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,67 @@
11
import Base
2-
using FillArrays: Fill
3-
# """
4-
# A power measure is a product of a measure with itself. The number of elements in
5-
# the product determines the dimensionality of the resulting support.
62

7-
# Note that power measures are only well-defined for integer powers.
3+
export PowerMeasure
84

9-
# The nth power of a measure μ can be written μ^x.
10-
# """
11-
# PowerMeasure{M,N,D} = ProductMeasure{Fill{M,N,D}}
5+
"""
6+
struct PowerMeasure{M,...} <: AbstractProductMeasure
127
13-
export PowerMeasure
8+
A power measure is a product of a measure with itself. The number of elements in
9+
the product determines the dimensionality of the resulting support.
1410
11+
Note that power measures are only well-defined for integer powers.
12+
13+
The nth power of a measure μ can be written μ^x.
14+
"""
1515
struct PowerMeasure{M,A} <: AbstractProductMeasure
1616
parent::M
1717
axes::A
1818
end
1919

20+
dslength::PowerMeasure) = prod(dssize(μ))
21+
dssize::PowerMeasure) = map(dslength, μ.axes)
22+
2023
function Pretty.tile::PowerMeasure)
2124
sz = length.(μ.axes)
2225
arg1 = Pretty.tile.parent)
2326
arg2 = Pretty.tile(length(sz) == 1 ? only(sz) : sz)
2427
return Pretty.pair_layout(arg1, arg2; sep = " ^ ")
2528
end
2629

30+
# ToDo: Make rand return static arrays for statically-sized power measures.
31+
32+
_cartidxs(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N} = CartesianIndices(map(_dynamic, axs))
33+
2734
function Base.rand(
2835
rng::AbstractRNG,
2936
::Type{T},
3037
d::PowerMeasure{M},
3138
) where {T,M<:AbstractMeasure}
32-
map(CartesianIndices(d.axes)) do _
39+
map(_cartidxs(d.axes)) do _
3340
rand(rng, T, d.parent)
3441
end
3542
end
3643

3744
function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
38-
map(CartesianIndices(d.axes)) do _
45+
map(_cartidxs(d.axes)) do _
3946
rand(rng, d.parent)
4047
end
4148
end
4249

50+
@inline _pm_axes(sz::Tuple{Vararg{<:IntegerLike,N}}) where N = map(one_to, sz)
51+
@inline _pm_axes(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where N = axs
52+
4353
@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N}
44-
a = axes(Fill{T,N}(x, sz))
45-
A = typeof(a)
46-
PowerMeasure{T,A}(x, a)
54+
PowerMeasure(x, _pm_axes(sz))
4755
end
4856

49-
marginals(d::PowerMeasure) = Fill(d.parent, d.axes)
57+
58+
marginals(d::PowerMeasure) = fill_with(d.parent, d.axes)
5059

5160
function Base.:^::AbstractMeasure, dims::Tuple{Vararg{<:AbstractArray,N}}) where {N}
5261
powermeasure(μ, dims)
5362
end
5463

55-
Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, Base.OneTo.(dims))
64+
Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, one_to.(dims))
5665
Base.:^::AbstractMeasure, n) = powermeasure(μ, (n,))
5766

5867
# Base.show(io::IO, d::PowerMeasure) = print(io, d.parent, " ^ ", size(d.xs))
@@ -76,7 +85,7 @@ end
7685
end
7786

7887
@inline function logdensity_def(
79-
d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}},
88+
d::PowerMeasure{M,Tuple{StaticOneTo{N}}},
8089
x,
8190
) where {M,N}
8291
parent = d.parent
@@ -86,7 +95,7 @@ end
8695
end
8796

8897
@inline function logdensity_def(
89-
d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}},
98+
d::PowerMeasure{M,NTuple{N,StaticOneTo{0}}},
9099
x,
91100
) where {M,N}
92101
static(0.0)
@@ -110,7 +119,7 @@ end
110119

111120
@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))
112121

113-
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N}
122+
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,StaticOneTo{0}}}) where {N}
114123
static(0)
115124
end
116125

@@ -135,7 +144,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)
135144

136145
# To avoid ambiguities
137146
function logdensity_def(
138-
::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}},
147+
::PowerMeasure{P,Tuple{Vararg{StaticOneTo{0},N}}},
139148
x,
140149
) where {P<:PrimitiveMeasure,N}
141150
static(0.0)

src/combinators/smart-constructors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
###############################################################################
3838
# ProductMeasure
3939

40-
productmeasure(mar::Fill) = powermeasure(mar.value, mar.axes)
40+
productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes)
4141

4242
function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
4343
return powermeasure(mar.f.value, axes(mar.data))

src/domains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ struct Simplex <: CodimOne end
116116

117117
function zeroset(::Simplex)
118118
f(x::AbstractArray{T}) where {T} = sum(x) - one(T)
119-
∇f(x::AbstractArray{T}) where {T} = Fill(one(T), size(x))
119+
∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x))
120120
ZeroSet(f, ∇f)
121121
end
122122

src/standard/stdmeasure.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x)
1313
end
1414

1515
function transport_def::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x)
16-
return Fill(transport_def.parent, μ, only(x)), map(length, ν.axes)...)
16+
return fill_with(transport_def.parent, μ, only(x)), map(length, ν.axes))
1717
end
1818

1919
function transport_def(
@@ -35,7 +35,7 @@ end
3535
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):
3636

3737
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
38-
_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof
38+
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
3939
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))
4040

4141
function transport_to(::Type{NU}, μ) where {NU<:StdMeasure}
@@ -90,7 +90,7 @@ end
9090
@inline _offset_cumsum(s, x) = (s,)
9191
@inline _offset_cumsum(s) = ()
9292

93-
function _stdvar_viewranges(μs::Tuple, startidx::Integer)
93+
function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike)
9494
N = map(getdof, μs)
9595
offs = _offset_cumsum(startidx, N...)
9696
map((o, n) -> o:o+n-1, offs, N)

src/static.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
MeasureBase.IntegerLike
3+
4+
Equivalent to `Union{Integer,Static.StaticInt}`.
5+
"""
6+
const IntegerLike = Union{Integer,Static.StaticInt}
7+
8+
9+
"""
10+
MeasureBase.one_to(n::IntegerLike)
11+
12+
Creates a range from one to n.
13+
14+
Returns an instance of `Base.OneTo` or `MeasureBase.StaticOneTo`, depending
15+
on the type of `n`.
16+
"""
17+
@inline one_to(n::Integer) = Base.OneTo(n)
18+
19+
@static if isdefined(Static, :SOneTo)
20+
@inline one_to(::Static.StaticInt{N}) where N = Static.SOneTo{N}()
21+
else
22+
@inline one_to(::Static.StaticInt{N}) where N = Base.OneTo(StaticInt{N}())
23+
end
24+
25+
26+
"""
27+
MeasureBase.StaticOneTo(n)
28+
29+
A static range from one to N.
30+
31+
StaticOneTo is a type alias for `Static.SOneTo{N}` (for Static.jl >= v0.8) or
32+
`Base.OneTo{Static.StaticInt{N}}` (older versions of Static.jl).
33+
"""
34+
@static if isdefined(Static, :SOneTo)
35+
const StaticOneTo{N} = Static.SOneTo{N}
36+
else
37+
const StaticOneTo{N} = Base.OneTo{Static.StaticInt{N}}
38+
end
39+
40+
_dynamic(x::Number) = dynamic(x)
41+
_dynamic(::StaticOneTo{N}) where N = Base.OneTo(N)
42+
_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r)
43+
44+
"""
45+
MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N
46+
47+
Creates an array of size `sz` filled with `x`.
48+
49+
Returns an instance of `FillArrays.Fill`.
50+
"""
51+
function fill_with end
52+
53+
@inline fill_with(x::T, sz::Tuple{Vararg{<:IntegerLike,N}}) where {T,N} = fill_with(x, map(one_to, sz))
54+
55+
@inline function fill_with(x::T, axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {T,N}
56+
# While `FillArrays.Fill` (mostly?) works with axes that are static unit
57+
# ranges, some operations that automatic differentiation requires do fail
58+
# on such instances of `Fill` (e.g. `reshape` from dynamic to static size).
59+
# So need to use standard ranges for the axes for now:
60+
dyn_axs = map(_dynamic, axs)
61+
FillArrays.Fill(x, dyn_axs)
62+
end
63+
64+
65+
"""
66+
MeasureBase.dslength(x)::IntegerLike
67+
68+
Returns the length of `x` as a dynamic or static integer.
69+
"""
70+
dslength(x) = length(x)
71+
dslength(::Static.OptionallyStaticUnitRange{StaticInt{A}, StaticInt{B}}) where {A,B} = static(B - A + 1)
72+
73+
74+
"""
75+
MeasureBase.dssize(x)::Tuple{Vararg{IntegerLike}}
76+
77+
Returns the size of `x` as a tuple of dynamic or static integers.
78+
"""
79+
dssize(x) = size(x)

test/transport.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LogExpFunctions: logit
88
using ChainRulesTestUtils
99

1010
@testset "transport_to" begin
11-
test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform()))
11+
test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform()), output_tangent = static(0))
1212

1313
for (f, μ) in [
1414
(logit, StdUniform())

0 commit comments

Comments
 (0)