This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
Gaussian Mixture
This notebook illustrates how to use the NormalMixture
node in RxInfer.jl
for both univariate and multivariate observations.
Load packages
using RxInfer, Plots, Random, LinearAlgebra, StableRNGs, LaTeXStrings
Univariate Gaussian Mixture Model
Consider the data set of length $N$ observed below.
function generate_univariate_data(nr_samples; rng = MersenneTwister(123))
# data generating parameters
class = [1/3, 2/3]
mean1, mean2 = -10, 10
precision = 1.777
# generate data
z = rand(rng, Categorical(class), nr_samples)
y = zeros(nr_samples)
for k in 1:nr_samples
y[k] = rand(rng, Normal(z[k] == 1 ? mean1 : mean2, 1/sqrt(precision)))
return y
data_univariate = generate_univariate_data(100)
histogram(data_univariate, bins=50, label="data", normed=true)
xlims!(minimum(data_univariate), maximum(data_univariate))
ylims!(0, Inf)
ylabel!("relative occurrence [%]")
Model specification
The goal here is to create a model for the data set above. In this case a Gaussian mixture model with $K$ components seems to suite the situation well. We specify the factorized model as $p(y, z, s, m, w) = \prod_{n=1}^N \bigg(p(y_n \mid m, w, z_n) p(z_n \mid s) \bigg)\prod_{k=1}^K \bigg(p(m_k) p(w_k) \bigg) p(s),$ where the individual terms are specified as $\begin{aligned} p(s) &= \mathrm{Beta}(s \mid \alpha_s, \beta_s) \\ p(m_{k}) &= \mathcal{N}(m_k \mid \mu_k, \sigma_k^2) \\ p(w_{k}) &= \Gamma(w_k \mid \alpha_k, \beta_k) \\ p(z_n \mid s) &= \mathrm{Ber}(z_n \mid s) \\ p(y_n \mid m, w, z_n) &= \prod_{k=1}^K \mathcal{N}\left(y_n \mid m_{k}, w_{k}\right)^{z_{nk}} \end{aligned}$
The set of observations $y = \{y_1, y_2, \ldots, y_N\}$ is modeled by a mixture of Gaussian distributions, parameterized by means $m = \{m_1, m_2, \ldots, m_K\}$ and precisions $w = \{ w_1, w_2, \ldots, w_K\}$, where $k$ denotes the component index. This component is selected per observation by the indicator variable $z_n$, which is a one-of-$K$ encoded vector satisfying $\sum_{k=1}^K z_{nk} = 1$ and $z_{nk} \in \{0, 1\} \forall k$. We put a hyperprior on these variables, termed $s$, which represents the relative occurrence of the different realizations of $z_n$.
Here we implement the following model with uninformative values for the hyperparameters as
@model function univariate_gaussian_mixture_model(y)
s ~ Beta(1.0, 1.0)
m[1] ~ Normal(mean = -2.0, variance = 1e3)
w[1] ~ Gamma(shape = 0.01, rate = 0.01)
m[2] ~ Normal(mean = 2.0, variance = 1e3)
w[2] ~ Gamma(shape = 0.01, rate = 0.01)
for i in eachindex(y)
z[i] ~ Bernoulli(s)
y[i] ~ NormalMixture(switch = z[i], m = m, p = w)
Probabilistic inference
In order to fit the model to the data, we are interested in computing the posterior distribution $p(z, s, m, w \mid y)$ However, computation of this term is intractable. Therefore, it is approximated by a naive mean-field approximation, specified as $p(z, s, m, w \mid y) \approx \prod_{n=1}^N q(z_n) \prod_{k=1}^K \bigg(q(m_k) q(w_k)\bigg) q(s),$ with the functional forms $\begin{aligned} q(s) &= \mathrm{Beta}(s \mid \hat{\alpha}_s, \hat{\beta}_s) \\ q(m_k) &= \mathcal{N}(m_k \mid \hat{\mu}_k, \hat{\sigma}^2_k) \\ q(w_k) &= \Gamma (w_k \mid \hat{\alpha}_k, \hat{\beta}_k) \\ q(z_n) &= \mathrm{Ber}(z_n \mid \hat{p}_n) \end{aligned}$ In order to get the inference procedure started, these marginal distribution need to be initialized.
n_iterations = 10
init = @initialization begin
q(s) = vague(Beta)
q(m) = [NormalMeanVariance(-2.0, 1e3), NormalMeanVariance(2.0, 1e3)]
q(w) = [vague(GammaShapeRate), vague(GammaShapeRate)]
results_univariate = infer(
model = univariate_gaussian_mixture_model(),
constraints = MeanField(),
data = (y = data_univariate,),
initialization = init,
iterations = n_iterations,
free_energy = true
Inference results:
Posteriors | available for (m, w, s, z)
Free Energy: | Real[262.985, 206.795, 183.824, 140.402, 138.188, 138.
184, 138.184, 138.184, 138.184, 138.184]
Below the inference results can be seen as a function of the iterations
m1 = [results_univariate.posteriors[:m][i][1] for i in 1:n_iterations]
m2 = [results_univariate.posteriors[:m][i][2] for i in 1:n_iterations]
w1 = [results_univariate.posteriors[:w][i][1] for i in 1:n_iterations]
w2 = [results_univariate.posteriors[:w][i][2] for i in 1:n_iterations];
mp = plot(mean.(m1), ribbon = std.(m1) .|> sqrt, label = L"posterior $m_1$")
mp = plot!(mean.(m2), ribbon = std.(m2) .|> sqrt, label = L"posterior $m_2$")
mp = plot!(mp, [ -10 ], seriestype = :hline, label = L"true $m_1$")
mp = plot!(mp, [ 10 ], seriestype = :hline, label = L"true $m_2$")
wp = plot(mean.(w1), ribbon = std.(w1) .|> sqrt, label = L"posterior $w_1$", legend = :bottomright, ylim = (-1, 3))
wp = plot!(wp, mean.(w2), ribbon = std.(w2) .|> sqrt, label = L"posterior $w_2$")
wp = plot!(wp, [ 1.777 ], seriestype = :hline, label = L"true $w_1$")
wp = plot!(wp, [ 1.777 ], seriestype = :hline, label = L"true $w_2$")
swp = plot(mean.(results_univariate.posteriors[:s]), ribbon = std.(results_univariate.posteriors[:s]) .|> sqrt, label = L"posterior $s$")
swp = plot!(swp, [ 2/3 ], seriestype = :hline, label = L"true $s$")
fep = plot(results_univariate.free_energy, label = "Free Energy", legend = :topright)
plot(mp, wp, swp, fep, layout = @layout([ a b; c d ]), size = (800, 400))
Multivariate Gaussian Mixture Model
The above example can also be extended to the multivariate case. Consider the data set below
function generate_multivariate_data(nr_samples; rng = MersenneTwister(123))
L = 50.0
nr_mixtures = 6
probvec = normalize!(ones(nr_mixtures), 1)
switch = Categorical(probvec)
gaussians = map(1:nr_mixtures) do index
angle = 2π / nr_mixtures * (index - 1)
basis_v = L * [ 1.0, 0.0 ]
R = [ cos(angle) -sin(angle); sin(angle) cos(angle) ]
mean = R * basis_v
covariance = Matrix(Hermitian(R * [ 10.0 0.0; 0.0 20.0 ] * transpose(R)))
return MvNormal(mean, covariance)
z = rand(rng, switch, nr_samples)
y = Vector{Vector{Float64}}(undef, nr_samples)
for n in 1:nr_samples
y[n] = rand(rng, gaussians[z[n]])
return y
data_multivariate = generate_multivariate_data(500)
sdim(n) = (a) -> map(d -> d[n], a) # helper function
scatter(data_multivariate |> sdim(1), data_multivariate |> sdim(2), ms = 2, alpha = 0.4, size = (600, 400), legend=false)
Model specification
The goal here is to create a model for the data set above. In this case a Gaussian mixture model with $K$ components seems to suite the situation well. We specify the factorized model as $p(y, z, s, m, w) = \prod_{n=1}^N \bigg(p(y_n \mid m, W, z_n) p(z_n \mid s) \bigg)\prod_{k=1}^K \bigg(p(m_k) p(W_k) \bigg) p(s),$ where the individual terms are specified as $\begin{aligned} p(s) &= \mathrm{Dir}(s \mid \alpha_s) \\ p(m_{k}) &= \mathcal{N}(m_k \mid \mu_k, \Sigma_k) \\ p(W_{k}) &= \mathcal{W}(W_k \mid V_k, \nu_k) \\ p(z_n \mid s) &= \mathrm{Cat}(z_n \mid s) \\ p(y_n \mid m, W, z_n) &= \prod_{k=1}^K \mathcal{N}\left(y_n \mid m_{k}, W_{k}\right)^{z_{nk}} \end{aligned}$
The set of observations $y = \{y_1, y_2, \ldots, y_N\}$ is modeled by a mixture of Gaussian distributions, parameterized by means $m = \{m_1, m_2, \ldots, m_K\}$ and precisions $W = \{ W_1, W_2, \ldots, W_K\}$, where $k$ denotes the component index. This component is selected per observation by the indicator variable $z_n$, which is a one-of-$K$ encoded vector satisfying $\sum_{k=1}^K z_{nk} = 1$ and $z_{nk} \in \{0, 1\} \forall k$. We put a hyperprior on these variables, termed $s$, which represents the relative occurrence of the different realizations of $z_n$.
@model function multivariate_gaussian_mixture_model(nr_mixtures, priors, y)
local m
local w
for k in 1:nr_mixtures
m[k] ~ priors[k]
w[k] ~ Wishart(3, 1e2*diagm(ones(2)))
s ~ Dirichlet(ones(nr_mixtures))
for n in eachindex(y)
z[n] ~ Categorical(s)
y[n] ~ NormalMixture(switch = z[n], m = m, p = w)
Probabilistic inference
In order to fit the model to the data, we are interested in computing the posterior distribution $p(z, s, m, W \mid y)$ However, computation of this term is intractable. Therefore, it is approximated by a naive mean-field approximation, specified as $p(z, s, m, W \mid y) \approx \prod_{n=1}^N q(z_n) \prod_{k=1}^K \bigg(q(m_k) q(W_k)\bigg) q(s),$ with the functional forms $\begin{aligned} q(s) &= \mathrm{Dir}(s \mid \hat{\alpha}_s) \\ q(m_k) &= \mathcal{N}(m_k \mid \hat{\mu}_k, \hat{\Sigma}_k) \\ q(w_k) &= \mathcal{W}(W_k \mid \hat{V}_k, \hat{\nu}_k) \\ q(z_n) &= \mathrm{Cat}(z_n \mid \hat{p}_n) \end{aligned}$ In order to get the inference procedure started, these marginal distribution need to be initialized.
rng = MersenneTwister(121)
priors = [MvNormal([cos(k*2π/6), sin(k*2π/6)], diagm(1e2 * ones(2))) for k in 1:6]
init = @initialization begin
q(s) = vague(Dirichlet, 6)
q(m) = priors
q(w) = Wishart(3, diagm(1e2 * ones(2)))
results_multivariate = infer(
model = multivariate_gaussian_mixture_model(
nr_mixtures = 6,
priors = priors,
data = (y = data_multivariate,),
constraints = MeanField(),
initialization = init,
iterations = 50,
free_energy = true
Inference results:
Posteriors | available for (w, m, s, z)
Free Energy: | Real[3927.95, 3884.37, 3884.37, 3884.37, 3884.37, 3884
.37, 3884.37, 3884.37, 3884.37, 3884.37 … 3884.37, 3884.37, 3884.37, 3884
.37, 3884.37, 3884.37, 3884.37, 3884.37, 3884.37, 3884.37]
Below the inference results can be seen
p_data = scatter(data_multivariate |> sdim(1), data_multivariate |> sdim(2), ms = 2, alpha = 0.4, legend=false, title="Data", xlims=(-75, 75), ylims=(-75, 75))
p_result = plot(xlims = (-75, 75), ylims = (-75, 75), title="Inference result", legend=false, colorbar = false)
for (e_m, e_w) in zip(results_multivariate.posteriors[:m][end], results_multivariate.posteriors[:w][end])
gaussian = MvNormal(mean(e_m), Matrix(Hermitian(mean(inv, e_w))))
global p_result = contour!(p_result, range(-75, 75, step = 0.25), range(-75, 75, step = 0.25), (x, y) -> pdf(gaussian, [ x, y ]), title="Inference result", legend=false, levels = 7, colorbar = false)
p_fe = plot(results_multivariate.free_energy, label = "Free Energy")
plot(p_data, p_result, p_fe, layout = @layout([ a b; c ]))
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/problem_specific/gaussian_mixture/Project.toml`
[b964fa9f] LaTeXStrings v1.4.0
[91a5bcdd] Plots v1.40.9
[86711068] RxInfer v4.2.0
[860ef19b] StableRNGs v1.0.2
[37e2e46d] LinearAlgebra v1.11.0
[9a3f8284] Random v1.11.0