Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 4bf313fc authored by Dahua Lin's avatar Dahua Lin

add logpdf & logpdf! for MultivariateNormal and Dirichlet for efficient batch evaluation

parent bd4827b8
......@@ -85,7 +85,9 @@ export # types
logccdf, # ccdf returning log-probability
logcdf, # cdf returning log-probability
logpdf, # log probability density
logpdf!, # evaluate log pdf to provided storage
logpmf, # log probability mass
logpmf!, # evaluate log pmf to provided storage
mean, # mean of distribution
median, # median of distribution
mgf, # moment generating function
......@@ -471,7 +473,7 @@ end
for f in (:pdf, :logpdf, :cdf, :logcdf,
:ccdf, :logccdf, :quantile, :cquantile,
:invlogcdf, :invlogccdf)
@eval begin
@eval begin
function ($f)(d::UnivariateDistribution, x::AbstractArray)
res = Array(Float64, size(x))
for i in 1:length(res)
......@@ -482,8 +484,38 @@ for f in (:pdf, :logpdf, :cdf, :logcdf,
end
end
function logpdf!(d::UnivariateDistribution, x::AbstractArray, r::AbstractArray)
if size(x) != size(r)
throw(ArgumentError("Inconsistent array dimensions."))
end
for i = 1 : length(x)
r[i] = logpdf(d, x[i])
end
end
function logpdf(d::MultivariateDistribution, x::AbstractMatrix)
n::Int = size(x, 2)
r = Array(Float64, n)
for i = 1 : n
r[i] = logpdf(d, x[:,i])
end
r
end
function logpdf!(d::MultivariateDistribution, x::AbstractMatrix, r::AbstractArray)
n::Int = size(x, 2)
if length(r) != n
throw(ArgumentError("Inconsistent array dimensions."))
end
for i = 1 : n
r[i] = logpdf(d, x[:,i])
end
end
pmf(d::DiscreteDistribution, args::Any...) = pdf(d, args...)
logpmf(d::DiscreteDistribution, args::Any...) = logpdf(d, args...)
logpmf!(d::DiscreteDistribution, args::Any...) = logpdf!(d, args...)
binary_entropy(d::Distribution) = entropy(d) / log(2)
......@@ -1113,10 +1145,12 @@ MultivariateNormal() = MultivariateNormal(zeros(2), eye(2))
mean(d::MultivariateNormal) = d.mean
var(d::MultivariateNormal) = (U = d.covchol[:U]; U'U)
function rand(d::MultivariateNormal)
z = randn(length(d.mean))
return d.mean + d.covchol[:U]'z
end
function rand!(d::MultivariateNormal, X::Matrix)
k = length(mean(d))
m, n = size(X)
......@@ -1124,13 +1158,47 @@ function rand!(d::MultivariateNormal, X::Matrix)
if n == k return randn!(X) * d.covchol[:U] + d.mean'[ones(Int,m),:] end
error("Wrong dimensions")
end
function logpdf{T <: Real}(d::MultivariateNormal, x::Vector{T})
k = length(d.mean)
u = x - d.mean
z = d.covchol \ u # This is equivalent to inv(cov) * u, but much faster
return -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol[:U]))) - 0.5 * dot(u,z)
u = x - d.mean
# don't have to copy u, as we only use the transformed version (not the original one)
Base.LinAlg.LAPACK.trtrs!('U', 'T', 'N', d.covchol.UL, u)
-0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.UL))) - 0.5 * dot(u,u)
end
function logpdf!{T <: Real}(d::MultivariateNormal, x::Matrix{T}, r::AbstractVector)
mu::Vector{Float64} = d.mean
k = length(mu)
if size(x, 1) != k
throw(ArgumentError("The dimension of x is inconsistent with d."))
end
n = size(x, 2)
u = Array(Float64, k, n)
for j = 1 : n # u[:,j] = x[:,j] - mu
for i = 1 : k
u[i, j] = x[i, j] - mu[i]
end
end
Base.LinAlg.LAPACK.trtrs!('U', 'T', 'N', d.covchol.UL, u)
c::Float64 = -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.UL)))
for j = 1 : n
dot_uj = 0.
for i = 1 : k
dot_uj += u[i,j] * u[i,j]
end
r[j] = c - 0.5 * dot_uj
end
end
function logpdf{T <: Real}(d::MultivariateNormal, x::Matrix{T})
r = Array(Float64, size(x, 2))
logpdf!(d, x, r)
r
end
pdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x))
pdf{T<:Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x))
function cdf{T <: Real}(d::MultivariateNormal, x::Vector{T})
k = length(d.mean)
if k > 3; error("Dimension larger than three is not supported yet"); end
......@@ -1648,12 +1716,16 @@ end
immutable Dirichlet <: ContinuousMultivariateDistribution
alpha::Vector{Float64}
function Dirichlet{T <: Real}(alpha::Vector{T})
for el in alpha
if el < 0. error("Dirichlet: elements of alpha must be non-negative") end
end
new(float64(alpha))
end
# construct a symmetric Dirichlet distribution
Dirichlet{T <: Real}(d::Int, alpha::T) = Dirichlet(fill(alpha, d))
end
Dirichlet(dim::Integer) = Dirichlet(ones(dim))
......@@ -1683,22 +1755,36 @@ function insupport{T <: Real}(d::Dirichlet, x::Vector{T})
return true
end
function pdf{T <: Real}(d::Dirichlet, x::Vector{T})
if !insupport(d, x)
error("x not in the support of Dirichlet distribution")
end
b = prod(gamma(d.alpha)) / gamma(sum(d.alpha))
(1 / b) * prod(x.^(d.alpha - 1))
end
function logpdf{T <: Real}(d::Dirichlet, x::Vector{T})
if !insupport(d, x)
error("x not in the support of Dirichlet distribution")
end
b = sum(lgamma(d.alpha)) - lgamma(sum(d.alpha))
dot((d.alpha - 1), log(x)) - b
end
pdf{T <: Real}(d::Dirichlet, x::Vector{T}) = exp(logpdf(d, x))
function logpdf!{T <: Real}(d::Dirichlet, x::Matrix{T}, r::Vector{T})
if size(x, 1) != length(d.alpha)
throw(ArgumentError("Inconsistent argument dimensions."))
end
n = size(x, 2)
if length(r) != n
throw(ArgumentError("Inconsistent argument dimensions."))
end
b::Float64 = sum(lgamma(d.alpha)) - lgamma(sum(d.alpha))
At_mul_B(r, log(x), d.alpha - 1.)
for i = 1 : n
r[i] -= b
end
end
function logpdf{T <: Real}(d::Dirichlet, x::Matrix{T})
r = Array(Float64, size(x, 2))
logpdf!(d, x, r)
r
end
function rand(d::Dirichlet)
x = [rand(Gamma(el)) for el in d.alpha]
x ./ sum(x)
......
......@@ -136,11 +136,31 @@ for ll in (LogitLink(), ProbitLink()#, CloglogLink() # edge cases for CloglogLin
end
end
# Multivariate normal
d = MultivariateNormal(zeros(2), eye(2))
@test abs(pdf(d, [0, 0]) - 0.159155) < 10e-3
@test abs(pdf(d, [1, 0]) - 0.0965324) < 10e-3
@test abs(pdf(d, [1, 1]) - 0.0585498) < 10e-3
@test abs(pdf(d, [0., 0.]) - 0.159155) < 1.0e-5
@test abs(pdf(d, [1., 0.]) - 0.0965324) < 1.0e-5
@test abs(pdf(d, [1., 1.]) - 0.0585498) < 1.0e-5
d = MultivariateNormal(zeros(3), [4. -2. -1.; -2. 5. -1.; -1. -1. 6.])
@test abs(logpdf(d, [3., 4., 5.]) - (-15.75539253001834)) < 1.0e-10
@test_approx_eq logpdf(d, [3., 4., 5.]) (-15.75539253001834)
x = [3. 4. 5.; 1. 2. 3.; -4. -3. -2.; -1. -3. -2.]'
r0 = zeros(4)
for i = 1 : 4
r0[i] = logpdf(d, x[:,i])
end
@test_approx_eq logpdf(d, x) r0
# Dirichlet
d = Dirichlet([1.5, 2.0, 2.5])
x = [0.2 0.5 0.3; 0.1 0.5 0.4; 0.8 0.1 0.1; 0.05 0.15 0.8]'
r0 = zeros(4)
for i = 1 : 4
r0[i] = logpdf(d, x[:,i])
end
@test_approx_eq logpdf(d, x) r0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment