Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit c9889458 authored by Dahua Lin's avatar Dahua Lin

New implementation of Wishart

parent 9dca9376
......@@ -14,3 +14,9 @@ end
@Base.deprecate logpmf logpdf
@Base.deprecate logpmf! logpmf!
@Base.deprecate pmf pdf
#### Deprecate on 0.6 (to be removed on 0.7)
@Base.deprecate expected_logdet meanlogdet
......@@ -79,25 +79,3 @@ function rand!(IW::InverseWishart, X::Array{Matrix{Float64}})
end
var(IW::InverseWishart) = error("Not yet implemented")
# because X == X' keeps failing due to floating point nonsense
function isApproxSymmmetric(a::Matrix{Float64})
tmp = true
for j in 2:size(a, 1)
for i in 1:(j - 1)
tmp &= abs(a[i, j] - a[j, i]) < 1e-8
end
end
return tmp
end
# because isposdef keeps giving the wrong answer for samples
# from Wishart and InverseWisharts
hasCholesky(a::Matrix{Float64}) = isa(trycholfact(a), Cholesky)
function trycholfact(a::Matrix{Float64})
try cholfact(a)
catch e
return e
end
end
##############################################################################
# Wishart distribution
#
# Wishart Distribution
# following the Wikipedia parameterization
#
# Parameters nu and S such that E(X) = nu * S
# See the rwish and dwish implementation in R's MCMCPack
# This parametrization differs from Bernardo & Smith p 435
# in this way: (nu, S) = (2.0 * alpha, 0.5 * beta^-1)
#
##############################################################################
immutable Wishart <: ContinuousMatrixDistribution
nu::Float64
Schol::Cholesky{Float64}
function Wishart(n::Real, Sc::Cholesky{Float64})
if n > size(Sc, 1) - 1
new(float64(n), Sc)
else
error("Wishart parameters must be df > p - 1")
end
end
immutable Wishart{ST<:AbstractPDMat} <: ContinuousMatrixDistribution
df::Float64 # degree of freedom
S::ST # the scale matrix
c0::Float64 # the logarithm of normalizing constant in pdf
end
Wishart(nu::Real, S::Matrix{Float64}) = Wishart(nu, cholfact(S))
#### Constructors
show(io::IO, d::Wishart) = show_multline(io, d, [(:nu, d.nu), (:S, full(d.Schol))])
function Wishart{ST<:AbstractPDMat}(df::Real, S::ST)
p = dim(S)
df > p - 1 || error("df should be greater than dim - 1.")
Wishart{ST}(df, S, _wishart_c0(df, S))
end
Wishart(df::Real, S::Matrix{Float64}) = Wishart(df, PDMat(S))
dim(W::Wishart) = size(W.Schol, 1)
size(W::Wishart) = size(W.Schol)
Wishart(df::Real, S::Cholesky) = Wishart(df, PDMat(S))
function insupport(W::Wishart, X::Matrix{Float64})
return size(X) == size(W) && isApproxSymmmetric(X) && hasCholesky(X)
end
# This just checks if X could come from any Wishart
function insupport(::Type{Wishart}, X::Matrix{Float64})
return size(X, 1) == size(X, 2) && isApproxSymmmetric(X) && hasCholesky(X)
function _wishart_c0(df::Float64, S::AbstractPDMat)
h_df = df / 2
p = dim(S)
h_df * (logdet(S) + p * logtwo) + lpgamma(p, h_df)
end
mean(w::Wishart) = w.nu * (w.Schol[:U]' * w.Schol[:U])
function expected_logdet(W::Wishart)
logd = 0.
d = dim(W)
#### Properties
for i=1:d
logd += digamma(0.5 * (W.nu + 1 - i))
end
insupport(::Type{Wishart}, X::Matrix{Float64}) = isposdef(X)
insupport(d::Wishart, X::Matrix{Float64}) = size(X) == size(d) && isposdef(X)
logd += d * log(2)
logd += logdet(W.Schol)
dim(d::Wishart) = dim(d.S)
size(d::Wishart) = (p = dim(d); (p, p))
return logd
end
function lognorm(W::Wishart)
d = dim(W)
return (W.nu / 2) * logdet(W.Schol) + (d * W.nu / 2) * log(2) + lpgamma(d, W.nu / 2)
end
#### Show
show(io::IO, d::Wishart) = show_multline(io, d, [(:df, d.df), (:S, full(d.S))])
#### Statistics
mean(d::Wishart) = d.df * full(d.S)
function _logpdf{T<:Real}(W::Wishart, X::DenseMatrix{T})
Xchol = trycholfact(X)
if size(X) == size(W) && isApproxSymmmetric(X) && isa(Xchol, Cholesky)
d = dim(W)
logd = -lognorm(W)
logd += 0.5 * (W.nu - d - 1.0) * logdet(Xchol)
logd -= 0.5 * trace(W.Schol \ X)
return logd
else
return -Inf
function meanlogdet(d::Wishart)
p = dim(d)
df = d.df
v = logdet(d.S) + p * logtwo
for i = 1:p
v += digamma(0.5 * (df - (i - 1)))
end
return v
end
function rand(w::Wishart)
p = size(w.Schol, 1)
X = zeros(p, p)
for ii in 1:p
X[ii, ii] = sqrt(rand(Chisq(w.nu - ii + 1)))
end
if p > 1
for col in 2:p
for row in 1:(col - 1)
X[row, col] = randn()
end
end
end
Z = X * w.Schol[:U]
return At_mul_B(Z, Z)
function entropy(d::Wishart)
p = dim(d)
df = d.df
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
end
function entropy(W::Wishart)
d = dim(W)
return lognorm(W) - (W.nu - d - 1) / 2 * expected_logdet(W) + W.nu * d / 2
#### Evaluation
function _logpdf(d::Wishart, X::DenseMatrix{Float64})
Xcf = cholfact(X)
df = d.df
p = dim(d)
0.5 * ((df - (p + 1)) * logdet(Xcf) - trace(d.S \ X)) - d.c0
end
var(w::Wishart) = error("Not yet implemented")
#### Sampling
function rand(d::Wishart)
Z = unwhiten!(d.S, _wishart_genA(dim(d), d.df))
A_mul_Bt(Z, Z)
end
function _wishart_genA(p::Int, df::Float64)
# Generate the matrix A in the Bartlett decomposition
#
# A is a lower triangular matrix, with
#
# A(i, j) ~ sqrt of Chisq(df - i + 1) when i == j
# ~ Normal() when i > j
#
A = zeros(p, p)
for i = 1:p
@inbounds A[i,i] = sqrt(rand(Chisq(df - i + 1.0)))
end
for j = 1:p-1, i = j+1:p
@inbounds A[i,j] = randn()
end
return A
end
......@@ -158,3 +158,29 @@ function simpson(f::AbstractVector{Float64}, h::Float64)
return s * h / 3.0
end
# because X == X' keeps failing due to floating point nonsense
function isApproxSymmmetric(a::Matrix{Float64})
tmp = true
for j in 2:size(a, 1)
for i in 1:(j - 1)
tmp &= abs(a[i, j] - a[j, i]) < 1e-8
end
end
return tmp
end
# because isposdef keeps giving the wrong answer for samples
# from Wishart and InverseWisharts
hasCholesky(a::Matrix{Float64}) = isa(trycholfact(a), Cholesky)
function trycholfact(a::Matrix{Float64})
try cholfact(a)
catch e
return e
end
end
......@@ -17,7 +17,6 @@ tests = [
"conjugates",
"conjugates_normal",
"conjugates_mvnormal",
"wishart",
"mixture",
"gradlogpdf"]
......
# Tests on Wishart distributions
using Distributions
using Base.Test
V = [[2. 1.], [1. 2.]]
W = Wishart(3., V)
# logdet
@test_approx_eq expected_logdet(W) 1.9441809588650447
# entropy
@test_approx_eq entropy(W) 7.178942679971454
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment