Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 11632dfb by John Myles White

### Liberalize inputs for MultivariateNormal

`Add draft show() method for Distributions`
parent ee61a809
 ... ... @@ -85,6 +85,7 @@ export # types import Base.mean, Base.median, Base.quantile import Base.rand, Base.std, Base.var, Base.integer_valued import Base.show include("tvpack.jl") ... ... @@ -787,6 +788,20 @@ function rand(d::MixtureModel) i = rand(Categorical(d.probs)) rand(d.components[i]) end function mean(d::MixtureModel) m = 0.0 for i in 1:length(d.components) m += mean(d.components[i]) * d.probs[i] end return m end function var(d::MixtureModel) m = 0.0 for i in 1:length(d.components) m += var(d.components[i]) * d.probs[i]^2 end return m end type MultivariateNormal <: ContinuousMultivariateDistribution mean::Vector{Float64} ... ... @@ -814,13 +829,13 @@ function rand(d::MultivariateNormal) return d.mean + d.covchol.LR'z end function logpdf(d::MultivariateNormal, x::Vector{Float64}) function logpdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) k = length(d.mean) z = d.covchol.LR \ (x - d.mean) return -0.5 * k * log(2.0pi) - sum(log(diag(d.covchol.LR))) - 0.5 * dot(z,z) end pdf(d::MultivariateNormal, x::Vector{Float64}) = exp(logpdf(d, x)) function cdf(d::MultivariateNormal, x::Vector{Float64}) pdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) = exp(logpdf(d, x)) function cdf{T <: Real}(d::MultivariateNormal, x::Vector{T}) k = length(d.mean) if k > 3; error("Dimension larger than three is not supported yet"); end stddev = sqrt(diag(var(d))) ... ... @@ -1299,4 +1314,28 @@ canonicallink(d::Normal) = IdentityLink() canonicallink(d::Bernoulli) = LogitLink() canonicallink(d::Poisson) = LogLink() function show(io::IO, d::Distribution) print(io, @sprintf "%s distribution\n" typeof(d)) for parameter in typeof(d).names if isa(d.(parameter), AbstractArray) param = strcat(ucfirst(string(parameter)), ":\n", d.(parameter), "\n") else param = strcat(ucfirst(string(parameter)), ": ", d.(parameter), "\n") end print(io, param) end m = mean(d) if isa(m, AbstractArray) print(io, strcat("Mean:\n", m, "\n")) else print(io, strcat("Mean: ", m, "\n")) end v = var(d) if isa(v, AbstractArray) print(io, strcat("Variance:\n", v)) else print(io, strcat("Variance: ", v)) end end end #module
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!