Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit e9b6874d authored by Dahua Lin's avatar Dahua Lin

tweak Dirichlet MLE

parent 00f01ead
......@@ -87,6 +87,7 @@ export # types
entropy, # entropy of distribution in nats
fit, # fit a distribution to data (using default method)
fit_mle, # fit a distribution to data using MLE
fit_mle!, # fit a distribution to data using MLE (inplace update to initial guess)
fit_map, # fit a distribution to data using MAP
freecumulant, # free cumulants of distribution
insupport, # predicate, is x in the support of the distribution?
......
......@@ -157,6 +157,70 @@ function rand!(d::Dirichlet, X::Matrix)
return X
end
#####
#
# Algorithm: Newton-Raphson
#
#####
function fit_mle!{T <: Real}(
dty::Type{Dirichlet},
alpha::Vector{Float64}, # initial guess of alpha
Elogp::Vector{Float64}; # expectation/mean of log(p)
maxiter::Int=25, tol::Float64=1.0e-8)
K = length(alpha)
if length(Elogp) != K
throw(ArgumentError("Inconsistent argument dimensions."))
end
g = Array(Float64, K)
iq = Array(Float64, K)
t = 0
converged = false
while !converged && t < maxiter
t += 1
# compute gradient & Hessian
# (b is computed as well)
a0 = sum(alpha)
digam_a0 = digamma(a0)
iz = 1.0 / trigamma(a0)
gnorm = 0.
for k = 1:K
ak = alpha[k]
g[k] = gk = digam_a0 - digamma(ak) + Elogp[k]
iq[k] = - 1.0 / trigamma(ak)
b += gk * iq[k]
iqs += 1.0 * iq[k]
agk = abs(gk)
if agk > gnorm
gnorm = agk
end
end
b /= (iz + iqs)
# update alpha
for k = 1:K
alpha[k] -= (g[k] - b) * iq[k]
end
# determine convergence
converged = gnorm < tol
end
Dirichlet(alpha)
end
function fit_mle{T <: Real}(::Type{Dirichlet}, P::Matrix{T}; maxiter::Int=25, tol::Float64=1.0e-8)
K, N = size(P)
......@@ -187,29 +251,7 @@ function fit_mle{T <: Real}(::Type{Dirichlet}, P::Matrix{T}; maxiter::Int=25, to
alpha[k] *= alpha0
end
iteration = 0
converged = false
while !converged && iteration < maxiter
iteration += 1
alpha0 = sum(alpha)
b = 0.0
iqs = 0.0
iz = 1.0 / (N * trigamma(alpha0))
dgalpha0 = digamma(alpha0)
for k in 1:K
g[k] = N * (dgalpha0 - digamma(alpha[k]) + lpbar[k])
q[k] = -N * trigamma(alpha[k])
b += g[k] / q[k]
iqs += 1.0 / q[k]
end
b /= (iz + iqs)
for k in 1:K
alpha[k] -= (g[k] - b) / q[k]
end
if amax(g) > tol
converged = true
end
end
return Dirichlet(alpha)
end
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment