Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
DataStructures = "0.17, 0.18"
Documenter = "0.25, 0.26, 0.27"
ForwardDiff = "0.10"
DataStructures = "0.17, 0.18, 0.19, 0.20"
Documenter = "0.25, 0.26, 0.27, 1"
ForwardDiff = "0.10, 1"
PositiveFactorizations = "0.2"
Roots = "1.3, 2"
SpecialFunctions = "0.8.1, 0.9, 1.0, 2.0"
StatsBase = "0.32, 0.33"
SpecialFunctions = "0.8.1, 0.9, 1, 2"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.9, 1"
julia = "1.3, 1.4, 1.5, 1.6, 1.7"
julia = "1.6, 1.7, 1.8, 1.9, 1.10, 1.11"
176 changes: 90 additions & 86 deletions src/engines/julia/update_rules/delta_sampling.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
export
ruleSPDeltaSOutNM,
ruleSPDeltaSIn1MN,
ruleSPDeltaSInGX,
ruleSPDeltaSOutNMX,
ruleSPDeltaSInMX,
ruleMDeltaSInMGX,
prod!
ruleSPDeltaSOutNM,
ruleSPDeltaSIn1MN,
ruleSPDeltaSInGX,
ruleSPDeltaSOutNMX,
ruleSPDeltaSInMX,
ruleMDeltaSInMGX,
prod!

const default_n_samples = 1000 # Default value for the number of samples

Expand All @@ -15,22 +15,22 @@ const default_n_samples = 1000 # Default value for the number of samples
#----------------------

function ruleSPDeltaSOutNM(g::Function,
msg_out::Nothing,
msg_in1::Message; # Applies to any message except SampleList
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Nothing,
msg_in1::Message; # Applies to any message except SampleList
dims::Any=nothing,
n_samples=default_n_samples)

samples = g.(sample(msg_in1.dist, n_samples))
weights = ones(n_samples)/n_samples
weights = ones(n_samples) / n_samples

return Message(variateType(dims), SampleList, s=samples, w=weights)
end

function ruleSPDeltaSOutNM(g::Function,
msg_out::Nothing,
msg_in1::Message{SampleList}; # Special case for SampleList
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Nothing,
msg_in1::Message{SampleList}; # Special case for SampleList
dims::Any=nothing,
n_samples=default_n_samples)

samples = g.(msg_in1.dist.params[:s])
weights = msg_in1.dist.params[:w]
Expand All @@ -39,20 +39,20 @@ function ruleSPDeltaSOutNM(g::Function,
end

function ruleSPDeltaSIn1MN(g::Function,
msg_out::Message,
msg_in1::Nothing;
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Message,
msg_in1::Nothing;
dims::Any=nothing,
n_samples=default_n_samples)

return Message(variateType(dims), Function, log_pdf = (z)->logPdf(msg_out.dist, g(z)))
return Message(variateType(dims), Function, log_pdf=(z) -> logPdf(msg_out.dist, g(z)))
end

function ruleSPDeltaSInGX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}};
dims::Any=nothing,
n_samples=default_n_samples)
inx::Int64, # Index of inbound interface inx
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}};
dims::Any=nothing,
n_samples=default_n_samples)

# Extract joint statistics of inbound messages
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Return arrays with individual means and covariances
Expand All @@ -69,7 +69,7 @@ function ruleSPDeltaSInGX(g::Function,
# Marginalize joint belief on in's
(m_inx, V_inx) = marginalizeGaussianMV(m_in, V_in, ds, inx)
W_inx = cholinv(V_inx) # Convert to canonical statistics
xi_inx = W_inx*m_inx
xi_inx = W_inx * m_inx

# Divide marginal on inx by forward message
(xi_fw_inx, W_fw_inx) = unsafeWeightedMeanPrecision(msgs_in[inx].dist)
Expand All @@ -80,29 +80,29 @@ function ruleSPDeltaSInGX(g::Function,
end

function ruleSPDeltaSOutNMX(g::Function,
msg_out::Nothing,
msgs_in::Vararg{Message};
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Nothing,
msgs_in::Vararg{Message};
dims::Any=nothing,
n_samples=default_n_samples)

samples_in = [sample(msg_in.dist, n_samples) for msg_in in msgs_in]
samples = g.(samples_in...)
weights = ones(n_samples)/n_samples
weights = ones(n_samples) / n_samples

return Message(variateType(dims), SampleList, s=samples, w=weights)
end

function ruleSPDeltaSInMX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message,
msgs_in::Vararg{Message};
dims::Any=nothing,
n_samples=default_n_samples)
inx::Int64, # Index of inbound interface inx
msg_out::Message,
msgs_in::Vararg{Message};
dims::Any=nothing,
n_samples=default_n_samples)

arg_sample = (z) -> begin
samples_in = []
for i=1:length(msgs_in)
if i==inx
for i = 1:length(msgs_in)
if i == inx
push!(samples_in, fill(z, n_samples))
else
push!(samples_in, sample(msgs_in[i].dist, n_samples))
Expand All @@ -112,40 +112,40 @@ function ruleSPDeltaSInMX(g::Function,
return samples_in
end

approximate_pdf(z) = sum(exp.(logPdf.([msg_out.dist],g.(arg_sample(z)...))))/n_samples
approximate_pdf(z) = sum(exp.(logPdf.([msg_out.dist], g.(arg_sample(z)...)))) / n_samples

return Message(variateType(dims), Function, log_pdf = (z)->log(approximate_pdf(z)))
return Message(variateType(dims), Function, log_pdf=(z) -> log(approximate_pdf(z)))
end

# Special case for two inputs with one PointMass (no inx required)
function ruleSPDeltaSInMX(g::Function,
msg_out::Message,
msg_in1::Message{PointMass},
msg_in2::Nothing;
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Message,
msg_in1::Message{PointMass},
msg_in2::Nothing;
dims::Any=nothing,
n_samples=default_n_samples)

m_in1 = msg_in1.dist.params[:m]

return Message(variateType(dims), Function, log_pdf = (z)->logPdf(msg_out.dist, g(m_in1, z)))
return Message(variateType(dims), Function, log_pdf=(z) -> logPdf(msg_out.dist, g(m_in1, z)))
end

# Special case for two inputs with one PointMass (no inx required)
function ruleSPDeltaSInMX(g::Function,
msg_out::Message,
msg_in1::Nothing,
msg_in2::Message{PointMass};
dims::Any=nothing,
n_samples=default_n_samples)
msg_out::Message,
msg_in1::Nothing,
msg_in2::Message{PointMass};
dims::Any=nothing,
n_samples=default_n_samples)

m_in2 = msg_in2.dist.params[:m]

return Message(variateType(dims), Function, log_pdf = (z)->logPdf(msg_out.dist, g(z, m_in2)))
return Message(variateType(dims), Function, log_pdf=(z) -> logPdf(msg_out.dist, g(z, m_in2)))
end

function ruleMDeltaSInMGX(g::Function,
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}})
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}})

# Extract joint statistics of inbound messages
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Return arrays with individual means and covariances
Expand Down Expand Up @@ -173,15 +173,15 @@ function collectSumProductNodeInbounds(node::Delta{Sampling}, entry::ScheduleEnt

# Push function to calling signature
# This function needs to be defined in the scope of the user
push!(inbounds, Dict{Symbol, Any}(:g => node.g,
:keyword => false))
push!(inbounds, Dict{Symbol,Any}(:g => node.g,
:keyword => false))

multi_in = isMultiIn(node) # Boolean to indicate a Delta node with multiple stochastic inbounds
inx = findfirst(isequal(entry.interface), node.interfaces) - 1 # Find number of inbound interface; 0 for outbound

if (inx > 0) && multi_in # Multi-inbound backward rule
push!(inbounds, Dict{Symbol, Any}(:inx => inx, # Push inbound identifier
:keyword => false))
push!(inbounds, Dict{Symbol,Any}(:inx => inx, # Push inbound identifier
:keyword => false))
end

interface_to_schedule_entry = current_inference_algorithm.interface_to_schedule_entry
Expand All @@ -205,12 +205,12 @@ function collectSumProductNodeInbounds(node::Delta{Sampling}, entry::ScheduleEnt

# Push custom arguments if defined
if (node.dims !== nothing)
push!(inbounds, Dict{Symbol, Any}(:dims => node.dims[inx + 1],
:keyword => true))
push!(inbounds, Dict{Symbol,Any}(:dims => node.dims[inx+1],
:keyword => true))
end
if (node.n_samples !== nothing)
push!(inbounds, Dict{Symbol, Any}(:n_samples => node.n_samples,
:keyword => true))
push!(inbounds, Dict{Symbol,Any}(:n_samples => node.n_samples,
:keyword => true))
end
return inbounds
end
Expand All @@ -221,19 +221,19 @@ end
#---------------------------

function prod!(
x::Distribution{V, Function},
y::Distribution{V, Function}) where V<:VariateType # log-pdf for z cannot be predefined, because it cannot be overwritten
x::Distribution{V,Function},
y::Distribution{V,Function}) where V<:VariateType # log-pdf for z cannot be predefined, because it cannot be overwritten

return Distribution(V, Function, log_pdf=(s)->x.params[:log_pdf](s) + y.params[:log_pdf](s))
return Distribution(V, Function, log_pdf=(s) -> x.params[:log_pdf](s) + y.params[:log_pdf](s))
end

@symmetrical function prod!(
x::Distribution{Univariate}, # Includes function distributions
y::Distribution{Univariate, <:Gaussian},
z::Distribution{Univariate, Gaussian{Precision}}=Distribution(Univariate, Gaussian{Precision}, m=0.0, w=1.0))
y::Distribution{Univariate,<:Gaussian},
z::Distribution{Univariate,Gaussian{Precision}}=Distribution(Univariate, Gaussian{Precision}, m=0.0, w=1.0))

# Optimize with gradient ascent
log_joint(s) = logPdf(y,s) + logPdf(x,s)
log_joint(s) = logPdf(y, s) + logPdf(x, s)
d_log_joint(s) = ForwardDiff.derivative(log_joint, s)
m_initial = unsafeMean(y)

Expand All @@ -248,11 +248,11 @@ end

@symmetrical function prod!(
x::Distribution{Multivariate}, # Includes function distributions
y::Distribution{Multivariate, <:Gaussian},
z::Distribution{Multivariate, Gaussian{Precision}}=Distribution(Multivariate, Gaussian{Precision}, m=[0.0], w=mat(1.0)))
y::Distribution{Multivariate,<:Gaussian},
z::Distribution{Multivariate,Gaussian{Precision}}=Distribution(Multivariate, Gaussian{Precision}, m=[0.0], w=mat(1.0)))

# Optimize with gradient ascent
log_joint(s) = logPdf(y,s) + logPdf(x,s)
log_joint(s) = logPdf(y, s) + logPdf(x, s)
d_log_joint(s) = ForwardDiff.gradient(log_joint, s)
m_initial = unsafeMean(y)

Expand All @@ -278,31 +278,35 @@ function gradientOptimization(log_joint::Function, d_log_joint::Function, m_init
m_old = m_initial
satisfied = false
step_count = 0
m_latests = if (dim_tot == 1) Queue{Float64}() else Queue{Vector}() end
m_latests = if (dim_tot == 1)
Queue{Float64}()
else
Queue{Vector}()
end

while !satisfied
m_new = m_old .+ step_size.*d_log_joint(m_old)
m_new = m_old .+ step_size .* d_log_joint(m_old)
if log_joint(m_new) > log_joint(m_old)
proposal_step_size = 10*step_size
m_proposal = m_old .+ proposal_step_size.*d_log_joint(m_old)
proposal_step_size = 10 * step_size
m_proposal = m_old .+ proposal_step_size .* d_log_joint(m_old)
if log_joint(m_proposal) > log_joint(m_new)
m_new = m_proposal
step_size = proposal_step_size
end
else
step_size = 0.1*step_size
m_new = m_old .+ step_size.*d_log_joint(m_old)
step_size = 0.1 * step_size
m_new = m_old .+ step_size .* d_log_joint(m_old)
end
step_count += 1
enqueue!(m_latests, m_old)
push!(m_latests, m_old)
if step_count > 10
m_average = sum(x for x in m_latests)./10
if sum(sqrt.(((m_new.-m_average)./m_average).^2)) < dim_tot*0.1
m_average = sum(x for x in m_latests) ./ 10
if sum(sqrt.(((m_new .- m_average) ./ m_average) .^ 2)) < dim_tot * 0.1
satisfied = true
end
dequeue!(m_latests);
popfirst!(m_latests)
end
if step_count > dim_tot*250
if step_count > dim_tot * 250
satisfied = true
end
m_old = m_new
Expand All @@ -317,7 +321,7 @@ end
#--------

function logJointPdfs(m_fw_in::Vector, W_fw_in::AbstractMatrix, dist_out::Distribution, g::Function, ds::Vector)
log_joint(x) = -0.5*sum(intdim.(ds))*log(2pi) + 0.5*logdet(W_fw_in) - 0.5*(x - m_fw_in)'*W_fw_in*(x - m_fw_in) + logPdf(dist_out, g(split(x, ds)...))
log_joint(x) = -0.5 * sum(intdim.(ds)) * log(2pi) + 0.5 * logdet(W_fw_in) - 0.5 * (x - m_fw_in)' * W_fw_in * (x - m_fw_in) + logPdf(dist_out, g(split(x, ds)...))
d_log_joint(x) = ForwardDiff.gradient(log_joint, x)

return (log_joint, d_log_joint)
Expand Down