Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Bayesian Multinomial Regression

This notebook is an introductory tutorial to Bayesian multinomial regression with RxInfer.

using RxInfer, Plots, StableRNGs, Distributions, ExponentialFamily, StatsPlots
import ExponentialFamily: softmax

Model Description

The key innovation in Linderman et al. (2015) is extending the Pólya-gamma augmentation scheme to the multinomial case. This allows us to transform the non-conjugate multinomial likelihood into a conditionally conjugate form by introducing auxiliary Pólya-gamma random variables.

The multinomial regression model with Pólya-gamma augmentation can be written as: $p(y | \psi, N) = \text{Multinomial}(y |N, \psi)$

where:

  • \[y\]

    is a $K$-dimensional vector of count data with $N$ total counts
  • \[\psi\]

    is a $K-1$ -dimensional Gaussian random variable

Implementation

In this notebook, we will implement the Pólya-gamma augmented Bayesian multinomial regression model with RxInfer by performing inference using message passing to estimate the posterior distribution of the regression coefficients

function generate_multinomial_data(rng=StableRNG(123);N = 20, k=9, nsamples = 1000)
    Ψ = randn(rng, k)
    p = softmax(Ψ)
    X = rand(rng, Multinomial(N, p), nsamples)
    X= [X[:,i] for i in 1:size(X,2)];
    return X, Ψ,p
end
generate_multinomial_data (generic function with 2 methods)
nsamples = 5000
N = 30
k = 40
X, Ψ, p = generate_multinomial_data(N=N,k=k,nsamples=nsamples);

The MultinomialPolya factor node is used to model the likelihood of the multinomial distribution.

Due to non-conjugacy of the likelihood and the prior distribution, we need to use a more complex inference algorithm. RxInfer provides an Expectation Propagation (EP) [2] algorithm to infer the posterior distribution. Due to EP's approximation, we need to specify an inbound message for the regression coefficients while using the MultinomialPolya factor node. This feature is implemented in the dependencies keyword argument during the creation of the MultinomialPolya factor node. ReactiveMP.jl provides a RequireMessageFunctionalDependencies type that is used to specify the inbound message for the regression coefficients ψ. Refer to the ReactiveMP.jl documentation for more information.

@model function multinomial_model(obs, N, ξ_ψ, W_ψ)
    ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
    obs .~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))}
end
result = infer(
    model = multinomial_model(ξ_ψ=zeros(k-1), W_ψ=rand(Wishart(3, diageye(k-1))), N=N),
    data = (obs=X, ),
    iterations = 50,
    free_energy = true,
    showprogress = true,
    options = (
        limit_stack_depth = 100,
    )
)
Inference results:
  Posteriors       | available for (ψ)
  Free Energy:     | Real[446752.0, 2.93105e5, 2.38735e5, 2.13138e5, 1.9915
4e5, 1.90762e5, 1.85387e5, 1.81773e5, 1.79249e5, 1.77433e5  …  1.71044e5, 1
.71037e5, 1.71031e5, 1.71026e5, 171021.0, 1.71017e5, 1.71013e5, 1.7101e5, 1
.71007e5, 1.71004e5]
plot(result.free_energy[1:end], 
     title="Free Energy Over Iterations",
     xlabel="Iteration",
     ylabel="Free Energy",
     linewidth=2,
     legend=false,
     grid=true,
     )

predictive = @call_rule MultinomialPolya(:x, Marginalisation) (q_N = PointMass(N), q_ψ = result.posteriors[:ψ][end], meta = MultinomialPolyaMeta(21))
println("Estimated data generation probabilities: $(predictive.p)")
println("True data generation probabilities: $(p)")
Estimated data generation probabilities: [0.012601220188727679, 0.027489195
868565644, 0.004566841738928611, 0.013076116285030493, 0.013880420079730175
, 0.036795349165354006, 0.007056952070568315, 0.007124677330215452, 0.00592
7849573067121, 0.004353707576408949, 0.005670178740663258, 0.00396928337076
34795, 0.004287069918550092, 0.03679235925983153, 0.10821015909607469, 0.07
256355459309397, 0.026569437362601162, 0.023851402189567994, 0.011002922644
901037, 0.00877082499370087, 0.03914669689477235, 0.005395740113056405, 0.0
08646922691700237, 0.02609601761956201, 0.006609443200408948, 0.00828314592
7923162, 0.009144759121316539, 0.007243580137600973, 0.01738580203883556, 0
.00722616821251049, 0.009385422432723894, 0.003454764186611513, 0.012108819
458038723, 0.010861958381053185, 0.09473610866461152, 0.0433383032117478, 0
.1310451471481578, 0.026997650771561903, 0.030261130317039378, 0.0680728974
2442318]
True data generation probabilities: [0.012475572764691347, 0.02759115956301
153, 0.004030932560100506, 0.013008651265311708, 0.012888510278451618, 0.03
7656116813111006, 0.007242363105598982, 0.006930069564505769, 0.00538389836
228327, 0.0036198124274772225, 0.005212387391120808, 0.003185556887255863, 
0.003820168769118259, 0.036849638787622915, 0.109428569898501, 0.0726075387
5224316, 0.026079268674281158, 0.024477855252934583, 0.010207778995219957, 
0.008532295265944583, 0.040242532118754906, 0.005181587450423221, 0.0082073
91370854009, 0.02741148713822125, 0.006623087410725917, 0.00836770271463416
2, 0.009668643362989908, 0.007171783607096945, 0.016985615150215773, 0.0070
80691453323701, 0.008297044496975403, 0.0037359000700039487, 0.011142755810
390478, 0.010256554277897088, 0.09528238587772694, 0.04369806970660494, 0.1
3308101804159636, 0.02665693577960761, 0.030479170124456504, 0.069201498658
71575]
mse = mean((predictive.p - p).^2);
println("MSE between estimated and true data generation probabilities: $mse")
MSE between estimated and true data generation probabilities: 4.75371993355
1976e-7
@model function multinomial_regression(obs, N, X, ϕ, ξβ, Wβ)
    β ~ MvNormalWeightedMeanPrecision(ξβ, Wβ)
    for i in eachindex(obs)
        Ψ[i] := ϕ(X[i])*β
        obs[i] ~ MultinomialPolya(N, Ψ[i]) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(zeros(length(obs[i])-1), diageye(length(obs[i])-1)))}
    end
end
function generate_regression_data(rng=StableRNG(123);ϕ = identity,N = 3, k=5, nsamples = 1000)
    β = randn(rng, k)
    X = randn(rng, nsamples, k, k)
    X = [X[i,:,:] for i in 1:size(X,1)];
    Ψ = ϕ.(X)
    p = map(x -> logistic_stick_breaking(x*β), Ψ)
    return map(x -> rand(rng, Multinomial(N, x)), p), X, β, p
end
generate_regression_data (generic function with 2 methods)
ϕ = x -> sin(x)
obs_regression, X_regression, β_regression, p_regression = generate_regression_data(;nsamples = 5000, ϕ = ϕ);
reg_results = infer(  
    model = multinomial_regression(N = 3, ϕ = ϕ, ξβ = zeros(5), Wβ = rand(Wishart(5, diageye(5)))),
    data = (obs=obs_regression,X = X_regression ),
    iterations = 20,
    free_energy = true,
    showprogress = true,
    returnvars = KeepLast(),
    options = (
        limit_stack_depth = 100,
    ) 
)
Inference results:
  Posteriors       | available for (Ψ, β)
  Free Energy:     | Real[11950.2, 11584.2, 11501.8, 11480.8, 11475.2, 1147
3.6, 11473.2, 11473.0, 11473.0, 11473.0, 11473.0, 11473.0, 11473.0, 11473.0
, 11473.0, 11473.0, 11473.0, 11473.0, 11473.0, 11473.0]
println("estimated β: with mean and covariance: $(mean_cov(reg_results.posteriors[:β]))")
println("true β: $(β_regression)")
estimated β: with mean and covariance: ([-0.11524613426609875, 0.6627909161
82894, -1.253992306308758, -0.08494945053643042, -0.07987663990112007], [0.
0001480080581601405 -2.128236825580651e-6 3.705231919246632e-6 -1.605600469
615616e-6 3.203248802444071e-6; -2.128236825580651e-6 0.0001516518154129740
6 -1.92911093804846e-5 -2.0138681363160548e-7 1.3030205609665011e-6; 3.7052
31919246632e-6 -1.92911093804846e-5 0.00017971570179046073 4.41736124107205
85e-6 5.043199211790552e-7; -1.605600469615616e-6 -2.0138681363160548e-7 4.
4173612410720585e-6 0.00014010055212099626 3.2078570809395344e-6; 3.2032488
02444071e-6 1.3030205609665011e-6 5.043199211790552e-7 3.2078570809395344e-
6 0.00013952833424015363])
true β: [-0.12683768965424458, 0.6668851724871252, -1.2566124895590247, -0.
08499562516549662, -0.094274004848194]
plot(reg_results.free_energy,
title="Free Energy Over Iterations",
xlabel="Iteration",
ylabel="Free Energy",
linewidth=2,
legend=false,
grid=true,)

mse_β =  mean((mean(reg_results.posteriors[:β]) - β_regression).^2)
println("MSE of β estimate: $mse_β")
MSE of β estimate: 7.305574015635415e-5

We can visualize how the logistic stick-breaking transformation of the simplex coordinates of the regression coefficients affects the prior distribution of the regression coefficients and vice versa since the logistic stick-breaking transformation is invertible.


# Previous helper functions remain the same
σ(x) = 1 / (1 + exp(-x))
σ_inv(x) = log(x / (1 - x))

function jacobian_det(π)
    K = length(π)
    det = 1.0
    for k in 1:(K-1)
        num = 1 - sum(π[1:(k-1)])
        den = π[k] * (1 - sum(π[1:k]))
        det *= num / den
    end
    return det
end

function ψ_to_π(ψ::Vector{Float64})
    K = length(ψ) + 1
    π = zeros(K)
    for k in 1:(K-1)
        π[k] = σ(ψ[k]) * (1 - sum(π[1:(k-1)]))
    end
    π[K] = 1 - sum(π[1:(K-1)])
    return π
end

function π_to_ψ(π)
    K = length(π)
    ψ = zeros(K-1)
    ψ[1] = σ_inv(π[1])
    for k in 2:(K-1)
        ψ[k] = σ_inv(π[k] / (1 - sum(π[1:(k-1)])))
    end
    return ψ
end

# Function to compute density in simplex coordinates
function compute_simplex_density(x::Float64, y::Float64, Σ::Matrix{Float64})
    # Check if point is inside triangle
    if y < 0 || y > 1 || x < 0 || x > 1 || (x + y) > 1
        return 0.0
    end
    
    # Convert from simplex coordinates to π
    π1 = x
    π2 = y
    π3 = 1 - x - y
    
    try
        ψ = π_to_ψ([π1, π2, π3])
        # Compute Gaussian density
        dist = MvNormal(zeros(2), Σ)
        return pdf(dist, ψ) * abs(jacobian_det([π1, π2, π3]))
    catch
        return 0.0
    end
   
end

function plot_transformed_densities()
    # Create three different covariance matrices
    ###For higher variances values needs scaling for proper visualization.
    σ² = 1.0
    Σ_corr = [σ² 0.9σ²; 0.9σ² σ²]
    Σ_anticorr = [σ² -0.9σ²; -0.9σ² σ²]
    Σ_uncorr = [σ² 0.0; 0.0 σ²]
    
    # Plot Gaussian densities
    ψ1, ψ2 = range(-4sqrt(σ²), 4sqrt(σ²), length=500), range(-4sqrt(σ²), 4sqrt(σ²), length=100)
    
    p1 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_corr), [x,y]),
                 title="Correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
    p2 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_anticorr), [x,y]),
                 title="Anti-correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
    p3 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_uncorr), [x,y]),
                 title="Uncorrelated Prior", xlabel="ψ₁", ylabel="ψ₂")
    
    # Plot simplex densities
    n_points = 500
    x = range(0, 1, length=n_points)
    y = range(0, 1, length=n_points)
    
    # Plot simplices
    p4 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_corr),
                 title="Correlated Simplex")
    
    # Add simplex boundaries and median lines
    plot!(p4, [0,1,0,0], [0,0,1,0], color=:black, label="")  # Triangle boundaries
    
    p5 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_anticorr),
                 title="Anti-correlated Simplex")
    plot!(p5, [0,1,0,0], [0,0,1,0], color=:black, label="")
    
    p6 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_uncorr),
                 title="Uncorrelated Simplex")
    plot!(p6, [0,1,0,0], [0,0,1,0], color=:black, label="")
    
    # Combine all plots
    plot(p1, p2, p3, p4, p5, p6, layout=(2,3), size=(900,600))
end

# Generate the plots
plot_transformed_densities()


Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Environment

This example was executed in a clean, isolated environment. Below are the exact package versions used:

For reproducibility:

  • Use the same package versions when running locally
  • Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/bayesian_multinomial_regression/Project.toml`
  [31c24e10] Distributions v0.25.120
  [62312e5e] ExponentialFamily v2.0.5
  [91a5bcdd] Plots v1.40.14
  [86711068] RxInfer v4.5.0
  [860ef19b] StableRNGs v1.0.3
  [f3b207a7] StatsPlots v0.15.7