Commit f7ca4f72 authored by Dahua Lin's avatar Dahua Lin

use GammaSampler

The introduction of GammaSampler to capture the pre-computation & sampling for gamma random number generation simplify codes that rely on the generation of gamma random numbers.
parent e2856948
......@@ -222,6 +222,7 @@ include("tvpack.jl")
include("utils.jl")
include(joinpath("samplers", "categorical_samplers.jl"))
include(joinpath("samplers", "gamma_sampler.jl"))
# Univariate distributions
include(joinpath("univariate", "arcsine.jl"))
......
# Sampler for drawing random number of a Gamma distribution
# Routines for sampling from Gamma distribution
# A simple method for generating gamma variables - Marsaglia and Tsang (2000)
# http://www.cparity.com/projects/AcmClassification/samples/358414.pdf
# Page 369
# basic simulation loop for pre-computed d and c
const _v13 = 1.0 / 3.0
# a sampler for Gamma(α)
immutable GammaSampler
α::Float64
::Float64 # iα = α > 1 ? 1.0 : 1.0 / α
d::Float64
c::Float64
function GammaSampler(α::Float64)
local ::Float64, d::Float64
if α > 1.0
= 1.0
d = α - _v13
else
= 1.0 / α
d = α + 1.0 - _v13
end
new(α, , d, 1.0 / sqrt(9.0 * d))
end
end
function rand(s::GammaSampler)
d::Float64 = s.d
c::Float64 = s.c
v = 0.0
while true
x = randn()
v = 1.0 + c * x
while v <= 0.0
x = randn()
v = 1.0 + c * x
end
v *= (v * v)
u = rand()
x2 = x^2
if u < 1.0 - 0.331 * x2^2
break
end
if log(u) < 0.5 * x2 + d * (1.0 - v + log(v))
break
end
end
v *= d
if s.α <= 1.0
v *= (rand()^s.)
end
return v::Float64
end
randg(α::Float64) = rand(GammaSampler(α))
randg(α::Real) = rand(float64(α))
......@@ -45,49 +45,15 @@ function rand(d::Beta)
end
function rand!(d::Beta, A::Array{Float64})
α = d.alpha
β = d.beta
sa = GammaSampler(d.alpha)
sb = GammaSampler(d.beta)
da = (α <= 1.0 ? α + 1.0 : α) - 1.0 / 3.0
ca = 1.0 / sqrt(9.0 * da)
db = (β <= 1.0 ? β + 1.0 : β) - 1.0 / 3.0
cb = 1.0 / sqrt(9.0 * db)
n = length(A)
if α > 1.0
if β > 1.0
for i = 1:n
u = randg2(da, ca)
v = randg2(db, cb)
for i = 1:length(A)
u = rand(sa)
v = rand(sb)
@inbounds A[i] = u / (u + v)
end
else
invβ = 1.0 / β
for i = 1:n
u = randg2(da, ca)
v = randg2(db, cb) * (rand()^invβ)
@inbounds A[i] = u / (u + v)
end
end
else
invα = 1.0 / α
if β > 1.0
for i = 1:n
u = randg2(da, ca) * (rand()^invα)
v = randg2(db, cb)
@inbounds A[i] = u / (u + v)
end
else
invβ = 1.0 / β
for i = 1:n
u = randg2(da, ca) * (rand()^invα)
v = randg2(db, cb) * (rand()^invβ)
@inbounds A[i] = u / (u + v)
end
end
end
return A
end
......
......@@ -43,21 +43,16 @@ end
function rand!(d::Chisq, A::Array{Float64})
if d.df == 1
for i in 1:length(A)
A[i] = randn()^2
for i = 1:length(A)
@inbounds A[i] = randn()^2
end
return A
end
if d.df >= 2
dpar = d.df / 2.0 - 1.0 / 3.0
else
error("require degrees of freedom df >= 2")
s = GammaSampler(d.df / 2.0)
for i = 1:length(A)
@inbounds A[i] = 2.0 * rand(s)
end
cpar = 1.0 / sqrt(9.0 * dpar)
for i in 1:length(A)
A[i] = 2.0 * randg2(dpar, cpar)
end
A
return A
end
skewness(d::Chisq) = sqrt(8.0 / d.df)
......
......@@ -43,18 +43,9 @@ modes(d::Gamma) = [mode(d)]
rand(d::Gamma) = d.scale * randg(d.shape)
function rand!(d::Gamma, A::Array{Float64})
α = d.shape
dpar = (α <= 1.0 ? α + 1.0 : α) - 1.0 / 3.0
cpar = 1.0 / sqrt(9.0 * dpar)
n = length(A)
for i in 1:n
A[i] = randg2(dpar, cpar)
end
if α <= 1.0
ainv = 1.0 / α
for i in 1:n
A[i] *= rand()^ainv
end
s = GammaSampler(d.shape)
for i = 1:length(A)
A[i] = rand(s)
end
multiply!(A, d.scale)
end
......
......@@ -32,42 +32,6 @@ function isprobvec(p::Vector{Float64})
return abs(s - 1.0) <= 1.0e-12
end
# Routines for sampling from Gamma distribution
#
# The reason why these functions are in utils.jl instead of gamma.jl is:
# other distributions (e.g. Beta & Dirichlet) also use them
#
# A simple method for generating gamma variables - Marsaglia and Tsang (2000)
# http://www.cparity.com/projects/AcmClassification/samples/358414.pdf
# Page 369
# basic simulation loop for pre-computed d and c
function randg2(d::Float64, c::Float64)
while true
x = v = 0.0
while v <= 0.0
x = randn()
v = 1.0 + c * x
end
v = v^3
U = rand()
x2 = x^2
if U < 1.0 - 0.331 * x2^2 ||
log(U) < 0.5 * x2 + d * (1.0 - v + log(v))
return d * v
end
end
end
# sampling from Gamma(α, 1)
function randg(α::Float64)
dpar = (α <= 1.0 ? α + 1.0 : α) - 1.0 / 3.0
cpar = 1.0 / sqrt(9.0 * dpar)
randg2(dpar, cpar) * (α > 1.0 ? 1.0 : rand()^(1.0 / α))
end
# macros for generating functions for support handling
#
# Both lb & ub should be compile-time constants
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment