Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Icomplete Data

This tutorial demonstrates how to handle incomplete observations (missing data) using RxInfer.jl. Missing data is a common challenge in real-world applications. Traditional approaches often involve imputation or deletion of incomplete observations, which can lead to biased results. Bayesian inference provides a principled way to handle missing data by treating missing values as latent variables and marginalizing over them.

Problem Setup

We'll work with a hierarchical multivariate Gaussian model:

  • Precision matrix Λ follows a Wishart distribution (conjugate prior for precision)
  • Mean vector m follows a multivariate normal distribution
  • Latent states x[i] are drawn from MvNormal(m, Λ⁻¹)
  • Observations y[i,j] are linked to latent states, but some may be missing

Model Definition

using RxInfer, LinearAlgebra
@model function incomplete_data(y, dim)
    Λ ~ Wishart(dim, diagm(ones(dim)))
    m ~ MvNormal(mean=zeros(dim), precision=diagm(ones(dim)))
    for i in 1:size(y, 1)
        x[i] ~ MvNormal(mean=m, precision=Λ)
        for j in 1:dim
            y[i, j] ~ softdot(x[i], StandardBasisVector(dim, j), huge)
        end
    end
end

The softdot with StandardBasisVector effectively extracts the j-th component of x[i], creating the relationship y[i,j] = x[i][j].

Data Generation

Let's generate synthetic data with known ground truth parameters:

n_samples = 100

real_m = [13.0, 1.0, 5.0, 4.0, -20.0, 10.0]
dimension = length(real_m)
real_Λ = diagm(ones(dimension))

real_x = [rand(MvNormal(real_m, inv(real_Λ))) for _ in 1:n_samples]
incomplete_x = Vector{Vector{Union{Float64, Missing}}}(copy(real_x))

for i in 1:n_samples
    incomplete_x[i][rand(1:dimension)] = missing
end
# Create a matrix instead of vector of vectors
observations = Matrix{Union{Float64, Missing}}(undef, n_samples, dimension)

for i in 1:n_samples
    for j in 1:dimension
        observations[i, j] = incomplete_x[i][j]
    end
end

Key insight: Each sample has exactly one missing element, chosen randomly. This creates a challenging scenario where every observation is incomplete, but different dimensions are missing across samples.

Inference Configuration

# We assume independence between the precision matrix and other variables.
constraints = @constraints begin
    q(x, m, Λ) = q(x, m)q(Λ) 
end

# We need to initialize the precision matrix.
init = @initialization begin
    q(Λ) = Wishart(dimension, diagm(ones(dimension)))
end

result = infer(model=incomplete_data(dim=dimension), data=(y=observations,), constraints=constraints, initialization=init, showprogress=true, iterations=100);

Results Analysis

Recovered Parameters

# Extract final posterior estimates
estimated_covariance = inv(mean(result.posteriors[:Λ][end]))
estimated_mean = mean(result.posteriors[:m][end])

println("True mean: ", real_m[1:dimension])  # Show first 5 elements
println("Estimated mean: ", estimated_mean[1:dimension])
println()
println("True covariance (diagonal): ", diag(inv(real_Λ))[1:dimension])
println("Estimated covariance (diagonal): ", diag(estimated_covariance)[1:dimension])
True mean: [13.0, 1.0, 5.0, 4.0, -20.0, 10.0]
Estimated mean: [8392.43781358009, -20056.09243531769, 17562.677303417844, 
9981.161131341998, -14044.404140823552, -26472.09699565004]

True covariance (diagonal): [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Estimated covariance (diagonal): [7.479314485923238e9, 8.945486281523256e8,
 1.3503456216696682e9, 3.2242131352711415e9, 2.1408785703256266e9, 8.370026
375480433e8]

Convergence Analysis

The algorithm successfully recovers both the mean vector and covariance structure despite having incomplete observations. The Bayesian framework naturally handles the uncertainty introduced by missing data.

Key Takeaways

  1. Missing data as latent variables: RxInfer treats missing observations as latent variables, avoiding the need for explicit imputation.

  2. Principled uncertainty quantification: The posterior distributions capture both parameter uncertainty and uncertainty due to missing data.

  3. Computational efficiency: The mean-field approximation and message-passing algorithms scale well to high-dimensional problems.

  4. Robustness: The method works even when every observation has missing elements, as long as there's sufficient information across the dataset.

# Simple plotting code for the RxInfer incomplete data tutorial
using Plots, Distributions

function plot_posterior_distributions(result, real_m, real_Λ, max_dim=3)
    # Get final posteriors
    final_m_posterior = result.posteriors[:m][end]
    final_Λ_posterior = result.posteriors[:Λ][end]
    
    # Plot mean posterior for first few dimensions
    p1 = plot(title="Posterior Distribution of Mean (first $max_dim dimensions)", 
              xlabel="Value", ylabel="Density")
    
    for i in 1:max_dim
        # Extract marginal distribution for dimension i
        marginal_mean = mean(final_m_posterior)[i]
        marginal_var = inv(mean(final_Λ_posterior))[i,i]
        
        # Plot the Gaussian
        x_range = range(marginal_mean - 3*sqrt(marginal_var), 
                       marginal_mean + 3*sqrt(marginal_var), length=100)
        gaussian = Normal(marginal_mean, sqrt(marginal_var))
        plot!(p1, x_range, pdf.(gaussian, x_range), 
              label="Dimension $i", linewidth=2, color=i)
        
        # Add vertical line for true value with same color
        vline!(p1, [real_m[i]], color=i, linestyle=:dash, alpha=0.7, 
               linewidth=2, label="")
    end
    
    plot(p1)
end
plot_posterior_distributions (generic function with 2 methods)
plot_posterior_distributions(result, real_m, real_Λ, 6)


Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Environment

This example was executed in a clean, isolated environment. Below are the exact package versions used:

For reproducibility:

  • Use the same package versions when running locally
  • Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/incomplete_data/Project.toml`
  [31c24e10] Distributions v0.25.120
  [91a5bcdd] Plots v1.40.14
  [86711068] RxInfer v4.5.0
  [37e2e46d] LinearAlgebra v1.11.0