Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 4bf313fc by Dahua Lin

### add logpdf & logpdf! for MultivariateNormal and Dirichlet for efficient batch evaluation

parent bd4827b8
 ... ... @@ -85,7 +85,9 @@ export # types logccdf, # ccdf returning log-probability logcdf, # cdf returning log-probability logpdf, # log probability density logpdf!, # evaluate log pdf to provided storage logpmf, # log probability mass logpmf!, # evaluate log pmf to provided storage mean, # mean of distribution median, # median of distribution mgf, # moment generating function ... ... @@ -471,7 +473,7 @@ end for f in (:pdf, :logpdf, :cdf, :logcdf, :ccdf, :logccdf, :quantile, :cquantile, :invlogcdf, :invlogccdf) @eval begin @eval begin function (\$f)(d::UnivariateDistribution, x::AbstractArray) res = Array(Float64, size(x)) for i in 1:length(res) ... ... @@ -482,8 +484,38 @@ for f in (:pdf, :logpdf, :cdf, :logcdf, end end function logpdf!(d::UnivariateDistribution, x::AbstractArray, r::AbstractArray) if size(x) != size(r) throw(ArgumentError("Inconsistent array dimensions.")) end for i = 1 : length(x) r[i] = logpdf(d, x[i]) end end function logpdf(d::MultivariateDistribution, x::AbstractMatrix) n::Int = size(x, 2) r = Array(Float64, n) for i = 1 : n r[i] = logpdf(d, x[:,i]) end r end function logpdf!(d::MultivariateDistribution, x::AbstractMatrix, r::AbstractArray) n::Int = size(x, 2) if length(r) != n throw(ArgumentError("Inconsistent array dimensions.")) end for i = 1 : n r[i] = logpdf(d, x[:,i]) end end pmf(d::DiscreteDistribution, args::Any...) = pdf(d, args...) logpmf(d::DiscreteDistribution, args::Any...) = logpdf(d, args...) logpmf!(d::DiscreteDistribution, args::Any...) = logpdf!(d, args...) binary_entropy(d::Distribution) = entropy(d) / log(2) ... ... @@ -1113,10 +1145,12 @@ MultivariateNormal() = MultivariateNormal(zeros(2), eye(2)) mean(d::MultivariateNormal) = d.mean var(d::MultivariateNormal) = (U = d.covchol[:U]; U'U) function rand(d::MultivariateNormal) z = randn(length(d.mean)) return d.mean + d.covchol[:U]'z end function rand!(d::MultivariateNormal, X::Matrix) k = length(mean(d)) m, n = size(X) ... ... @@ -1124,13 +1158,47 @@ function rand!(d::MultivariateNormal, X::Matrix) if n == k return randn!(X) * d.covchol[:U] + d.mean'[ones(Int,m),:] end error("Wrong dimensions") end function logpdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) k = length(d.mean) u = x - d.mean z = d.covchol \ u # This is equivalent to inv(cov) * u, but much faster return -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol[:U]))) - 0.5 * dot(u,z) u = x - d.mean # don't have to copy u, as we only use the transformed version (not the original one) Base.LinAlg.LAPACK.trtrs!('U', 'T', 'N', d.covchol.UL, u) -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.UL))) - 0.5 * dot(u,u) end function logpdf!{T <: Real}(d::MultivariateNormal, x::Matrix{T}, r::AbstractVector) mu::Vector{Float64} = d.mean k = length(mu) if size(x, 1) != k throw(ArgumentError("The dimension of x is inconsistent with d.")) end n = size(x, 2) u = Array(Float64, k, n) for j = 1 : n # u[:,j] = x[:,j] - mu for i = 1 : k u[i, j] = x[i, j] - mu[i] end end Base.LinAlg.LAPACK.trtrs!('U', 'T', 'N', d.covchol.UL, u) c::Float64 = -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.UL))) for j = 1 : n dot_uj = 0. for i = 1 : k dot_uj += u[i,j] * u[i,j] end r[j] = c - 0.5 * dot_uj end end function logpdf{T <: Real}(d::MultivariateNormal, x::Matrix{T}) r = Array(Float64, size(x, 2)) logpdf!(d, x, r) r end pdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x)) pdf{T<:Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x)) function cdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) k = length(d.mean) if k > 3; error("Dimension larger than three is not supported yet"); end ... ... @@ -1648,12 +1716,16 @@ end immutable Dirichlet <: ContinuousMultivariateDistribution alpha::Vector{Float64} function Dirichlet{T <: Real}(alpha::Vector{T}) for el in alpha if el < 0. error("Dirichlet: elements of alpha must be non-negative") end end new(float64(alpha)) end # construct a symmetric Dirichlet distribution Dirichlet{T <: Real}(d::Int, alpha::T) = Dirichlet(fill(alpha, d)) end Dirichlet(dim::Integer) = Dirichlet(ones(dim)) ... ... @@ -1683,22 +1755,36 @@ function insupport{T <: Real}(d::Dirichlet, x::Vector{T}) return true end function pdf{T <: Real}(d::Dirichlet, x::Vector{T}) if !insupport(d, x) error("x not in the support of Dirichlet distribution") end b = prod(gamma(d.alpha)) / gamma(sum(d.alpha)) (1 / b) * prod(x.^(d.alpha - 1)) end function logpdf{T <: Real}(d::Dirichlet, x::Vector{T}) if !insupport(d, x) error("x not in the support of Dirichlet distribution") end b = sum(lgamma(d.alpha)) - lgamma(sum(d.alpha)) dot((d.alpha - 1), log(x)) - b end pdf{T <: Real}(d::Dirichlet, x::Vector{T}) = exp(logpdf(d, x)) function logpdf!{T <: Real}(d::Dirichlet, x::Matrix{T}, r::Vector{T}) if size(x, 1) != length(d.alpha) throw(ArgumentError("Inconsistent argument dimensions.")) end n = size(x, 2) if length(r) != n throw(ArgumentError("Inconsistent argument dimensions.")) end b::Float64 = sum(lgamma(d.alpha)) - lgamma(sum(d.alpha)) At_mul_B(r, log(x), d.alpha - 1.) for i = 1 : n r[i] -= b end end function logpdf{T <: Real}(d::Dirichlet, x::Matrix{T}) r = Array(Float64, size(x, 2)) logpdf!(d, x, r) r end function rand(d::Dirichlet) x = [rand(Gamma(el)) for el in d.alpha] x ./ sum(x) ... ...
 ... ... @@ -136,11 +136,31 @@ for ll in (LogitLink(), ProbitLink()#, CloglogLink() # edge cases for CloglogLin end end # Multivariate normal d = MultivariateNormal(zeros(2), eye(2)) @test abs(pdf(d, [0, 0]) - 0.159155) < 10e-3 @test abs(pdf(d, [1, 0]) - 0.0965324) < 10e-3 @test abs(pdf(d, [1, 1]) - 0.0585498) < 10e-3 @test abs(pdf(d, [0., 0.]) - 0.159155) < 1.0e-5 @test abs(pdf(d, [1., 0.]) - 0.0965324) < 1.0e-5 @test abs(pdf(d, [1., 1.]) - 0.0585498) < 1.0e-5 d = MultivariateNormal(zeros(3), [4. -2. -1.; -2. 5. -1.; -1. -1. 6.]) @test abs(logpdf(d, [3., 4., 5.]) - (-15.75539253001834)) < 1.0e-10 @test_approx_eq logpdf(d, [3., 4., 5.]) (-15.75539253001834) x = [3. 4. 5.; 1. 2. 3.; -4. -3. -2.; -1. -3. -2.]' r0 = zeros(4) for i = 1 : 4 r0[i] = logpdf(d, x[:,i]) end @test_approx_eq logpdf(d, x) r0 # Dirichlet d = Dirichlet([1.5, 2.0, 2.5]) x = [0.2 0.5 0.3; 0.1 0.5 0.4; 0.8 0.1 0.1; 0.05 0.15 0.8]' r0 = zeros(4) for i = 1 : 4 r0[i] = logpdf(d, x[:,i]) end @test_approx_eq logpdf(d, x) r0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!