diff --git a/src/continuous.jl b/src/continuous.jl index a5d04ad..8907694 100644 --- a/src/continuous.jl +++ b/src/continuous.jl @@ -1,4 +1,4 @@ -#OU process +# OU process struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess mean::T volatility::T @@ -13,19 +13,25 @@ var(model::OrnsteinUhlenbeckDiffusion) = (model.volatility^2) / (2 * model.rever eq_dist(model::OrnsteinUhlenbeckDiffusion) = Normal(model.mean,sqrt(var(model))) +# These are for nested broadcasting +elmwiseadd(x, y) = x .+ y +elmwisesub(x, y) = x .- y +elmwisemul(x, y) = x .* y +elmwisediv(x, y) = x ./ y + function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real) μ, σ, θ = process.mean, process.volatility, process.reversion - mean = @. exp(-(t - s) * θ) * (x_s - μ) + μ - var = similar(mean) - var .= ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ + # exp(-(t - s) * θ) * (x_s - μ) + μ + mean = elmwiseadd.(elmwisemul.(exp(-(t - s) * θ), elmwisesub.(x_s, μ)), μ) + var = ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ return GaussianVariables(mean, var) end function backward(process::OrnsteinUhlenbeckDiffusion, x_t::AbstractArray, s::Real, t::Real) μ, σ, θ = process.mean, process.volatility, process.reversion - mean = @. exp((t - s) * θ) * (x_t - μ) + μ - var = similar(mean) - var .= -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ + # @. exp((t - s) * θ) * (x_t - μ) + μ + mean = elmwiseadd.(elmwisemul.(exp((t - s) * θ), elmwisesub.(x_t, μ)), μ) + var = -(σ^2 / 2θ) + (σ^2 * exp(2(t - s) * θ)) / 2θ return (μ = mean, σ² = var) end diff --git a/src/loss.jl b/src/loss.jl index 6b38f91..8c371e4 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -37,6 +37,16 @@ function standardloss( return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t)) end +function standardloss( + p::OrnsteinUhlenbeckDiffusion, + t::Union{Real,AbstractVector{<:Real}}, + x̂::AbstractArray{<: SVector}, x::AbstractArray{<: SVector}; + scaler=defaultscaler) + loss(x̂, x) = norm.(x̂ .- x).^2 + # ugly syntax but scaler.(p, t) is not differentiable with Zygote.jl for some reason + return scaledloss(loss(x̂, parent(x)), maskedindices(x), (t -> scaler(p, t)).(t)) +end + defaultscaler(p::RotationDiffusion, t::Real) = sqrt(1 - exp(-t * p.rate * 5)) function standardloss( diff --git a/src/randomvariable.jl b/src/randomvariable.jl index 89892a0..4b277b3 100644 --- a/src/randomvariable.jl +++ b/src/randomvariable.jl @@ -1,19 +1,20 @@ # Random Variables # ---------------- -struct GaussianVariables{T, A <: AbstractArray{T}} +struct GaussianVariables{A, B} # μ and σ² must have the same size - μ::A # mean - σ²::A # variance + μ::A # mean (array) + σ²::B # variance (scalar) end Base.size(X::GaussianVariables) = size(X.μ) -sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = randn(rng, T, size(X)) .* .√X.σ² .+ X.μ +sample(rng::AbstractRNG, X::GaussianVariables) = + elmwisemul.(randn(rng, eltype(X.μ), size(X)), √X.σ²) .+ X.μ function combine(X::GaussianVariables, lik) - σ² = @. inv(inv(X.σ²) + inv(lik.σ²)) - μ = @. σ² * (X.μ / X.σ² + lik.μ / lik.σ²) + σ² = inv(inv(X.σ²) + inv(lik.σ²)) + μ = elmwisemul.(σ², elmwisediv.(X.μ, X.σ²) .+ elmwisediv.(lik.μ, lik.σ²)) return GaussianVariables(μ, σ²) end diff --git a/test/runtests.jl b/test/runtests.jl index c773432..3890d27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,13 @@ end x = one(QuatRotation{Float32}) t = 0.29999998f0 @test sampleforward(diffusion, t, [x]) isa Vector + + # three-dimensional diffusion + diffusion = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = fill(zero(SVector{3, Float64}), 2) + x_t = sampleforward(diffusion, 1.0, x_0) + @test x_t isa typeof(x_0) + @test size(x_t) == size(x_0) end @testset "Discrete Diffusions" begin @@ -175,6 +182,12 @@ end x = samplebackward((x, t) -> x + randn(size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t) @test size(x) == size(x_t) @test x isa Matrix + + process = OrnsteinUhlenbeckDiffusion(0.0) + x_t = randn(SVector{3, Float64}, 4, 10) + x = samplebackward((x, t) -> x + randn(eltype(x), size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t) + @test size(x) == size(x_t) + @test x isa Matrix end @testset "Masked Diffusion" begin @@ -244,8 +257,8 @@ end end @testset "Loss" begin - p = OrnsteinUhlenbeckDiffusion(0.0, 1.0, 0.5) - x_0 = randn(5, 10) + p = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = zeros(5, 10) t = rand(10) @test standardloss(p, t, x_0, x_0) == 0 x = rand(5, 10) @@ -253,14 +266,25 @@ end # unmasked elements don't contribute to the loss x = copy(x_0) - m = x_0 .< 0 + m = rand(size(x)...) .< 0.5 + x[.!m] .= 1 x_0 = mask(x_0, m) - x[.!m] .= 0 @test standardloss(p, t, x, x_0) == 0 + @test standardloss(p, t, x, parent(x_0)) > 0 - # but masked elements do - x[m] .= 0 + p = OrnsteinUhlenbeckDiffusion(0.0) + x_0 = fill(zero(SVector{3, Float64}), 10) + t = rand(10) + @test standardloss(p, t, x_0, x_0) == 0 + x = [rand(SVector{3, Float64}) for _ in eachindex(x_0)] @test standardloss(p, t, x, x_0) > 0 + + x = copy(x_0) + m = rand(size(x)...) .< 0.5 + x[.!m] .= (ones(SVector{3, Float64}),) + x_0 = mask(x_0, m) + @test standardloss(p, t, x, x_0) == 0 + @test standardloss(p, t, x, parent(x_0)) > 0 end @testset "Autodiff" begin