Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 11632dfb authored by John Myles White's avatar John Myles White

Liberalize inputs for MultivariateNormal

Add draft show() method for Distributions
parent ee61a809
......@@ -85,6 +85,7 @@ export # types
import Base.mean, Base.median, Base.quantile
import Base.rand, Base.std, Base.var, Base.integer_valued
import Base.show
include("tvpack.jl")
......@@ -787,6 +788,20 @@ function rand(d::MixtureModel)
i = rand(Categorical(d.probs))
rand(d.components[i])
end
function mean(d::MixtureModel)
m = 0.0
for i in 1:length(d.components)
m += mean(d.components[i]) * d.probs[i]
end
return m
end
function var(d::MixtureModel)
m = 0.0
for i in 1:length(d.components)
m += var(d.components[i]) * d.probs[i]^2
end
return m
end
type MultivariateNormal <: ContinuousMultivariateDistribution
mean::Vector{Float64}
......@@ -814,13 +829,13 @@ function rand(d::MultivariateNormal)
return d.mean + d.covchol.LR'z
end
function logpdf(d::MultivariateNormal, x::Vector{Float64})
function logpdf{T <: Real}(d::MultivariateNormal, x::Vector{T})
k = length(d.mean)
z = d.covchol.LR \ (x - d.mean)
return -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.LR))) - 0.5 * dot(z,z)
end
pdf(d::MultivariateNormal, x::Vector{Float64}) = exp(logpdf(d, x))
function cdf(d::MultivariateNormal, x::Vector{Float64})
pdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x))
function cdf{T <: Real}(d::MultivariateNormal, x::Vector{T})
k = length(d.mean)
if k > 3; error("Dimension larger than three is not supported yet"); end
stddev = sqrt(diag(var(d)))
......@@ -1299,4 +1314,28 @@ canonicallink(d::Normal) = IdentityLink()
canonicallink(d::Bernoulli) = LogitLink()
canonicallink(d::Poisson) = LogLink()
function show(io::IO, d::Distribution)
print(io, @sprintf "%s distribution\n" typeof(d))
for parameter in typeof(d).names
if isa(d.(parameter), AbstractArray)
param = strcat(ucfirst(string(parameter)), ":\n", d.(parameter), "\n")
else
param = strcat(ucfirst(string(parameter)), ": ", d.(parameter), "\n")
end
print(io, param)
end
m = mean(d)
if isa(m, AbstractArray)
print(io, strcat("Mean:\n", m, "\n"))
else
print(io, strcat("Mean: ", m, "\n"))
end
v = var(d)
if isa(v, AbstractArray)
print(io, strcat("Variance:\n", v))
else
print(io, strcat("Variance: ", v))
end
end
end #module
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment