diff --git a/src/sampling.jl b/src/sampling.jl index 357d8d9c4..f185d37e8 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,3 +1,41 @@ +using Base: mightalias + +if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+ + using Base: require_one_based_indexing +else + require_one_based_indexing(xs...) = + any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported")) +end + +function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool) + mightalias(input, output) && + throw(ArgumentError("destination array must not share memory with the source array")) + require_one_based_indexing(input, output) + n = length(input) + k = length(output) + if !replace && k > n + throw(DimensionMismatch("cannot draw a sample of $k values from an array " * + "with $n values without replacement")) + end + return nothing +end + +function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, + output::AbstractArray, replace::Bool) + mightalias(output, weights) && + throw(ArgumentError("destination array must not share memory with weights array")) + _validate_sample_inputs(input, weights) + _validate_sample_inputs(input, output, replace) + return nothing +end + +function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) + require_one_based_indexing(weights) + n = length(input) + nw = length(weights) + nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw")) + return nothing +end ########################################################### # @@ -10,16 +48,16 @@ using Random: Sampler, Random.GLOBAL_RNG ### Algorithms for sampling with replacement function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) + _validate_sample_inputs(a, x, true) + k = length(x) s = Sampler(rng, 1:length(a)) b = a[1] - 1 if b == 0 - for i = 1:length(x) + for i = 1:k @inbounds x[i] = rand(rng, s) end else - for i = 1:length(x) + for i = 1:k @inbounds x[i] = b + rand(rng, s) end end @@ -36,10 +74,7 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`. This algorithm consumes `k` random numbers. """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, true) s = Sampler(rng, 1:length(a)) for i = 1:length(x) @inbounds x[i] = a[rand(rng, s)] @@ -61,11 +96,9 @@ storeindices(n, k, T) = false # order results of a sampler that does not order automatically function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - n, k = length(a), length(x) + _validate_sample_inputs(a, x, true) + n = length(a) + k = length(x) # todo: if eltype(x) <: Real && eltype(a) <: Real, # in some cases it might be faster to check # issorted(a) to see if we can just sort x @@ -140,13 +173,9 @@ memory space. Suitable for the case where memory is tight. """ function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; initshuffle::Bool=true) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) n = length(a) k = length(x) - k <= n || error("length(x) should not exceed length(a)") # initialize for i = 1:k @@ -200,13 +229,9 @@ faster than Knuth's algorithm especially when `n` is greater than `k`. It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling """ function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) n = length(a) k = length(x) - k <= n || error("length(x) should not exceed length(a)") inds = Vector{Int}(undef, n) for i = 1:n @@ -240,13 +265,9 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase drastically, resulting in poorer performance. """ function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) n = length(a) k = length(x) - k <= n || error("length(x) should not exceed length(a)") s = Set{Int}() sizehint!(s, k) @@ -282,13 +303,9 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`. The outputs are ordered. """ function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) n = length(a) k = length(x) - k <= n || error("length(x) should not exceed length(a)") i = 0 j = 0 @@ -324,13 +341,9 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) n = length(a) k = length(x) - k <= n || error("length(x) should not exceed length(a)") i = 0 j = 0 @@ -370,13 +383,9 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`. The outputs are ordered. """ function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) + _validate_sample_inputs(a, x, false) N = length(a) n = length(x) - n <= N || error("length(x) should not exceed length(a)") i = 0 j = 0 @@ -485,11 +494,10 @@ nor share memory with them, or the result may be incorrect. """ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - 1 == firstindex(a) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) + _validate_sample_inputs(a, x, replace) k = length(x) k == 0 && return x + n = length(a) if replace # with replacement if ordered @@ -499,8 +507,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; end else # without replacement - k <= n || error("Cannot draw more samples without replacement.") - if ordered if n > 10 * k * k seqsample_c!(rng, a, x) @@ -582,8 +588,7 @@ Optionally specify a random number generator `rng` as the first argument (defaults to `Random.GLOBAL_RNG`). """ function sample(rng::AbstractRNG, wv::AbstractWeights) - 1 == firstindex(wv) || - throw(ArgumentError("non 1-based arrays are not supported")) + require_one_based_indexing(wv) t = rand(rng) * sum(wv) n = length(wv) i = 1 @@ -596,7 +601,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights) end sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv) -sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] +function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) + _validate_sample_inputs(a, wv) + return a[sample(rng, wv)] +end sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv) """ @@ -613,14 +621,7 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm: """ function direct_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + _validate_sample_inputs(a, wv, x, true) for i = 1:length(x) x[i] = a[sample(rng, wv)] end @@ -702,14 +703,9 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers. """ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) + _validate_sample_inputs(a, wv, x, true) n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + k = length(x) # create alias table ap = Vector{Float64}(undef, n) @@ -718,7 +714,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, # sampling s = Sampler(rng, 1:n) - for i = 1:length(x) + for i = 1:k j = rand(rng, s) x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] end @@ -740,15 +736,10 @@ and has overall time complexity ``O(n k)``. """ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + _validate_sample_inputs(a, wv, x, false) k = length(x) + k > 0 || return x + n = length(a) w = Vector{Float64}(undef, n) copyto!(w, wv) @@ -786,20 +777,15 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) + _validate_sample_inputs(a, wv, x, false) k = length(x) + k > 0 || return x + n = length(a) # calculate keys for all items keys = randexp(rng, n) for i in 1:n - @inbounds keys[i] = wv.values[i]/keys[i] + @inbounds keys[i] = wv[i]/keys[i] end # return items with largest keys @@ -827,16 +813,10 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. """ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) + _validate_sample_inputs(a, wv, x, false) k = length(x) k > 0 || return x + n = length(a) # initialize priority queue pq = Vector{Pair{Float64,Int}}(undef, k) @@ -844,7 +824,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -859,7 +839,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds threshold = pq[1].first @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue key = w/randexp(rng) @@ -900,16 +880,10 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; ordered::Bool=false) - Base.mightalias(a, x) && - throw(ArgumentError("output array x must not share memory with input array a")) - Base.mightalias(x, wv) && - throw(ArgumentError("output array x must not share memory with weights array wv")) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) - length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) + _validate_sample_inputs(a, wv, x, false) k = length(x) k > 0 || return x + n = length(a) # initialize priority queue pq = Vector{Pair{Float64,Int}}(undef, k) @@ -917,7 +891,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -933,7 +907,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, X = threshold*randexp(rng) @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue X -= w @@ -968,10 +942,10 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; replace::Bool=true, ordered::Bool=false) - 1 == firstindex(a) == firstindex(wv) == firstindex(x) || - throw(ArgumentError("non 1-based arrays are not supported")) - n = length(a) + _validate_sample_inputs(a, wv, x, replace) k = length(x) + k > 0 || return x + n = length(a) if replace if ordered @@ -991,7 +965,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs end end else - k <= n || error("Cannot draw $k samples from $n samples without replacement.") efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) end return x diff --git a/test/wsampling.jl b/test/wsampling.jl index d1de4c855..e0ff9cc47 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -93,45 +93,34 @@ end import StatsBase: naive_wsample_norep!, efraimidis_a_wsample_norep!, efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep! -n = 10^5 -wv = weights([0.2, 0.8, 0.4, 0.6]) - -a = zeros(Int, 3, n) -for j = 1:n - naive_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(naive_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_a_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_a_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_ares_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_ares_wsample_norep!, 4:7, wv, zeros(Int, 2)) - -a = zeros(Int, 3, n) -for j = 1:n - efraimidis_aexpj_wsample_norep!(4:7, wv, view(a,:,j)) -end -check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) -test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2)) +@testset "Weighted sampling without replacement" begin + n = 10^5 + wv = weights([0.2, 0.8, 0.4, 0.6]) + + @testset "$f" for f in (naive_wsample_norep!, efraimidis_a_wsample_norep!, + efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!) + a = zeros(Int, 3, n) + for j = 1:n + f(4:7, wv, view(a,:,j)) + end + check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false) + test_rng_use(f, 4:7, wv, zeros(Int, 2)) + # Check that the function is using the weight vector's own indexing method(s) + # by trying with `UnitWeights`, which doesn't store an underlying array and thus + # doesn't have a `values` field to access. Here we're effectively just ensuring + # there's no error thrown. + @test length(f(rand(4), uweights(Float64, 4), zeros(2))) == 2 + end -a = sample(4:7, wv, 3; replace=false, ordered=false) -check_wsample_norep(a, (4, 7), wv, -1; ordered=false) + a = sample(4:7, wv, 3; replace=false, ordered=false) + check_wsample_norep(a, (4, 7), wv, -1; ordered=false) -for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) - r = rev ? reverse(4:7) : (4:7) - r = T===Int ? r : T.(r) - aa = Int.(sample(r, wv, 3; replace=false, ordered=true)) - check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev) + for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) + r = rev ? reverse(4:7) : (4:7) + r = T===Int ? r : T.(r) + aa = Int.(sample(r, wv, 3; replace=false, ordered=true)) + check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev) + end end @testset "validation of inputs" begin @@ -143,6 +132,7 @@ end oz = OffsetArray(z, -4:5) @test_throws ArgumentError sample(weights(ox)) + @test_throws DimensionMismatch sample(x, weights(1:5)) for f in (sample!, wsample!, naive_wsample_norep!, efraimidis_a_wsample_norep!, efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!) @@ -158,8 +148,19 @@ end @test_throws ArgumentError f(x, weights(x), x) @test_throws ArgumentError f(y, weights(view(x, 3:5)), view(x, 2:4)) @test_throws ArgumentError f(view(x, 2:4), weights(view(x, 3:5)), view(x, 1:2)) + + # Test that source and weight lengths agree + @test_throws DimensionMismatch f(x, weights(1:5), z) + + # Test that sampling without replacement can't draw more than what's available + if endswith(String(nameof(f)), "_norep!") + @test_throws DimensionMismatch f(x, weights(y), vcat(z, z)) + else + @test_throws DimensionMismatch f(x, weights(y), vcat(z, z); replace=false) + end + # This corner case should theoretically succeed # but it currently fails as Base.mightalias is not smart enough @test_broken f(y, weights(view(x, 5:6)), view(x, 2:4)) end -end \ No newline at end of file +end