Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 3b9566bd authored by Dahua Lin's avatar Dahua Lin

population sampling migrated out of Distributions.jl to StatsBase.jl

These functions are imported & exported in this package -- so codes that only importing Distributions can still directly use sample & wsample.
parent c5035db4
tests = [
"types",
"utils",
"sample",
"fit",
"discrete",
"univariate",
"truncate",
"multinomial",
"dirichlet",
"mvnormal",
"mvtdist",
"kolmogorov",
"edgeworth",
"matrix",
"vonmisesfisher",
"compoundvariate",
"conjugates",
"fit",
"discrete",
"univariate",
"truncate",
"multinomial",
"dirichlet",
"mvnormal",
"mvtdist",
"kolmogorov",
"edgeworth",
"matrix",
"vonmisesfisher",
"compoundvariate",
"conjugates",
"conjugates_normal",
"conjugates_mvnormal",
"wishart"]
......@@ -24,7 +22,7 @@ tests = [
println("Running tests:")
for t in tests
test_fn = joinpath("test", "$t.jl")
test_fn = joinpath("test", "$t.jl")
println(" * $test_fn")
include(test_fn)
end
......@@ -169,12 +169,11 @@ export
wsample, # weighted sampling from a source array
expected_logdet # expected logarithm of random matrix determinant
import Base.mean, Base.median, Base.quantile, Base.scale
import Base.max, Base.min, Base.maximum, Base.minimum
import Base.Random, Base.rand, Base.rand!, Base.std, Base.var, Base.cor, Base.cov
import Base.show, Base.sprand
import NumericExtensions.dim, NumericExtensions.entropy
import StatsBase.kurtosis, StatsBase.skewness, StatsBase.mode, StatsBase.modes
import Base.Random
import Base: show, scale, sum!, rand, rand!, sprand
import Base: mean, median, maximum, minimum, quantile, std, var, cov, cor
import NumericExtensions: dim, entropy
import StatsBase: kurtosis, skewness, mode, modes, randi, RandIntSampler
#### Distribution type system
......@@ -221,7 +220,6 @@ include("rmath.jl")
include("specialfuns.jl")
include("tvpack.jl")
include("utils.jl")
include("sample.jl")
include(joinpath("samplers", "categorical_samplers.jl"))
......
......@@ -26,9 +26,8 @@ posterior_canon(pri::Dirichlet, ss::MultinomialStats) = DirichletCanon(pri.alpha
function posterior_canon{T<:Real}(pri::Dirichlet, G::Type{Multinomial}, x::Matrix{T})
d = dim(pri)
size(x,1) == d || throw(ArgumentError("Inconsistent argument dimensions."))
a = Array(Float64, d)
add!(sum!(a, x, 2), pri.alpha)
DirichletCanon(a)
a = add!(sum(x, 2), pri.alpha)
DirichletCanon(vec(a))
end
function posterior_canon{T<:Real}(pri::Dirichlet, G::Type{Multinomial}, x::Matrix{T}, w::Array{Float64})
......
# Sample from arbitrary arrays
################################################################
#
# A variety of algorithms for sampling without replacement
#
# They are suited for different cases.
#
# Particularly,
# - Fisher-Yates sampler is suited for general cases
# where n is not overly large
#
# - Self avoiding sampler is suited for cases where k << n
#
################################################################
function pick2!(a::AbstractArray, x::AbstractArray)
# Pick a pair of values without replacement
n0 = length(a)
i1 = randi(n0)
i2 = randi(n0 - 1)
if i2 == i1
i2 = n0
end
x[1] = a[i1]
x[2] = a[i2]
end
## A sampler that implements without-replacement sampling
## via Fisher-Yates shuffling
##
immutable FisherYatesSampler
n::Int
seq::Vector{Int} # Internal sequence for shuffling
FisherYatesSampler(n::Int) = new(n, [1:n])
end
function rand!(s::FisherYatesSampler, a::AbstractArray, x::AbstractArray)
# draw samples without-replacement to x
n::Int = s.n
k::Int = length(x)
if k > n
throw(ArgumentError("Cannot draw more than n samples without replacement."))
end
seq::Vector{Int} = s.seq
for i = 1:k
j = randi(i, n)
sj = seq[j]
x[i] = a[sj]
seq[j] = seq[i]
seq[i] = sj
end
x
end
fisher_yates_sample!(a::AbstractArray, x::AbstractArray) = rand!(FisherYatesSampler(length(a)), a, x)
function self_avoid_sample!{T}(a::AbstractArray{T}, x::AbstractArray)
# This algorithm is suitable when length(x) << length(a)
s = Set{T}()
# sizehint(s, length(x))
rgen = RandIntSampler(length(a))
# first one
idx = rand(rgen)
x[1] = a[idx]
push!(s, idx)
# remaining
for i = 2:length(x)
idx = rand(rgen)
while in(s, idx)
idx = rand(rgen)
end
x[i] = a[idx]
push!(s, idx)
end
x
end
# Ordered sampling without replacement
# Author: Mike Innes
function rand_first_index(n, k)
r = rand()
p = k/n
i = 1
while p < r
i += 1
p += (1-p)k/(n-(i-1))
end
return i
end
function ordered_sample_norep!(xs::AbstractArray, target::AbstractArray)
n = length(xs)
k = length(target)
i = 0
for j in 1:k
step = rand_first_index(n, k)
n -= step
i += step
target[j] = xs[i]
k -= 1
end
return target
end
function ordered_sample_rep!(xs::AbstractArray, target::AbstractArray)
n = length(xs)
n_left = n
k = length(target)
j = 0
for i = 1:n
k > 0 || break
num = i == n ? k : rand(Binomial(k, 1/n_left))
for _ = 1:num
j += 1
target[j] = xs[i]
end
k -= num
n_left -= 1
end
return target
end
###########################################################
#
# Interface functions
#
###########################################################
sample(a::AbstractArray) = a[randi(length(a))]
function sample!(a::AbstractArray, x::AbstractArray; replace=true, ordered=false)
n = length(a)
k = length(x)
if !isempty(x)
if ordered
replace ? ordered_sample_rep!(a, x) : ordered_sample_norep!(a, x)
else
if replace # with replacement
s = RandIntSampler(n)
for i = 1:k
x[i] = a[rand(s)]
end
else # without replacement
if k > n
throw(ArgumentError("n exceeds the length of x"))
end
if k == 1
x[1] = sample(a)
elseif k == 2
pick2!(a, x)
elseif n < k * max(k, 100)
fisher_yates_sample!(a, x)
else
self_avoid_sample!(a, x)
end
end
end
end
x
end
function sample{T}(a::AbstractArray{T}, n::Integer; replace=true, ordered=false)
sample!(a, Array(T, n); replace=replace, ordered=ordered)
end
function sample{T}(a::AbstractArray{T}, dims::Dims; replace=true, ordered=false)
sample!(a, Array(T, dims); replace=replace, ordered=ordered)
end
################################################################
#
# Weighted sampling
#
################################################################
function wsample(ws::AbstractArray; wsum::Number = sum(ws))
t = rand() * wsum
i = 0
p = 0.
while p < t
i += 1
p += ws[i]
end
return i
end
wsample(xs::AbstractArray, ws::AbstractArray; wsum::Number = sum(ws)) =
xs[wsample(ws, wsum=wsum)]
# Author: Mike Innes
function ordered_wsample!(xs::AbstractArray, ws::AbstractArray, target::AbstractArray; wsum::Number = sum(ws))
n = length(xs)
k = length(target)
j = 0
length(ws) == n || throw(ArgumentError("Inconsistent argument dimensions."))
for i = 1:n
k > 0 || break
num = i == n ? k : rand(Binomial(k, ws[i]/wsum))
for _ = 1:num
j += 1
target[j] = xs[i]
end
k -= num
wsum -= ws[i]
end
return target
end
function wsample!(xs::AbstractArray, ws::AbstractArray, target::AbstractArray; wsum::Number = sum(ws), ordered::Bool = false)
k = length(target)
ordered && return ordered_wsample!(xs, ws, target, wsum = wsum)
k > 100 && return ordered_wsample!(xs, ws, target, wsum = wsum) |> shuffle!
for i = 1:k
target[i] = wsample(xs)
end
return target
end
wsample(xs::AbstractArray, ws::AbstractArray, k; wsum::Number = sum(ws), ordered::Bool = false) =
wsample!(xs, ws, similar(xs, k), wsum = wsum, ordered = ordered)
immutable AliasTable <: AbstractCategoricalSampler
accept::Vector{Float64}
alias::Vector{Int}
......
......@@ -33,47 +33,6 @@ function isprobvec(p::Vector{Float64})
end
function _randu(Ku::Uint, U::Uint) # ~ U[0:Ku-1]
x = rand(Uint)
while x > U
x = rand(Uint)
end
rem(x, Ku)
end
function randi(K::Int) # Fast method to draw a random integer from 1:K
Ku = uint(K)
U = div(typemax(Uint), Ku) * Ku
int(_randu(Ku, U)) + 1
end
function randi(a::Int, b::Int) # ~ U[a:b]
Ku = uint(b - a + 1)
U = div(typemax(Uint), Ku) * Ku
int(_randu(Ku, U)) + a
end
immutable RandIntSampler
a::Int
Ku::Uint
U::Uint
function RandIntSampler(K::Int) # 1:K
Ku = uint(K)
U = div(typemax(Uint), Ku) * Ku
new(1, Ku, U)
end
function RandIntSampler(a::Int, b::Int) # a:b
Ku = uint(b - a + 1)
U = div(typemax(Uint), Ku) * Ku
new(a, Ku, U)
end
end
rand(s::RandIntSampler) = int(_randu(s.Ku, s.U)) + s.a
# Routines for sampling from Gamma distribution
#
# The reason why these functions are in utils.jl instead of gamma.jl is:
......
# Test sample functions
using Distributions
using Base.Test
function est_p(x, K)
h = zeros(Int, K)
for xi in x
h[xi] += 1
end
p = h / length(x)
end
#### sample with replacement
n = 10^5
x = sample([10,20,30], n)
@test isa(x, Vector{Int})
@test length(x) == n
h = [sum(x .== 10), sum(x .== 20), sum(x .== 30)]
@test sum(h) == n
ph = h / n
p0 = fill(1/3, 3)
@test_approx_eq_eps ph p0 0.02
#### sample without replacement
# case: K == 2
n = 10^5
x = zeros(Int, 2, n)
for i = 1:n
v = sample(11:15, 2; replace=false)
@assert v[1] != v[2]
x[:,i] = v
end
@test minimum(x) == 11
@test maximum(x) == 15
x[:] -= 10 # brings x to 1:5
p0 = fill(1/5, 5)
@test_approx_eq_eps est_p(x, 5) p0 0.02
# case: K == 4 with moderate a (using Fisher-Yates)
n = 10^5
x = zeros(Int, 4, n)
for i = 1 : n
v = sample(11:20, 4; replace=false)
sv = sort(v)
@assert sv[1] < sv[2] < sv[3] < sv[4]
x[:,i] = v
end
@test minimum(x) == 11
@test maximum(x) == 20
x[:] -= 10
p0 = fill(0.1, 10)
@test_approx_eq_eps est_p(x, 10) p0 0.01
# case: K == 4 with very large a (using self-avoid)
n = 10^4
x = zeros(Int, 4, n)
a = 10^7 + 1
b = 2 * 10^7
for i = 1 : n
v = sample(a:b, 4; replace=false)
sv = sort(v)
@assert sv[1] < sv[2] < sv[3] < sv[4]
x[:,i] = v
end
@test minimum(x) >= a
@test maximum(x) <= b
#### weighted sampling
w = [2., 5., 3.]
n = 10^5
x = wsample([10,20,30], w, n)
h = [sum(x .== 10), sum(x .== 20), sum(x .== 30)]
@test sum(h) == n
p0 = w / sum(w)
ph = h / n
@test_approx_eq_eps ph p0 0.02
using Distributions
using Base.Test
const randi = Distributions.randi
n = 1_000_000
x = Int[randi(10) for i = 1:n]
@test minimum(x) == 1
@test maximum(x) == 10
x = Int[randi(3, 12) for i = 1:n]
@test minimum(x) == 3
@test maximum(x) == 12
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment