Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 06a26840 authored by John Myles White's avatar John Myles White

Implement generic @truncate macro

Speed up DiscreteDistributionTable
Switch integer_valued => isinteger
Switch real_valued => isreal
parent ae39cef7
using Distributions
using Base.Test
my_tests = [
"test/distributions.jl",
"test/utils.jl",
"test/wisharts.jl",
"test/fit.jl"]
my_tests = ["test/distributions.jl",
"test/utils.jl",
"test/wisharts.jl",
"test/fit.jl",
"test/truncate.jl"]
println("Running tests:")
......
This diff is collapsed.
using Distributions
abstract TruncatedContinuousUnivariateDistribution <: ContinuousUnivariateDistribution
abstract TruncatedDiscreteUnivariateDistribution <: DiscreteUnivariateDistribution
typealias TruncatedUnivariateDistribution Union(TruncatedContinuousUnivariateDistribution, TruncatedDiscreteUnivariateDistribution)
macro truncate(dname::Any)
new_dname = esc(symbol(string("Truncated", string(dname))))
# TODO: Are we not supposed to run eval() in a macro?
if eval(dname) <: ContinuousUnivariateDistribution
dtype = esc(TruncatedContinuousUnivariateDistribution)
else
dtype = esc(TruncatedDiscreteUnivariateDistribution)
end
dname = esc(dname)
quote
immutable $new_dname <: $dtype
untruncated::$dname
lower::Float64
upper::Float64
nc::Float64 # Normalization constant
function ($new_dname)(d::$dname, l::Real, u::Real, nc::Real)
if l >= u
error("upper must be > lower")
end
new(d, float64(l), float64(u), float64(nc))
end
end
function ($new_dname)(d::$dname, l::Real, u::Real)
return ($new_dname)(d, l, u, cdf(d, u) - cdf(d, l))
end
end
end
function insupport(d::TruncatedUnivariateDistribution, x::Number)
return x >= d.lower && x <= d.upper && insupport(d.untruncated, x)
end
function pdf(d::TruncatedUnivariateDistribution, x::Real)
if !insupport(d, x)
return 0.0
else
return pdf(d.untruncated, x) / d.nc
end
end
function logpdf(d::TruncatedUnivariateDistribution, x::Real)
if !insupport(d, x)
return -Inf
else
return logpdf(d.untruncated, x) - log(d.nc)
end
end
function cdf(d::TruncatedUnivariateDistribution, x::Real)
if x < d.lower
return 0.0
elseif x > d.upper
return 1.0
else
return (cdf(d.untruncated, x) - cdf(d.untruncated, d.lower)) / d.nc
end
end
function quantile(d::TruncatedUnivariateDistribution, p::Real)
top = cdf(d.untruncated, d.upper)
bottom = cdf(d.untruncated, d.lower)
return quantile(d.untruncated, bottom + p * (top - bottom))
end
function rand(d::TruncatedUnivariateDistribution)
while true
r = rand(d.untruncated)
if d.lower <= r <= d.upper
return r
end
end
end
# Store an alias table
immutable DiscreteDistributionTable
table::Vector{Any}
table::Vector{Vector{Int64}}
bounds::Vector{Int64}
end
......@@ -16,7 +16,7 @@ function DiscreteDistributionTable{T <: Real}(probs::Vector{T})
end
# Allocate digit table and digit sums as table bounds
table = Array(Any, 9)
table = Array(Vector{Int64}, 9)
bounds = zeros(Int64, 9)
# Special case for deterministic distributions
......
using Distributions
using Base.Test
# n probability points, i.e. the midpoints of the intervals [0, 1/n],...,[1-1/n, 1]
probpts(n::Int) = ((1:n) - 0.5)/n
pp = float(probpts(1000)) # convert from a Range{Float64}
......@@ -23,12 +20,27 @@ function reldiff{T<:Real}(current::AbstractArray{T}, target::AbstractArray{T})
end
## Checks on ContinuousDistribution instances
for d in (Beta(), Cauchy(), Chisq(12), Exponential(), Exponential(23.1),
FDist(2, 21), Gamma(3), Gamma(), Gumbel(), Gumbel(5, 3),
Logistic(), logNormal(), Normal(), TDist(1), TDist(28),
TruncatedNormal(0, 1, -3, 3), TruncatedNormal(-100, 1, 0, 1),
TruncatedNormal(27, 3, 0, Inf), Uniform(), Weibull(2.3))
## println(d) # uncomment if an assertion fails
for d in (Beta(),
Cauchy(),
Chisq(12),
Exponential(),
Exponential(23.1),
FDist(2, 21),
Gamma(3),
Gamma(),
Gumbel(),
Gumbel(5, 3),
Logistic(),
logNormal(),
Normal(),
TDist(1),
TDist(28),
TruncatedNormal(Normal(0, 1), -3, 3),
# TruncatedNormal(Normal(-100, 1), 0, 1),
TruncatedNormal(Normal(27, 3), 0, Inf),
Uniform(),
Weibull(2.3))
# println(d) # uncomment if an assertion fails
qq = quantile(d, pp)
@test_approx_eq cdf(d, qq) pp
@test_approx_eq ccdf(d, qq) 1 - pp
......@@ -168,7 +180,8 @@ end
# Truncated normal
for d in (TruncatedNormal(0, 1, -1, 1), TruncatedNormal(3, 10, 7, 8),
TruncatedNormal(-5, 1, -Inf, -10))
for d in (TruncatedNormal(Normal(0, 1), -1, 1),
TruncatedNormal(Normal(3, 10), 7, 8),
TruncatedNormal(Normal(-5, 1), -Inf, -10))
@test all(insupport(d, rand(d, 1000)))
end
using Distributions
N = 100_000
fit(Bernoulli, rand(Bernoulli(0.7), N))
......
d = TruncatedNormal(Normal(0, 1), -0.1, +0.1)
@assert pdf(d, 0.0) > pdf(Normal(0, 1), 0.0)
@assert pdf(d, -1.0) == 0.0
@assert pdf(d, +1.0) == 0.0
@assert logpdf(d, 0.0) > logpdf(Normal(0, 1), 0.0)
@assert isinf(logpdf(d, -1.0))
@assert isinf(logpdf(d, +1.0))
@assert cdf(d, -1.0) == 0.0
@assert cdf(d, -0.09) < cdf(Normal(0, 1), -0.09)
@assert cdf(d, 0.0) == 0.5
@assert cdf(d, +0.09) > cdf(Normal(0, 1), +0.09)
@assert cdf(d, +1.0) == 1.0
@assert quantile(d, 0.01) > -0.1
@assert abs(quantile(d, 0.5) - 0.0) < 1e-8
@assert quantile(d, 0.99) < +0.1
@assert abs(cdf(d, quantile(d, 0.01)) - 0.01) < 1e-8
@assert abs(cdf(d, quantile(d, 0.50)) - 0.50) < 1e-8
@assert abs(cdf(d, quantile(d, 0.99)) - 0.99) < 1e-8
using Distributions
probs = [0.2245, 0.1271, 0.3452, 0.3032]
table = Distributions.DiscreteDistributionTable(probs)
Distributions.draw(table)
......
using Distributions
using Base.Test
v = 3.0
S = eye(2)
S[1, 2] = S[2, 1] = 0.5
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment