This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
Bayesian Multinomial Regression
This notebook is an introductory tutorial to Bayesian multinomial regression with RxInfer
.
using RxInfer, Plots, StableRNGs, Distributions, ExponentialFamily, StatsPlots
import ExponentialFamily: softmax
Model Description
The key innovation in Linderman et al. (2015) is extending the Pólya-gamma augmentation scheme to the multinomial case. This allows us to transform the non-conjugate multinomial likelihood into a conditionally conjugate form by introducing auxiliary Pólya-gamma random variables.
The multinomial regression model with Pólya-gamma augmentation can be written as: $p(y | \psi, N) = \text{Multinomial}(y |N, \psi)$
where:
\[y\]
is a $K$-dimensional vector of count data with $N$ total counts\[\psi\]
is a $K-1$ -dimensional Gaussian random variable
Implementation
In this notebook, we will implement the Pólya-gamma augmented Bayesian multinomial regression model with RxInfer
by performing inference using message passing to estimate the posterior distribution of the regression coefficients
function generate_multinomial_data(rng=StableRNG(123);N = 20, k=9, nsamples = 1000)
Ψ = randn(rng, k)
p = softmax(Ψ)
X = rand(rng, Multinomial(N, p), nsamples)
X= [X[:,i] for i in 1:size(X,2)];
return X, Ψ,p
end
generate_multinomial_data (generic function with 2 methods)
nsamples = 5000
N = 30
k = 40
X, Ψ, p = generate_multinomial_data(N=N,k=k,nsamples=nsamples);
The MultinomialPolya
factor node is used to model the likelihood of the multinomial distribution.
Due to non-conjugacy of the likelihood and the prior distribution, we need to use a more complex inference algorithm. RxInfer provides an Expectation Propagation (EP) [2] algorithm to infer the posterior distribution. Due to EP's approximation, we need to specify an inbound message for the regression coefficients while using the MultinomialPolya
factor node. This feature is implemented in the dependencies
keyword argument during the creation of the MultinomialPolya
factor node. ReactiveMP.jl
provides a RequireMessageFunctionalDependencies
type that is used to specify the inbound message for the regression coefficients ψ
. Refer to the ReactiveMP.jl documentation for more information.
@model function multinomial_model(obs, N, ξ_ψ, W_ψ)
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
obs .~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))}
end
result = infer(
model = multinomial_model(ξ_ψ=zeros(k-1), W_ψ=rand(Wishart(3, diageye(k-1))), N=N),
data = (obs=X, ),
iterations = 50,
free_energy = true,
showprogress = true,
options = (
limit_stack_depth = 100,
)
)
Inference results:
Posteriors | available for (ψ)
Free Energy: | Real[4.46421e5, 2.92587e5, 2.38082e5, 2.12385e5, 1.983
25e5, 1.89875e5, 1.84456e5, 1.80807e5, 1.78257e5, 1.76422e5 … 1.69971e5,
1.69964e5, 1.69958e5, 169953.0, 1.69948e5, 1.69944e5, 1.6994e5, 1.69937e5,
1.69934e5, 1.69932e5]
plot(result.free_energy[1:end],
title="Free Energy Over Iterations",
xlabel="Iteration",
ylabel="Free Energy",
linewidth=2,
legend=false,
grid=true,
)
predictive = @call_rule MultinomialPolya(:x, Marginalisation) (q_N = PointMass(N), q_ψ = result.posteriors[:ψ][end], meta = MultinomialPolyaMeta(21))
println("Estimated data generation probabilities: $(predictive.p)")
println("True data generation probabilities: $(p)")
Estimated data generation probabilities: [0.011800058483121777, 0.027578125
207094996, 0.004868891497346805, 0.012695477357583574, 0.013339716855034267
, 0.0376068412960776, 0.007469757024150875, 0.007176485081071076, 0.0057046
96368754302, 0.004245544734840519, 0.005559238460360947, 0.0037630434611126
483, 0.00440609827282599, 0.03637530997105492, 0.10874935340831442, 0.07222
726920915121, 0.02650227162061791, 0.02427652737526112, 0.01034007543224731
5, 0.008851345963710448, 0.040005519660497015, 0.005221568303640217, 0.0078
8923808542629, 0.026330673560671473, 0.006484065593229215, 0.00841368746551
6165, 0.008961804955850874, 0.007178954561542986, 0.0168814473660241, 0.007
356538414798726, 0.008918186675458385, 0.00414041737494085, 0.0111774273599
07116, 0.010680500503165274, 0.09499521693786041, 0.04352011582283553, 0.13
269456980417646, 0.027133680300169945, 0.03051998490202622, 0.0679602752725
2999]
True data generation probabilities: [0.012475572764691347, 0.02759115956301
153, 0.004030932560100506, 0.013008651265311708, 0.012888510278451618, 0.03
7656116813111006, 0.007242363105598982, 0.006930069564505769, 0.00538389836
228327, 0.0036198124274772225, 0.005212387391120808, 0.003185556887255863,
0.003820168769118259, 0.036849638787622915, 0.109428569898501, 0.0726075387
5224316, 0.026079268674281158, 0.024477855252934583, 0.010207778995219957,
0.008532295265944583, 0.040242532118754906, 0.005181587450423221, 0.0082073
91370854009, 0.02741148713822125, 0.006623087410725917, 0.00836770271463416
2, 0.009668643362989908, 0.007171783607096945, 0.016985615150215773, 0.0070
80691453323701, 0.008297044496975403, 0.0037359000700039487, 0.011142755810
390478, 0.010256554277897088, 0.09528238587772694, 0.04369806970660494, 0.1
3308101804159636, 0.02665693577960761, 0.030479170124456504, 0.069201498658
71575]
mse = mean((predictive.p - p).^2);
println("MSE between estimated and true data generation probabilities: $mse")
MSE between estimated and true data generation probabilities: 2.18377309530
7924e-7
@model function multinomial_regression(obs, N, X, ϕ, ξβ, Wβ)
β ~ MvNormalWeightedMeanPrecision(ξβ, Wβ)
for i in eachindex(obs)
Ψ[i] := ϕ(X[i])*β
obs[i] ~ MultinomialPolya(N, Ψ[i]) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(zeros(length(obs[i])-1), diageye(length(obs[i])-1)))}
end
end
function generate_regression_data(rng=StableRNG(123);ϕ = identity,N = 3, k=5, nsamples = 1000)
β = randn(rng, k)
X = randn(rng, nsamples, k, k)
X = [X[i,:,:] for i in 1:size(X,1)];
Ψ = ϕ.(X)
p = map(x -> logistic_stick_breaking(x*β), Ψ)
return map(x -> rand(rng, Multinomial(N, x)), p), X, β, p
end
generate_regression_data (generic function with 2 methods)
ϕ = x -> sin(x)
obs_regression, X_regression, β_regression, p_regression = generate_regression_data(;nsamples = 5000, ϕ = ϕ);
reg_results = infer(
model = multinomial_regression(N = 3, ϕ = ϕ, ξβ = zeros(5), Wβ = rand(Wishart(5, diageye(5)))),
data = (obs=obs_regression,X = X_regression ),
iterations = 20,
free_energy = true,
showprogress = true,
returnvars = KeepLast(),
options = (
limit_stack_depth = 100,
)
)
Inference results:
Posteriors | available for (Ψ, β)
Free Energy: | Real[11952.5, 11586.4, 11504.0, 11483.0, 11477.4, 1147
5.8, 11475.4, 11475.3, 11475.2, 11475.2, 11475.2, 11475.2, 11475.2, 11475.2
, 11475.2, 11475.2, 11475.2, 11475.2, 11475.2, 11475.2]
println("estimated β: with mean and covariance: $(mean_cov(reg_results.posteriors[:β]))")
println("true β: $(β_regression)")
estimated β: with mean and covariance: ([-0.1144726250363274, 0.66333367147
58509, -1.2537729584592794, -0.08556521598971065, -0.07931525266376563], [0
.00014804822642633966 -2.14139429120294e-6 3.576070953047402e-6 -1.60668592
39393382e-6 3.216698566433965e-6; -2.14139429120294e-6 0.000151755094039218
6 -1.93148315212323e-5 -2.461981971234855e-7 1.3075471224134247e-6; 3.57607
0953047402e-6 -1.93148315212323e-5 0.0001796836720604156 4.4818667755957866
e-6 4.226588892163872e-7; -1.6066859239393382e-6 -2.461981971234855e-7 4.48
18667755957866e-6 0.00014015901034652314 3.228063638353486e-6; 3.2166985664
33965e-6 1.3075471224134247e-6 4.226588892163872e-7 3.228063638353486e-6 0.
00013948602993677889])
true β: [-0.12683768965424458, 0.6668851724871252, -1.2566124895590247, -0.
08499562516549662, -0.094274004848194]
plot(reg_results.free_energy,
title="Free Energy Over Iterations",
xlabel="Iteration",
ylabel="Free Energy",
linewidth=2,
legend=false,
grid=true,)
mse_β = mean((mean(reg_results.posteriors[:β]) - β_regression).^2)
println("MSE of β estimate: $mse_β")
MSE of β estimate: 7.953192398538803e-5
We can visualize how the logistic stick-breaking transformation of the simplex coordinates of the regression coefficients affects the prior distribution of the regression coefficients and vice versa since the logistic stick-breaking transformation is invertible.
# Previous helper functions remain the same
σ(x) = 1 / (1 + exp(-x))
σ_inv(x) = log(x / (1 - x))
function jacobian_det(π)
K = length(π)
det = 1.0
for k in 1:(K-1)
num = 1 - sum(π[1:(k-1)])
den = π[k] * (1 - sum(π[1:k]))
det *= num / den
end
return det
end
function ψ_to_π(ψ::Vector{Float64})
K = length(ψ) + 1
π = zeros(K)
for k in 1:(K-1)
π[k] = σ(ψ[k]) * (1 - sum(π[1:(k-1)]))
end
π[K] = 1 - sum(π[1:(K-1)])
return π
end
function π_to_ψ(π)
K = length(π)
ψ = zeros(K-1)
ψ[1] = σ_inv(π[1])
for k in 2:(K-1)
ψ[k] = σ_inv(π[k] / (1 - sum(π[1:(k-1)])))
end
return ψ
end
# Function to compute density in simplex coordinates
function compute_simplex_density(x::Float64, y::Float64, Σ::Matrix{Float64})
# Check if point is inside triangle
if y < 0 || y > 1 || x < 0 || x > 1 || (x + y) > 1
return 0.0
end
# Convert from simplex coordinates to π
π1 = x
π2 = y
π3 = 1 - x - y
try
ψ = π_to_ψ([π1, π2, π3])
# Compute Gaussian density
dist = MvNormal(zeros(2), Σ)
return pdf(dist, ψ) * abs(jacobian_det([π1, π2, π3]))
catch
return 0.0
end
end
function plot_transformed_densities()
# Create three different covariance matrices
###For higher variances values needs scaling for proper visualization.
σ² = 1.0
Σ_corr = [σ² 0.9σ²; 0.9σ² σ²]
Σ_anticorr = [σ² -0.9σ²; -0.9σ² σ²]
Σ_uncorr = [σ² 0.0; 0.0 σ²]
# Plot Gaussian densities
ψ1, ψ2 = range(-4sqrt(σ²), 4sqrt(σ²), length=500), range(-4sqrt(σ²), 4sqrt(σ²), length=100)
p1 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_corr), [x,y]),
title="Correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
p2 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_anticorr), [x,y]),
title="Anti-correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
p3 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_uncorr), [x,y]),
title="Uncorrelated Prior", xlabel="ψ₁", ylabel="ψ₂")
# Plot simplex densities
n_points = 500
x = range(0, 1, length=n_points)
y = range(0, 1, length=n_points)
# Plot simplices
p4 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_corr),
title="Correlated Simplex")
# Add simplex boundaries and median lines
plot!(p4, [0,1,0,0], [0,0,1,0], color=:black, label="") # Triangle boundaries
p5 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_anticorr),
title="Anti-correlated Simplex")
plot!(p5, [0,1,0,0], [0,0,1,0], color=:black, label="")
p6 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_uncorr),
title="Uncorrelated Simplex")
plot!(p6, [0,1,0,0], [0,0,1,0], color=:black, label="")
# Combine all plots
plot(p1, p2, p3, p4, p5, p6, layout=(2,3), size=(900,600))
end
# Generate the plots
plot_transformed_densities()
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/bayesian_multinomial_regression/Project.toml`
[31c24e10] Distributions v0.25.120
[62312e5e] ExponentialFamily v2.0.7
[91a5bcdd] Plots v1.40.17
[86711068] RxInfer v4.5.0
[860ef19b] StableRNGs v1.0.3
[f3b207a7] StatsPlots v0.15.7