Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit b1acfe3c authored by Chris DuBois's avatar Chris DuBois

Add VonMisesFisher, a ContinuousMultivariateDistribution for unit vectors.

parent cfcd8dc5
......@@ -13,7 +13,8 @@ tests = [
"conjugates",
"kolmogorov",
"edgeworth",
"matrix"]
"matrix",
"vonmisesfisher"]
println("Running tests:")
......
......@@ -78,6 +78,7 @@ export
Triangular,
Truncated,
Uniform,
VonMisesFisher,
Weibull,
Wishart,
QQPair,
......@@ -240,6 +241,7 @@ include(joinpath("univariate", "weibull.jl"))
include(joinpath("multivariate", "dirichlet.jl"))
include(joinpath("multivariate", "multinomial.jl"))
include(joinpath("multivariate", "multivariatenormal.jl"))
include(joinpath("multivariate", "vonmisesfisher.jl"))
# Matrix distributions
include(joinpath("matrix", "inversewishart.jl"))
......
# Von-Mises Fisher: a multivariate distribution useful in directional statistics
# Useful notes:
# http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
# Some of the code adapted from http://www.unc.edu/~sungkyu/manifolds/randvonMisesFisherm.m
# as well as the movMF R package.
immutable VonMisesFisher <: ContinuousMultivariateDistribution
mu::Vector{Float64}
kappa::Float64
function VonMisesFisher{T <: Real}(mu::Vector{T}, kappa::Float64)
mu = mu / norm(mu)
if kappa < 0
throw(ArgumentError("kappa must be a nonnegative real number."))
end
new(float64(mu), kappa)
end
end
dim(d::VonMisesFisher) = length(d.mu)
mean(d::VonMisesFisher) = d.mu
scale(d::VonMisesFisher) = d.kappa
insupport{T <: Real}(d::VonMisesFisher, x::Vector{T}) = abs(sum(x) - 1.) < 1e-8
function rand(d::VonMisesFisher, n::Int64)
randvonMisesFisher(n, d.kappa, d.mu)
end
function logpdf(d::VonMisesFisher, x::Vector{Float64}; stable=true)
if abs(d.kappa - 0.0) < eps() return 1/4/pi; end
if stable
# As suggested by Wenzel Jakob: http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
return d.kappa * dot(d.mu, x) - d.kappa + log(d.kappa) - log(2*pi) - log(1-exp(-2*d.kappa))
else
# As described on Wikipedia
p = dim(d)
logCpk = 0.0
if p == 3
logCpk = log(d.kappa) - log(2 * pi * (exp(kappa) - exp(-kappa)))
else
logCpk = (p/2 - 1) * log(d.kappa) - (p/2) * log(2*pi) - log(besselj(p/2-1, d.kappa))
end
return d.kappa * dot(d.mu, x) + logCpk
end
end
# Helper functions
# Sample n vectors x ~ VonMisesFisher(mu, kappa)
function randvonMisesFisher(n, kappa, mu)
m = length(mu)
w = rW(n, kappa, m)
v = rand(MvNormal(zeros(m-1), eye(m-1)), n)
v = normalize(v',2,2)
r = sqrt(1 - w .^ 2)
for j = 1:size(v,2) v[:,j] = v[:,j] .* r; end
x = hcat(v, w)
mu = mu / norm(mu)
return rotMat(mu)'*x'
end
# Randomly sample W
function rW(n, kappa, m)
y = zeros(n)
l = kappa;
d = m - 1;
b = (- 2. * l + sqrt(4. * l * l + d * d)) / d;
x = (1. - b) / (1. + b);
c = l * x + d * log(1. - x * x);
w = 0
for i=1:n
done = false
while !done
z = rand(Beta(d / 2., d / 2.))
w = (1. - (1. + b) * z) / (1. - (1. - b) * z);
u = rand()
if l * w + d * log(1. - x * w) - c >= log(u)
done = true
end
end
y[i] = w
end
return y
end
# Rotation helper function
function rotMat(b)
d = length(b)
b= b/norm(b)
a = [zeros(d-1,1); 1]
alpha = acos(a'*b)[1]
c = b - a * (a'*b); c = c / norm(c)
A = a*c' - c*a'
return eye(d) + sin(alpha)*A + (cos(alpha) - 1)*(a*a' +c*c')
end
# Each row of x assumed to be ~ VonMisesFisher(mu, kappa)
# MLE notes from: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.186.1887&rep=rep1&type=pdf
function fit_mle(::Type{VonMisesFisher}, x::Matrix{Float64})
(n,p) = size(x)
sx = sum(x, 1)
mu = sx[:] / norm(sx)
rbar = norm(sx) / n
kappa0 = rbar * (p-rbar^2) / (1-rbar^2) # Eqn. 4
# TODO: Include a few Newton steps to get a better approximation.
# A(p,kappa) = besselj(p/2, kappa) / besselj(p/2-1, kappa)
# apk0 = A(p,kappa0)
# kappa1 = kappa0 + (apk0 - rbar) / (1 - apk0^2 - (p-1)*apk0/kappa0)
# apk1 = A(p,kappa1)
# kappa2 = kappa1 + (apk1 - rbar) / (1 - apk1^2 - (p-1)*apk1/kappa1)
return VonMisesFisher(mu, kappa0)#, kappa1, kappa2)
end
\ No newline at end of file
# Tests for Von-Mises Fisher distribution
using Distributions
using Base.Test
D = 3
mu = randn(D)
mu = mu / norm(mu)
kappa = 100.0
d = VonMisesFisher(mu, kappa)
# Basics
@test dim(d) == D
@test d.kappa == kappa
@test_approx_eq d.mu mean(d)
@test_approx_eq norm(d.mu) 1.0
# MLE
x = rand(d, 10_000)
dmle = fit_mle(VonMisesFisher, x')
@test all(abs(mean(d) - mean(dmle)) .< .01)
@test_approx_eq norm(dmle.mu) 1.0
@test abs(scale(dmle) - scale(d)) < .01 * scale(d) # within 1% ?
# Density
# TODO: Check against R's movMF. (Currently I'm a bit suspicious about their code.)
# > set.seed(1)
# > mu=c(1,0,0)
# > kappa=1.
# > x = rmovMF(1, mu, kappa)
# > x
# [,1] [,2] [,3]
# [1,] 0.1772372 -0.4566632 0.871806
# > dmovMF(x, mu, kappa)
# [1] 1.015923
# WEIRD:
# > dmovMF(x, mu, 100)
# [1] 1.015923
# > dmovMF(x, mu, 1)
# [1] 1.015923
# mu = [1.0, 0., 0.]
# kappa = 1.0
# x = [0.1772372, -0.4566632, 0.871806]
# d = VonMisesFisher(mu, kappa)
#@test abs(logpdf(d, x) - 1.015923) < .00001
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment