Ce serveur Gitlab sera éteint le 30 juin 2020, pensez à migrer vos projets vers les serveurs gitlab-research.centralesupelec.fr et gitlab-student.centralesupelec.fr !

Commit 2fc99c6d authored by Simon Byrne's avatar Simon Byrne

change Truncated to parametric type

parent 5f47074c
......@@ -75,8 +75,7 @@ export # types
Skellam,
TDist,
Triangular,
TruncatedNormal,
TruncatedUnivariateDistribution,
Truncated,
Uniform,
Weibull,
Wishart,
......
abstract TruncatedContinuousUnivariateDistribution <: ContinuousUnivariateDistribution
abstract TruncatedDiscreteUnivariateDistribution <: DiscreteUnivariateDistribution
typealias TruncatedUnivariateDistribution Union(TruncatedContinuousUnivariateDistribution, TruncatedDiscreteUnivariateDistribution)
macro truncate(dname::Any)
new_dname = esc(symbol(string("Truncated", string(dname))))
# TODO: Are we not supposed to run eval() in a macro?
if eval(dname) <: ContinuousUnivariateDistribution
dtype = esc(TruncatedContinuousUnivariateDistribution)
else
dtype = esc(TruncatedDiscreteUnivariateDistribution)
end
dname = esc(dname)
quote
immutable $new_dname <: $dtype
untruncated::$dname
lower::Float64
upper::Float64
nc::Float64 # Normalization constant
function ($new_dname)(d::$dname, l::Real, u::Real, nc::Real)
if l >= u
error("upper must be > lower")
end
new(d, float64(l), float64(u), float64(nc))
end
end
function ($new_dname)(d::$dname, l::Real, u::Real)
return ($new_dname)(d, l, u, cdf(d, u) - cdf(d, l))
immutable Truncated{D<:UnivariateDistribution} <: UnivariateDistribution
untruncated::D
lower::Float64
upper::Float64
nc::Float64
function Truncated{T<:UnivariateDistribution}(d::T, l::Real, u::Real, nc::Real)
if l >= u
error("upper must be > lower")
end
new(d, float64(l), float64(u), float64(nc))
end
end
Truncated(d::UnivariateDistribution, l::Real, u::Real, nc::Real) = Truncated{typeof(d)}(d,l,u,nc)
Truncated(d::UnivariateDistribution, l::Real, u::Real) = Truncated{typeof(d)}(d,l,u, cdf(d, u) - cdf(d, l))
function insupport(d::TruncatedUnivariateDistribution, x::Number)
function insupport(d::Truncated, x::Number)
return x >= d.lower && x <= d.upper && insupport(d.untruncated, x)
end
function pdf(d::TruncatedUnivariateDistribution, x::Real)
function pdf(d::Truncated, x::Real)
if !insupport(d, x)
return 0.0
else
......@@ -42,7 +27,7 @@ function pdf(d::TruncatedUnivariateDistribution, x::Real)
end
end
function logpdf(d::TruncatedUnivariateDistribution, x::Real)
function logpdf(d::Truncated, x::Real)
if !insupport(d, x)
return -Inf
else
......@@ -50,7 +35,7 @@ function logpdf(d::TruncatedUnivariateDistribution, x::Real)
end
end
function cdf(d::TruncatedUnivariateDistribution, x::Real)
function cdf(d::Truncated, x::Real)
if x < d.lower
return 0.0
elseif x > d.upper
......@@ -60,15 +45,15 @@ function cdf(d::TruncatedUnivariateDistribution, x::Real)
end
end
function quantile(d::TruncatedUnivariateDistribution, p::Real)
function quantile(d::Truncated, p::Real)
top = cdf(d.untruncated, d.upper)
bottom = cdf(d.untruncated, d.lower)
return quantile(d.untruncated, bottom + p * (top - bottom))
end
median(d::TruncatedUnivariateDistribution) = quantile(d, 0.5)
median(d::Truncated) = quantile(d, 0.5)
function rand(d::TruncatedUnivariateDistribution)
function rand(d::Truncated)
while true
r = rand(d.untruncated)
if d.lower <= r <= d.upper
......@@ -76,3 +61,12 @@ function rand(d::TruncatedUnivariateDistribution)
end
end
end
# from fallbacks
function rand{D<:ContinuousUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Float64, dims))
end
function rand{D<:DiscreteUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Int, dims))
end
@truncate Normal
function entropy(d::TruncatedNormal)
function entropy(d::Truncated{Normal})
s = std(d.untruncated)
a = d.lower
b = d.upper
......@@ -13,12 +13,12 @@ function entropy(d::TruncatedNormal)
0.5 * (a_phi_a - b_phi_b) / z - 0.5 * ((phi_a - phi_b) / z)^2
end
function mean(d::TruncatedNormal)
function mean(d::Truncated{Normal})
delta = pdf(d.untruncated, d.lower) - pdf(d.untruncated, d.upper)
return mean(d.untruncated) + delta * var(d.untruncated) / d.nc
end
function modes(d::TruncatedNormal)
function modes(d::Truncated{Normal})
mu = mean(d.untruncated)
if d.upper < mu
return [d.upper]
......@@ -29,14 +29,14 @@ function modes(d::TruncatedNormal)
end
end
function rand(d::TruncatedNormal)
function rand(d::Truncated{Normal})
mu = mean(d.untruncated)
sigma = std(d.untruncated)
z = randnt((d.lower - mu) / sigma, (d.upper - mu) / sigma)
return mu + sigma * z
end
function var(d::TruncatedNormal)
function var(d::Truncated{Normal})
s = std(d.untruncated)
a = d.lower
b = d.upper
......
using Distributions
using Base.Test
for d in (TruncatedNormal(Normal(0, 1), -1, 1),
TruncatedNormal(Normal(3, 10), 7, 8),
TruncatedNormal(Normal(-5, 1), -Inf, -10))
for d in (Truncated(Normal(0, 1), -1, 1),
Truncated(Normal(3, 10), 7, 8),
Truncated(Normal(-5, 1), -Inf, -10))
@test all(insupport(d, rand(d, 1000)))
end
d = TruncatedNormal(Normal(0, 1), -0.1, +0.1)
d = Truncated(Normal(0, 1), -0.1, +0.1)
@test pdf(d, 0.0) > pdf(Normal(0, 1), 0.0)
@test pdf(d, -1.0) == 0.0
......
......@@ -97,9 +97,9 @@ for d in [Arcsine(),
Triangular(3.0, 1.0),
Triangular(3.0, 2.0),
Triangular(10.0, 10.0),
TruncatedNormal(Normal(0, 1), -3, 3),
# TruncatedNormal(Normal(-100, 1), 0, 1),
TruncatedNormal(Normal(27, 3), 0, Inf),
Truncated(Normal(0, 1), -3, 3),
# Truncated(Normal(-100, 1), 0, 1),
Truncated(Normal(27, 3), 0, Inf),
Uniform(0.0, 1.0),
Uniform(3.0, 17.0),
Uniform(3.0, 3.1),
......@@ -112,14 +112,14 @@ for d in [Arcsine(),
# println(d)
n = length(pp)
is_continuous = isa(d, ContinuousDistribution)
is_discrete = isa(d, DiscreteDistribution)
is_continuous = isa(d, Truncated) ? isa(d.untruncated, ContinuousDistribution) : isa(d, ContinuousDistribution)
is_discrete = isa(d, Truncated) ? isa(d.untruncated, DiscreteDistribution) : isa(d, DiscreteDistribution)
@assert is_continuous == !is_discrete
sample_ty = is_continuous ? Float64 : Int
# avoid checking high order moments for LogNormal and Logistic
avoid_highord = isa(d, LogNormal) || isa(d, Logistic) || isa(d, TruncatedNormal)
avoid_highord = isa(d, LogNormal) || isa(d, Logistic) || isa(d, Truncated)
#####
#
......
......@@ -116,9 +116,9 @@ for d in [Arcsine(),
Triangular(3.0, 1.0),
Triangular(3.0, 2.0),
Triangular(10.0, 10.0),
# TruncatedNormal(Normal(0, 1), -3, 3),
# TruncatedNormal(Normal(-100, 1), 0, 1),
# TruncatedNormal(Normal(27, 3), 0, Inf),
# Truncated(Normal(0, 1), -3, 3),
# Truncated(Normal(-100, 1), 0, 1),
# Truncated(Normal(27, 3), 0, Inf),
Uniform(0.0, 1.0),
Uniform(3.0, 17.0),
Uniform(3.0, 3.1),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment