Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 33611bd1 authored by Dahua Lin's avatar Dahua Lin

Merge pull request #302 from JuliaStats/dh/vmf2

Reimplement von Mises-Fisher Distribution
parents 929401de 5ae0432c
......@@ -141,6 +141,7 @@ export
cquantile, # complementary quantile (i.e. using prob in right hand tail)
cumulant, # cumulants of distribution
complete, # turn an incomplete formulation into a complete distribution
concentration, # the concentration parameter
dim, # sample dimension of multivariate distribution
entropy, # entropy of distribution in nats
fit, # fit a distribution to data (using default method)
......@@ -183,6 +184,7 @@ export
sqmahal, # squared Mahalanobis distance to Gaussian center
sqmahal!, # inplace evaluation of sqmahal
mean, # mean of distribution
meandir, # mean direction (of a spherical distribution)
meanform, # convert a normal distribution from canonical form to mean form
median, # median of distribution
mgf, # moment generating function
......
# Von-Mises Fisher: a multivariate distribution useful in directional statistics
# Useful notes:
# http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
# Some of the code adapted from http://www.unc.edu/~sungkyu/manifolds/randvonMisesFisherm.m
# as well as the movMF R package.
# von Mises-Fisher distribution is useful for directional statistics
#
# The implementation here follows:
#
# - Wikipedia:
# http://en.wikipedia.org/wiki/Von_Mises–Fisher_distribution
#
# - R's movMF package's document:
# http://cran.r-project.org/web/packages/movMF/vignettes/movMF.pdf
#
# - Wenzel Jakob's notes:
# http://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf
#
immutable VonMisesFisher <: ContinuousMultivariateDistribution
mu::Vector{Float64}
kappa::Float64
μ::Vector{Float64}
κ::Float64
logCκ::Float64
function VonMisesFisher{T <: Real}(mu::Vector{T}, kappa::Float64)
mu = mu ./ norm(mu)
if kappa < 0
throw(ArgumentError("kappa must be a nonnegative real number."))
function VonMisesFisher(μ::Vector{Float64}, κ::Float64; checknorm::Bool=true)
if checknorm
isunitvec(μ) || error("μ must be a unit vector")
end
new(float64(mu), kappa)
κ > 0 || error("κ must be positive.")
new(μ, κ, vmflck(length(μ), κ))
end
end
length(d::VonMisesFisher) = length(d.mu)
mean(d::VonMisesFisher) = d.mu
scale(d::VonMisesFisher) = d.kappa
VonMisesFisher{T<:Real}(μ::Vector{T}, κ::Real) = VonMisesFisher(float64(μ), float64(κ))
insupport{T<:Real}(d::VonMisesFisher, x::AbstractVector{T}) = abs(sum(x) - 1.) < 1e-8
VonMisesFisher(θ::Vector{Float64}) = (κ = vecnorm(θ); VonMisesFisher(scale(θ, 1.0 / κ), κ))
VonMisesFisher{T<:Real}(θ::Vector{T}) = VonMisesFisher(float64(θ))
function _logpdf{T<:Real}(d::VonMisesFisher, x::DenseVector{T}; stable=true)
if abs(d.kappa - 0.0) < eps()
return 0.25 / pi
end
if stable
# As suggested by Wenzel Jakob: http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
return d.kappa * dot(d.mu, x) - d.kappa + log(d.kappa) - log(2*pi) - log(1-exp(-2*d.kappa))
else
# As described on Wikipedia
p = length(d)
logCpk = 0.0
if p == 3
logCpk = log(d.kappa) - log(2 * pi * (exp(kappa) - exp(-kappa)))
else
logCpk = (p/2 - 1) * log(d.kappa) - (p/2) * log(2*pi) - log(besselj(p/2-1, d.kappa))
end
return d.kappa * dot(d.mu, x) + logCpk
end
end
show(io::IO, d::VonMisesFisher) = show(io, d, (:μ, :κ))
### Basic properties
# sampling (TODO: make it consistent with the common API)
length(d::VonMisesFisher) = length(d.μ)
function rand(d::VonMisesFisher, n::Int)
randvonMisesFisher(n, d.kappa, d.mu)
meandir(d::VonMisesFisher) = d.μ
concentration(d::VonMisesFisher) = d.κ
insupport{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = isunitvec(x)
### Evaluation
function _vmflck(p, κ)
hp = 0.5 * p
q = hp - 1.0
q * log(κ) - hp * log(2π) - log(besseli(q, κ))
end
_vmflck3(κ) = log(κ) - log2π - κ - log1mexp(-2.0 * κ)
vmflck(p, κ) = (p == 3 ? _vmflck3(κ) : _vmflck(p, κ))::Float64
function randvonMisesFisher(n, kappa, mu)
m = length(mu)
w = rW(n, kappa, m)
v = rand(MvNormal(zeros(m-1), eye(m-1)), n)
# normalize each column of v
for j = 1:n
s = 0.
vj = view(v,:,j)
for i = 1:size(v,1)
s += abs2(vj[i])
end
s = sqrt(s)
for i = 1:size(v,1)
vj[i] /= s
end
end
v = v'
_logpdf{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = d.logCκ + d.κ * dot(d.μ, x)
### Sampling
sampler(d::VonMisesFisher) = VonMisesFisherSampler(d.μ, d.κ)
_rand!(d::VonMisesFisher, x::DenseVector) = _rand!(sampler(d), x)
_rand!(d::VonMisesFisher, x::DenseMatrix) = _rand!(sampler(d), x)
### Estimation
r = sqrt(1.0 .- w .^ 2)
for j = 1:size(v,2) v[:,j] = v[:,j] .* r; end
x = hcat(v, w)
mu = mu / norm(mu)
return rotMat(mu)'*x'
function fit_mle(::Type{VonMisesFisher}, X::Matrix{Float64})
r = vec(sum(X, 2))
n = size(X, 2)
r_nrm = vecnorm(r)
μ = scale!(r, 1.0 / r_nrm)
ρ = r_nrm / n
κ = _vmf_estkappa(length(μ), ρ)
VonMisesFisher(μ, κ)
end
# Randomly sample W
function rW(n, kappa, m)
y = zeros(n)
l = kappa;
d = m - 1;
b = (- 2. * l + sqrt(4. * l * l + d * d)) / d;
x = (1. - b) / (1. + b);
c = l * x + d * log(1. - x * x);
w = 0
for i=1:n
done = false
while !done
z = rand(Beta(d / 2., d / 2.))
w = (1. - (1. + b) * z) / (1. - (1. - b) * z);
u = rand()
if l * w + d * log(1. - x * w) - c >= log(u)
done = true
end
fit_mle{T<:Real}(::Type{VonMisesFisher}, X::Matrix{T}) = fit_mle(VonMisesFisher, float64(X))
function _vmf_estkappa(p::Int, ρ::Float64)
# Using the fixed-point iteration algorithm in the following paper:
#
# Akihiro Tanabe, Kenji Fukumizu, and Shigeyuki Oba, Takashi Takenouchi, and Shin Ishii
# Parameter estimation for von Mises-Fisher distributions.
# Computational Statistics, 2007, Vol. 22:145-157.
#
const maxiter = 200
half_p = 0.5 * p
ρ2 = abs2(ρ)
κ = ρ * (p - ρ2) / (1 - ρ2)
i = 0
while i < maxiter
i += 1
κ_prev = κ
a = (ρ / _vmfA(half_p, κ))
# println("i = $i, a = $a, abs(a - 1) = $(abs(a - 1))")
κ *= a
if abs(a - 1.0) < 1.0e-12
break
end
y[i] = w
end
return y
return κ
end
# Rotation helper function
function rotMat(b)
d = length(b)
b= b/norm(b)
a = [zeros(d-1,1); 1]
alpha = acos(a'*b)[1]
c = b - a * (a'*b); c = c / norm(c)
A = a*c' - c*a'
return eye(d) + sin(alpha)*A + (cos(alpha) - 1)*(a*a' +c*c')
end
_vmfA(half_p::Float64, κ::Float64) = besseli(half_p, κ) / besseli(half_p - 1.0, κ)
# Each row of x assumed to be ~ VonMisesFisher(mu, kappa)
# MLE notes from: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.186.1887&rep=rep1&type=pdf
function fit_mle(::Type{VonMisesFisher}, x::Matrix{Float64})
(n,p) = size(x)
sx = sum(x, 1)
mu = sx[:] / norm(sx)
rbar = norm(sx) / n
kappa0 = rbar * (p-rbar^2) / (1-rbar^2) # Eqn. 4
# TODO: Include a few Newton steps to get a better approximation.
# A(p,kappa) = besselj(p/2, kappa) / besselj(p/2-1, kappa)
# apk0 = A(p,kappa0)
# kappa1 = kappa0 + (apk0 - rbar) / (1 - apk0^2 - (p-1)*apk0/kappa0)
# apk1 = A(p,kappa1)
# kappa2 = kappa1 + (apk1 - rbar) / (1 - apk1^2 - (p-1)*apk1/kappa1)
return VonMisesFisher(mu, kappa0)#, kappa1, kappa2)
end
......@@ -6,7 +6,8 @@ for fname in ["categorical.jl",
"exponential.jl",
"gamma.jl",
"multinomial.jl",
"vonmises.jl"]
"vonmises.jl",
"vonmisesfisher.jl"]
include(joinpath("samplers", fname))
end
# Sampler for von Mises-Fisher
immutable VonMisesFisherSampler
p::Int # the dimension
κ::Float64
b::Float64
x0::Float64
c::Float64
Q::Matrix{Float64}
end
function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64)
p = length(μ)
b = _vmf_bval(p, κ)
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
Q = _vmf_rotmat(μ)
VonMisesFisherSampler(p, κ, b, x0, c, Q)
end
function _rand!(spl::VonMisesFisherSampler, x::DenseVector, t::DenseVector)
w = _vmf_genw(spl)
p = spl.p
t[1] = w
s = 0.0
for i = 2:p
t[i] = ti = randn()
s += abs2(ti)
end
# normalize t[2:p]
r = sqrt((1.0 - abs2(w)) / s)
for i = 2:p
t[i] *= r
end
# rotate
A_mul_B!(x, spl.Q, t)
return x
end
_rand!(spl::VonMisesFisherSampler, x::DenseVector) = _rand!(spl, x, Array(Float64, length(x)))
function _rand!(spl::VonMisesFisherSampler, x::DenseMatrix)
t = Array(Float64, size(x, 1))
for j = 1:size(x, 2)
_rand!(spl, view(x,:,j), t)
end
return x
end
### Core computation
_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))
function _vmf_genw(p, b, x0, c, κ)
# generate the W value -- the key step in simulating vMF
#
# following movMF's document
#
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand())
z = rand(betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
end
return w::Float64
end
_vmf_genw(s::VonMisesFisherSampler) = _vmf_genw(s.p, s.b, s.x0, s.c, s.κ)
function _vmf_rotmat(u::Vector{Float64})
# construct a rotation matrix Q
# s.t. Q * [1,0,...,0]^T --> u
#
# Strategy: construct a full-rank matrix
# with first column being u, and then
# perform QR factorization
#
p = length(u)
A = zeros(p, p)
copy!(view(A,:,1), u)
# let k the be index of entry with max abs
k = 1
a = abs(u[1])
for i = 2:p
@inbounds ai = abs(u[i])
if ai > a
k = i
a = ai
end
end
# other columns of A will be filled with
# indicator vectors, except the one
# that activates the k-th entry
i = 1
for j = 2:p
if i == k
i += 1
end
A[i, j] = 1.0
end
# perform QR factorization
Q = full(qrfact!(A)[:Q])
if dot(view(Q,:,1), u) < 0.0 # the first column was negated
for i = 1:p
@inbounds Q[i,1] = -Q[i,1]
end
end
return Q
end
......@@ -23,6 +23,8 @@ convert{T}(::Type{Vector{T}}, v::ZeroVector{T}) = full(v)
type NoArgCheck end
isunitvec{T}(v::AbstractVector{T}) = (vecnorm(v) - 1.0) < 1.0e-12
function allfinite{T<:Real}(x::Array{T})
for i = 1 : length(x)
if !(isfinite(x[i]))
......
......@@ -3,49 +3,73 @@
using Distributions
using Base.Test
D = 3
mu = randn(D)
mu = mu / norm(mu)
kappa = 100.0
d = VonMisesFisher(mu, kappa)
# Basics
@test length(d) == D
@test d.kappa == kappa
@test_approx_eq d.mu mean(d)
@test_approx_eq norm(d.mu) 1.0
# MLE
x = rand(d, 10_000)
dmle = fit_mle(VonMisesFisher, x')
@test all(abs(mean(d) - mean(dmle)) .< .01)
@test_approx_eq norm(dmle.mu) 1.0
#@test abs(scale(dmle) - scale(d)) < .01 * scale(d) # within 1%? not always...
# Density
# TODO: Check against R's movMF. (Currently I'm a bit suspicious about their code.)
# > set.seed(1)
# > mu=c(1,0,0)
# > kappa=1.
# > x = rmovMF(1, mu, kappa)
# > x
# [,1] [,2] [,3]
# [1,] 0.1772372 -0.4566632 0.871806
# > dmovMF(x, mu, kappa)
# [1] 1.015923
# WEIRD:
# > dmovMF(x, mu, 100)
# [1] 1.015923
# > dmovMF(x, mu, 1)
# [1] 1.015923
# mu = [1.0, 0., 0.]
# kappa = 1.0
# x = [0.1772372, -0.4566632, 0.871806]
# d = VonMisesFisher(mu, kappa)
#@test abs(logpdf(d, x) - 1.015923) < .00001
vmfCp(p::Int, κ::Float64) = (κ ^ (p/2 - 1)) / ((2π)^(p/2) * besseli(p/2-1, κ))
safe_vmfpdf(μ::Vector, κ::Float64, x::Vector) = vmfCp(length(μ), κ) * exp(κ * dot(μ, x))
function gen_vmf_tdata(n::Int, p::Int)
X = randn(p, n)
for i = 1:n
X[:,i] = X[:,i] ./ vecnorm(X[:,i])
end
return X
end
function test_vonmisesfisher(p::Int, κ::Float64, n::Int, ns::Int)
μ = randn(p)
μ = μ ./ vecnorm(μ)
d = VonMisesFisher(μ, κ)
@test length(d) == p
@test meandir(d) == μ
@test concentration(d) == κ
# println(d)
θ = κ * μ
d2 = VonMisesFisher(θ)
@test length(d2) == p
@test_approx_eq meandir(d2) μ
@test_approx_eq concentration(d2) κ
@test_approx_eq_eps d.logCκ log(vmfCp(p, κ)) 1.0e-12
X = gen_vmf_tdata(n, p)
lp0 = zeros(n)
for i = 1:n
xi = X[:,i]
lp0[i] = log(safe_vmfpdf(μ, κ, xi))
@test_approx_eq logpdf(d, xi) lp0[i]
end
@test_approx_eq logpdf(d, X) lp0
# sampling
x = rand(d)
@test_approx_eq vecnorm(x) 1.0
X = rand(d, n)
for i = 1:n
@test_approx_eq vecnorm(X[:,i]) 1.0
end
# MLE
X = rand(d, ns)
d_est = fit_mle(VonMisesFisher, X)
@test isa(d_est, VonMisesFisher)
@test_approx_eq_eps d_est.μ μ 0.01
@test_approx_eq_eps d_est.κ κ κ * 0.01
end
## General testing
n = 1000
ns = 10^6
for (p, κ) in [(2, 1.0),
(2, 5.0),
(3, 1.0),
(3, 5.0),
(5, 2.0)]
test_vonmisesfisher(p, κ, n, ns)
end
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment