Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Recurrent Switching Linear Dynamical System

This is an experimental example of a Recurrent Switching Linear Dynamical System (RSLDS) model. The notebook requires patches to RxInfer and ReactiveMP, which are condensed in the hidden blocks below.

Hidden block of RxInfer & ReactiveMP patches and extensions - click to expand

using ExponentialFamily, RxInfer, BayesBase, GraphPPL
import ReactiveMP: AbstractFactorNode, NodeInterface, IndexedNodeInterface, FactorNodeActivationOptions, Marginalisation,
 Deterministic, PredefinedNodeFunctionalForm,FunctionalDependencies, collect_functional_dependencies, activate!, functional_dependencies, 
 collect_latest_messages, collect_latest_marginals, marginalrule, rule, name, getinboundinterfaces, clustername, getdependecies,
 messagein, ManyOf, getvariable
import ExponentialFamily: getnaturalparameters, exponential_family_typetag
export Gate, GateNode

# Mixture Functional Form
struct Gate{N} end

ReactiveMP.as_node_symbol(::Type{<:Gate}) = :Gate
ReactiveMP.interfaces(::Type{<:Gate}) = Val((:out, :switch, :inputs))
ReactiveMP.alias_interface(::Type{<:Gate}, ::Int64, name::Symbol) = name
ReactiveMP.is_predefined_node(::Type{<:Gate}) = ReactiveMP.PredefinedNodeFunctionalForm()
ReactiveMP.sdtype(::Type{<:Gate}) = ReactiveMP.Deterministic()
ReactiveMP.collect_factorisation(::Type{<:Gate}, factorization) = GateNodeFactorisation()

struct GateNodeFactorisation end

struct GateNode{N} <: ReactiveMP.AbstractFactorNode
    out    :: ReactiveMP.NodeInterface
    switch :: ReactiveMP.NodeInterface
    inputs :: NTuple{N, ReactiveMP.IndexedNodeInterface}
end 

ReactiveMP.functionalform(factornode::GateNode{N}) where {N} = Gate{N}
ReactiveMP.getinterfaces(factornode::GateNode) = (factornode.out, factornode.switch, factornode.inputs...)
ReactiveMP.sdtype(factornode::GateNode) = ReactiveMP.Deterministic()

ReactiveMP.interfaceindices(factornode::GateNode, iname::Symbol)                       = (ReactiveMP.interfaceindex(factornode, iname),)
ReactiveMP.interfaceindices(factornode::GateNode, inames::NTuple{N, Symbol}) where {N} = map(iname -> ReactiveMP.interfaceindex(factornode, iname), inames)

ReactiveMP.interfaceindex(factornode::GateNode, iname::Symbol) = begin
    if iname === :out
        return 1
    elseif iname === :switch
        return 2
    elseif iname === :inputs
        return 3
    end
end

ReactiveMP.factornode(::Type{<:Gate}, interfaces, factorization) = begin
    outinterface = interfaces[findfirst(((name, variable),) -> name == :out, interfaces)]
    switchinterface = interfaces[findfirst(((name, variable),) -> name == :switch, interfaces)]
    inputinterfaces = filter(((name, variable),) -> name == :inputs, interfaces)
    N = length(inputinterfaces)
    return GateNode(ReactiveMP.NodeInterface(outinterface...), ReactiveMP.NodeInterface(switchinterface...), ntuple(i -> ReactiveMP.IndexedNodeInterface(i, ReactiveMP.NodeInterface(inputinterfaces[i]...)), N))
    
end

struct GateNodeInboundInterfaces end

ReactiveMP.getinboundinterfaces(::GateNode) = GateNodeInboundInterfaces()
ReactiveMP.clustername(::GateNodeInboundInterfaces) = :switch_inputs


struct GateNodeFunctionalDependencies <: FunctionalDependencies end

ReactiveMP.collect_functional_dependencies(::GateNode, ::Nothing) = GateNodeFunctionalDependencies()
ReactiveMP.collect_functional_dependencies(::GateNode, ::GateNodeFunctionalDependencies) = GateNodeFunctionalDependencies()
ReactiveMP.collect_functional_dependencies(::GateNode, ::Any) =
    error("The functional dependencies for GateNode must be either `Nothing` or `GateNodeFunctionalDependencies`")

ReactiveMP.activate!(factornode::GateNode, options::FactorNodeActivationOptions) = begin
    dependencies = ReactiveMP.collect_functional_dependencies(factornode, ReactiveMP.getdependecies(options))
    return ReactiveMP.activate!(dependencies, factornode, options)
end


ReactiveMP.functional_dependencies(::GateNodeFunctionalDependencies, factornode::GateNode{N}, interface, iindex::Int) where {N} = begin
    message_dependencies = if iindex === 1
        # output depends on input messages:
        (factornode.inputs, )
    elseif iindex === 2
        # switch depends on:
        (factornode.out, factornode.inputs)
    elseif 2 < iindex <= N + 2
        # k'th input depends on:
        (factornode.out, )
    else
        error("Bad index in functional_dependencies for MixtureNode")
    end

    marginal_dependencies = if iindex === 1
        # output depends on:
        (factornode.switch, )
    elseif iindex === 2
        # switch depends on:
        ( )
    elseif 2 < iindex <= N + 2
        # k'th input depends on:
        (factornode.switch,)
    else
        error("Bad index in functional_dependencies for GateNode")
    end

    return message_dependencies, marginal_dependencies
end


ReactiveMP.collect_latest_messages(::GateNodeFunctionalDependencies, factornode::GateNode{N}, messages::Tuple{NodeInterface}) where {N} = begin
    outputinterface = messages[1]

    msgs_names = Val{(name(outputinterface),)}()
    msgs_observable = combineLatestUpdates((messagein(outputinterface),), PushNew())
    return msgs_names, msgs_observable
end

ReactiveMP.collect_latest_marginals(::GateNodeFunctionalDependencies, factornode::GateNode{N}, marginals::Tuple{NodeInterface}) where {N} = begin
    switchinterface = marginals[1]

    marginal_names = Val{(name(switchinterface),)}()
    marginal_observable = combineLatestUpdates((
        getmarginal(getvariable(switchinterface), IncludeAll()),
    ), PushNew())

    return marginal_names, marginal_observable
end

ReactiveMP.collect_latest_marginals(::GateNodeFunctionalDependencies, factornode::GateNode{N}, marginals::NTuple{N,IndexedNodeInterface}) where {N} = begin
    inputsinterfaces = marginals
    
    marginal_names = Val{(name(first(inputsinterfaces)),)}()
    marginal_observable = combineLatest(map(input -> getmarginal(getvariable(input), IncludeAll()), inputsinterfaces), PushNew()) |> map_to((ManyOf(map(input -> getmarginal(getvariable(input), IncludeAll()), inputsinterfaces)),))
    
    return marginal_names, marginal_observable
end

ReactiveMP.collect_latest_messages(::GateNodeFunctionalDependencies, factornode::GateNode{N}, messages::Tuple{NodeInterface, NTuple{N, IndexedNodeInterface}}) where {N} = begin
    output_or_switch_interface = messages[1]
    inputsinterfaces = messages[2]

    msgs_names = Val{(name(output_or_switch_interface), name(inputsinterfaces[1]))}()
    msgs_observable =
        combineLatest(
            (messagein(output_or_switch_interface), combineLatest(map(input -> messagein(input), inputsinterfaces), PushNew())),
            PushNew()
        ) |> map_to((
            messagein(output_or_switch_interface), 
            ManyOf(map(input -> messagein(input), inputsinterfaces))
        ))
    
    return msgs_names, msgs_observable
end

ReactiveMP.collect_latest_messages(::GateNodeFunctionalDependencies, factornode::GateNode{N}, messages::Tuple{NTuple{N,IndexedNodeInterface}}) where {N} = begin
    inputsinterfaces = messages[1]
    
    msgs_names = Val{(name(first(inputsinterfaces)),)}()
    msgs_observable = combineLatest(map(input -> messagein(input), inputsinterfaces), PushNew()) |> map_to((ManyOf(map(input -> messagein(input), inputsinterfaces)),))
    
    return msgs_names, msgs_observable
end


ReactiveMP.marginalrule(fform::Type{<:Gate}, on::Val{:switch_inputs}, mnames::Any, messages::Any, qnames::Nothing, marginals::Nothing, meta::Nothing, __node::Any) = begin
    # m_out = getdata(messages[1])
    m_switch = getdata(messages[2])
    m_inputs = getdata.(messages[3:end])


    return FactorizedJoint((m_inputs..., m_switch))
end

ReactiveMP.@rule Gate(:out, Marginalisation) (q_switch::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
    return MixtureDistribution(collect(m_inputs), probvec(q_switch))
end


ReactiveMP.@rule Gate(:switch, Marginalisation) (m_out::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
    logscales = map(input -> compute_logscale(prod(GenericProd(),m_out,input), m_out, input), m_inputs)
    p = softmax(collect(logscales))
    return Multinomial(1, p)
end


ReactiveMP.@rule Gate((:inputs, k), Marginalisation) (m_out::Any, q_switch::Any,) = begin
    z = probvec(q_switch)[k]
    ef_out = convert(ExponentialFamilyDistribution, m_out)
    η      = getnaturalparameters(ef_out)
    ef_opt = ExponentialFamilyDistribution(exponential_family_typetag(ef_out), η * z)

    return convert(Distribution, ef_opt)
end


ReactiveMP.@rule typeof(*)(:out, Marginalisation) (m_A::PointMass, m_in::MixtureDistribution, meta::Any) = begin 
    comps = BayesBase.components(m_in)
    new_components = similar(comps)
    @inbounds for (i,component) in enumerate(comps)
        new_components[i] = @call_rule typeof(*)(:out, Marginalisation) (m_A = m_A, m_in = component, meta = meta)
    end
    dist = MixtureDistribution(new_components, BayesBase.weights(m_in))
    return dist
end

ReactiveMP.@rule typeof(dot)(:out, Marginalisation) (m_in1::MixtureDistribution, m_in2::PointMass, meta::Any) = begin 
    comps = BayesBase.components(m_in1)
    new_components = []
    @inbounds for (i, component) in enumerate(comps)
        push!(new_components, @call_rule typeof(dot)(:out, Marginalisation) (m_in1 = component, m_in2 = m_in2, meta = meta))
    end
    
    mixture = MixtureDistribution(new_components, BayesBase.weights(m_in1))

    return mixture
end



@rule typeof(dot)(:in1, Marginalisation) (m_out::MixtureDistribution, m_in2::PointMass, meta::Any) = begin 
    comps = BayesBase.components(m_out)
    weights = BayesBase.weights(m_out)
    new_comps = []
    for (comp, weight) in zip(comps, weights)
        new_comp = @call_rule typeof(dot)(:in1, Marginalisation) (m_out = comp, m_in2 = m_in2, meta = meta)
        push!(new_comps, new_comp)
    end
    return MixtureDistribution(new_comps, weights)
end

function BayesBase.prod(::BayesBase.UnspecifiedProd, left::GaussianDistributionsFamily, right::MixtureDistribution)
    comps = BayesBase.components(right)
    weights = BayesBase.weights(right)
    new_comps = []
    for comp in comps
        new_comp = prod(GenericProd(),left, comp)
        push!(new_comps, new_comp)
    end
    
    return MixtureDistribution(new_comps, weights)
end

BayesBase.prod(::BayesBase.UnspecifiedProd, left::MixtureDistribution, right::GaussianDistributionsFamily) = prod(GenericProd(),right, left)
BayesBase.paramfloattype(::MixtureDistribution) = Float64


import ExponentialFamily.LogExpFunctions: logsumexp
function BayesBase.prod(::GenericProd, left::Categorical, right::Multinomial)
    @assert right.n == 1
    right_cat = Categorical(right.p)

    p = prod(GenericProd(), left, right_cat).p 
    return Multinomial(1, p)

end



BayesBase.prod(::GenericProd, left::Multinomial, right::Categorical) = prod(GenericProd(), right, left)
function BayesBase.prod(::GenericProd, left::Multinomial, right::Multinomial) 
    @assert left.n == right.n

    p = left.p .* right.p
    p = p ./ sum(p)
    return Multinomial(left.n, p)
end

BayesBase.prod(::BayesBase.UnspecifiedProd, left::Multinomial, right::Multinomial) = prod(GenericProd(), left, right)


function BayesBase.compute_logscale(dist1::Multinomial, dist2::Multinomial, dist3::Multinomial) 
    logp1 = log.(dist1.p) - log(dist1.p[end])
    logp2 = log.(dist2.p) - log(dist2.p[end])
    logp3 = log.(dist3.p) - log(dist3.p[end])
    return logsumexp(logp1) - logsumexp(logp2) - logsumexp(logp3)
end

BayesBase.compute_logscale(d1::ExponentialFamily.WishartFast, d2::ExponentialFamily.WishartFast, d3::ExponentialFamily.WishartFast) = begin
    return logpartition(convert(ExponentialFamilyDistribution, d1)) - logpartition(convert(ExponentialFamilyDistribution, d2)) - logpartition(convert(ExponentialFamilyDistribution, d3))
end


ExponentialFamily.probvec(d::Multinomial) = d.p

@rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MixtureDistribution, meta::Any) = begin 
    q_a_normal = convert(promote_variate_type(typeof(mean(q_a)), NormalMeanPrecision), mean(q_a), precision(q_a))
    return @call_rule ContinuousTransition(:W, Marginalisation) (q_y_x = q_y_x, q_a = q_a_normal, meta = meta)
end

@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MixtureDistribution, q_W::Any, meta::Any) = begin 
    q_a_normal = convert(promote_variate_type(typeof(mean(q_a)), NormalMeanPrecision), mean(q_a), precision(q_a))
    return @call_rule ContinuousTransition(:y, Marginalisation) (m_x = m_x, q_a = q_a_normal, q_W = q_W, meta = meta)
end


@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MixtureDistribution, q_W::Any, meta::Any) = begin 
    q_a_normal = convert(promote_variate_type(typeof(mean(q_a)), NormalMeanPrecision), mean(q_a), precision(q_a))
    return @call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = q_y_x, q_a = q_a_normal, q_W = q_W, meta = meta)
end

@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily , q_a::MixtureDistribution, q_W::Any, meta::Any) = begin 
    q_a_normal = convert(promote_variate_type(typeof(mean(q_a)), NormalMeanPrecision), mean(q_a), precision(q_a))
    return @call_rule ContinuousTransition(:x, Marginalisation) (m_y = m_y, q_a = q_a_normal, q_W = q_W, meta = meta)
end

@rule ContinuousTransition(:y, Marginalisation) (m_x::MixtureDistribution, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::Any) = begin 
    m_x_normal = convert(promote_variate_type(typeof(mean(m_x)), NormalMeanPrecision), mean(m_x), precision(m_x))
    return @call_rule ContinuousTransition(:y, Marginalisation) (m_x = m_x_normal, q_a = q_a, q_W = q_W, meta = meta)
end


@marginalrule ContinuousTransition(:y_x) (m_y::MultivariateNormalDistributionsFamily, m_x::MixtureDistribution, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::Any) = begin 
    m_x_normal = convert(promote_variate_type(typeof(mean(m_x)), NormalMeanPrecision), mean(m_x), precision(m_x))
    return @call_marginalrule ContinuousTransition(:y_x) (m_y = m_y, m_x = m_x_normal, q_a = q_a, q_W = q_W, meta = meta)
end

@rule typeof(+)(:out, Marginalisation) (m_in1::MultivariateNormalDistributionsFamily, m_in2::MixtureDistribution, ) = begin 
    return @call_rule typeof(+)(:out, Marginalisation) (m_in1 = m_in1, m_in2 = convert(promote_variate_type(typeof(mean(m_in2)), NormalMeanPrecision), mean(m_in2), precision(m_in2)))
end

@rule typeof(+)(:in1, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in2::MixtureDistribution, ) = begin 
    return @call_rule typeof(+)(:in1, Marginalisation) (m_out = convert(promote_variate_type(typeof(mean(m_out)), NormalMeanPrecision), mean(m_out), precision(m_out)), m_in2 = convert(promote_variate_type(typeof(mean(m_in2)), NormalMeanPrecision), mean(m_in2), precision(m_in2)))
end

@rule DiscreteTransition(:out, Marginalisation) (m_in::Multinomial, q_a::DirichletCollection, ) = begin 
    @assert m_in.n == 1
    p = probvec(m_in)
    m_in_cat = Categorical(p)
    return @call_rule DiscreteTransition(:out, Marginalisation) (m_in = m_in_cat, q_a = q_a)
end

@rule DiscreteTransition(:in, Marginalisation) (m_out::Multinomial, q_a::DirichletCollection, ) = begin 
    @assert m_out.n == 1
    p = probvec(m_out)
    m_out_cat = Categorical(p)
    return @call_rule DiscreteTransition(:in, Marginalisation) (m_out = m_out_cat, q_a = q_a)
end

@marginalrule DiscreteTransition(:out_in) (m_out::Multinomial, m_in::Multinomial, q_a::DirichletCollection, ) = begin 
    @assert m_out.n == 1 && m_in.n == 1
    p_out = probvec(m_out)
    p_in = probvec(m_in)
    m_out_cat = Categorical(p_out)
    m_in_cat = Categorical(p_in)
    return @call_marginalrule DiscreteTransition(:out_in) (m_out = m_out_cat, m_in = m_in_cat, q_a = q_a)
end


Base.length(d::MixtureDistribution) = length(d.components)
Base.ndims(d::MixtureDistribution) = first(size(first(d.components)))

ExponentialFamily.probvec(d::Multinomial) = d.p

BayesBase.entropy(d::MixtureDistribution) = mapreduce((c,w) -> w * BayesBase.entropy(c), +, d.components, d.weights)

BayesBase.mean(f::F, itr::MixtureDistribution) where {F} = mapreduce((c,w) -> w * mean(f, c), +, itr.components, itr.weights)

function create_P_matrix(n_switches)
    P = zeros(n_switches, n_switches)
    for i in 1:n_switches
        P[i,:] = 0.5 * ones(n_switches)
        P[i,i] = 1.0
    end
    return P
end

function BayesBase.mean(mixture::MixtureDistribution)
    component_means = mean.(BayesBase.components(mixture))
    component_weights = BayesBase.weights(mixture)
    return mapreduce((m,w) -> w*m, +, component_means, component_weights)
end

function BayesBase.cov(mixture::MixtureDistribution)
    component_cov = cov.(BayesBase.components(mixture))
    component_means = mean.(BayesBase.components(mixture))
    component_weights = BayesBase.weights(mixture)
    mixture_mean = mean(mixture)
    return mapreduce((v,m,w) -> w*(v + m*m'), +, component_cov, component_means, component_weights) - mixture_mean*mixture_mean'
end

BayesBase.precision(mixture::MixtureDistribution) = inv(cov(mixture))

function BayesBase.var(mixture::MixtureDistribution)
    component_vars = var.(BayesBase.components(mixture))
    component_means = mean.(BayesBase.components(mixture))
    component_weights = BayesBase.weights(mixture)
    mixture_mean = mean(mixture)
    return mapreduce((v,m,w) -> w*(v + m.^2), +, component_vars, component_means, component_weights) - mixture_mean.^2
end
Hidden block of RSLDS Model Specification - click to expand

import ExponentialFamily: softmax

"""
    RSLDSHyperparameters{T}

Structure containing hyperparameters for the Recurrent Switching Linear Dynamical System (RSLDS) model.

# Fields
- `a_w::T = 2.0`: Shape parameter for the Gamma prior on precision parameter w (when n_switches=1)
- `b_w::T = 2.0`: Rate parameter for the Gamma prior on precision parameter w (when n_switches=1)
- `Ψ_w::Matrix{T}`: Scale matrix for the Wishart prior on precision matrix w (when n_switches>1)
- `Ψ_R::Union{Matrix{T}, T}`: Scale matrix/parameter for the Wishart/Gamma prior on observation precision
- `ν_R::T`: Degrees of freedom for the Wishart prior on observation precision
- `α::Matrix{T}`: Parameter matrix for the Dirichlet prior on transition probabilities
- `C::Matrix{T}`: Observation matrix mapping latent states to observations
"""
Base.@kwdef struct RSLDSHyperparameters{T} 
   a_w::T = 2.0
   b_w::T = 2.0
   Ψ_w::Matrix{T}
   Ψ_R::Union{Matrix{T}, T}
   ν_R::T
   α::Matrix{T} 
   C::Matrix{T}
end

"""
    get_hyperparameters(hyperparameters::RSLDSHyperparameters)

Extract all hyperparameters from the RSLDSHyperparameters structure.

# Arguments
- `hyperparameters::RSLDSHyperparameters`: Structure containing the hyperparameters

# Returns
A tuple containing all hyperparameters in the order: a_w, b_w, Ψ_w, Ψ_R, ν_R, α, C
"""
function get_hyperparameters(hyperparameters::RSLDSHyperparameters)
    return hyperparameters.a_w, hyperparameters.b_w, hyperparameters.Ψ_w, hyperparameters.Ψ_R, hyperparameters.ν_R, hyperparameters.α, hyperparameters.C
end

"""
    default_hyperparameters(n_switches, obs_dim, dim_latent)

Create a default set of hyperparameters for the RSLDS model.

# Arguments
- `n_switches`: Number of switching states in the model
- `obs_dim`: Dimension of the observation space
- `dim_latent`: Dimension of the latent state space

# Returns
An RSLDSHyperparameters structure with default values
"""
function default_hyperparameters(n_switches, obs_dim, dim_latent)
    return RSLDSHyperparameters(
        a_w = 2.0,
        b_w = 2.0,
        Ψ_w = diageye(n_switches),
        Ψ_R = diageye(obs_dim),
        ν_R = obs_dim + 2.0,
        α = ones(n_switches+1, n_switches+1),
        C = diageye(obs_dim,dim_latent)
    )
end


@model function rslds_model_learning(obs,n_obs,n_switches, dim_latent, η, Ψ, hyperparameters, learn_observation_covariance)
    local H,A,Λ,u
    transformation  = (x) -> reshape(x, (dim_latent, dim_latent))
    transformation2 = (x) -> reshape(x, (n_switches, dim_latent))
    ##Hyperparameters
    a_w, b_w, Ψ_w, Ψ_R,ν_R, α, C = get_hyperparameters(hyperparameters)
    ## Priors on the parameters 
    if n_switches == 1
        w ~ GammaShapeRate(a_w, b_w)
    else
        w ~ Wishart(n_switches+2,Ψ_w)
    end 

    if learn_observation_covariance
        if n_obs == 1
            R ~ GammaShapeRate(ν_R, Ψ_R)
        else
            R ~ Wishart(ν_R, Ψ_R) 
        end
    else
        R = Ψ_R
    end
    
    for k in 1:n_switches+1
        H[k] ~ MvNormalMeanCovariance(zeros(dim_latent^2), diageye(dim_latent^2))
        Λ[k] ~ Wishart(dim_latent+2, diageye(dim_latent))
    end
    P ~ DirichletCollection(α)
    ϕ ~ MvNormalMeanCovariance(zeros(dim_latent*n_switches), diageye(dim_latent*n_switches))
    ## States Initialisation 
    x[1] ~ MvNormalMeanCovariance(zeros(dim_latent), diageye(dim_latent))
    for t in eachindex(obs)  
        ## Recurrent Layer
        if n_switches == 1
            u[t] ~ softdot(ϕ, x[t], w)
        else
            u[t] ~ ContinuousTransition(x[t], ϕ, w) where {meta = CTMeta(transformation2)}
        end     
        s[t] ~ MultinomialPolya(1, u[t]) where {dependencies = RequireMessageFunctionalDependencies(ψ = convert(promote_variate_type(typeof(η), NormalWeightedMeanPrecision), η, Ψ))}   
        s[t+1] ~ DiscreteTransition(s[t], P)
        ##Transition Layer
        A[t] := Gate(switch=s[t+1], inputs=H)
        B[t] := Gate(switch=s[t+1], inputs=Λ)
        x[t+1] ~ ContinuousTransition(x[t], A[t], B[t]) where {meta = CTMeta(transformation)}
        ## Observation Layer
        obs[t] ~ MvNormalMeanPrecision(C*x[t+1], R)
    end
end

@constraints  function rslds_learning_constraints(learn_observation_covariance)
    if learn_observation_covariance
        q(x,s,u,ϕ,w,P,H,A,Λ,B,R) = q(x,u)q(A)q(s)q(ϕ)q(w)q(P)q(H)q(Λ)q(B)q(R)
    else
        q(x,s,u,ϕ,w,P,H,A,Λ,B) = q(x,u)q(A)q(s)q(ϕ)q(w)q(P)q(H)q(Λ)q(B)
    end
end

@initialization function rslds_learning_initmarginals(n_switches, dim_latent, obs_dim, learn_observation_covariance; rng = StableRNG(42))    
    q(x) = vague(MvNormalWeightedMeanPrecision, dim_latent)
    q(s) = Multinomial(1,softmax(randn(rng, n_switches+1)))
    q(ϕ) = vague(MvNormalWeightedMeanPrecision, dim_latent*(n_switches))
    if n_switches == 1
        q(w) = vague(GammaShapeRate)
    else
        q(w) = vague(Wishart, n_switches)   
    end
    q(A) = vague(MvNormalWeightedMeanPrecision, dim_latent^2)
    q(P) = DirichletCollection(ones(n_switches+1,n_switches+1))
    q(Λ) = vague(Wishart, dim_latent)
    q(H) = vague(MvNormalWeightedMeanPrecision, dim_latent^2)
    q(B) = vague(Wishart, dim_latent)
    if learn_observation_covariance
        if obs_dim == 1
            q(R) = vague(GammaShapeRate)
        else
            q(R) = Wishart(obs_dim+2, diageye(obs_dim))
        end
    end
end;



"""
    fit_rslds(data, n_switches, dim_latent, n_obs; kwargs...)

Fit a Recurrent Switching Linear Dynamical System (RSLDS) model to the provided data.

# Arguments
- `data`: Time series observation data
- `n_switches`: Number of switching states in the model. Note: The user provides the total number of states,
  but internally we use (n_switches-1) because the MultinomialPolya distribution adds an extra dimension
  to represent the recurrent influence on state transitions.
- `dim_latent`: Dimension of the latent state space
- `n_obs`: Dimension of the observation space

# Keyword Arguments
- `iterations::Int = 60`: Number of inference iterations
- `η = nothing`: Mean parameter for the functional dependency in MultinomialPolya
- `Ψ = nothing`: Precision parameter for the functional dependency in MultinomialPolya
- `hyperparameters = nothing`: Custom hyperparameters for the model
- `progress::Bool = false`: Whether to show progress during inference
- `learn_observation_covariance::Bool = false`: Whether to learn the observation covariance

# Returns
The result of the inference procedure
"""
function fit_rslds(data, n_switches, dim_latent, n_obs; iterations = 60, η = nothing, Ψ = nothing, hyperparameters = nothing, progress = false, learn_observation_covariance = false)
    @assert n_switches > 1 "n_switches must be greater than 1"
    # We subtract 1 from n_switches because the MultinomialPolya distribution
    # internally adds an extra dimension to represent the recurrent influence
    # on state transitions. This convention allows the model to maintain the
    # correct dimensionality while incorporating the recurrent dynamics.
    n_switches = n_switches - 1

    if hyperparameters === nothing
        hyperparameters = default_hyperparameters(n_switches, length(data[1]), dim_latent)
    end

    if η === nothing
        if n_switches == 1
            η = 0.0
        else
            η = zeros(n_switches)
        end
    end
    if Ψ === nothing
        if n_switches == 1
            Ψ = 0.0001
        else
            Ψ = 0.0001*diageye(n_switches)
        end
    end
    model = rslds_model_learning(n_obs = n_obs, n_switches = n_switches, dim_latent = dim_latent, η = η, Ψ = Ψ, hyperparameters = hyperparameters, learn_observation_covariance = learn_observation_covariance)
    constraints = rslds_learning_constraints(learn_observation_covariance)
    initmarginals = rslds_learning_initmarginals(n_switches, dim_latent, n_obs, learn_observation_covariance)
    
    return infer(model = model, data = (obs=data, ), constraints = constraints, initialization = initmarginals, iterations = iterations,
    showprogress = progress,
    returnvars = KeepEach(),
    free_energy = true,
    options = (limit_stack_depth = 100,)
    )
end

# 

function states_to_categorical(states)
    return [argmax(states[t].p) for t in 1:length(states)]
end
Hidden block of Generating Synthetic Data - click to expand

using StableRNGs

function generate_switching_data(T, A1, A2, c, Q, R, x_0;rng = StableRNG(42))
    # Initialize arrays to store states and observations
    x = zeros(2, T)  # State matrix: 2 dimensions × T timesteps
    y = zeros(2, T)  # Observation matrix: 2 dimensions × T timesteps
    
    # Set initial state
    x[:,1] = x_0
    
    # Generate state transitions and observations
    for t in 2:T
        # Switch dynamics multiple times through the sequence
        if t < T/3 || (t >= T/2 && t < 3T/4)
            x[:,t] = A2 * x[:,t-1] + rand(rng,MvNormal(zeros(2), Q))  # First regime
        else
            x[:,t] = A1 * x[:,t-1] + rand(rng,MvNormal(zeros(2), Q))  # Second regime
        end
        
        # Generate observation from current state
        y[:,t] = c * x[:,t] + rand(rng,MvNormal(zeros(2), R))
    end

    return x, y
end
        

# System parameters
T = 500  # Time horizon
θ = π / 15  # Rotation angle

# Define system matrices
A1 = [cos(θ) -sin(θ); sin(θ) cos(θ)]    # Rotation matrix
A2 = [0.4 -0.01; 0.01 0.2]         
c = [0.6 -0.02; -0.02 0.3]                   # Observation/distortion matrix

# Noise parameters
Q = [1.0 0.0; 0.0 1.0]                   # State noise covariance
R =  [1.0 0.0; 0.0 1.0]            # Observation noise variance
x_0 = [0.0, 0.0]                         # Initial state vector

# Generate synthetic data
x, y = generate_switching_data(T, A1, A2, c, Q, R, x_0)
y = [y[:,i] for i in 1:T]
x = [x[:,i] for i in 1:T]
hyperparameters = RSLDSHyperparameters(
    a_w = 0.01,
    b_w = 0.01,
    Ψ_w = 10.0*diageye(2), # n-1  
    Ψ_R = inv(R),
    ν_R = 4.0,
    α = ones(2,2), # n
    C = c
)
Main.anonymous.RSLDSHyperparameters{Float64}(0.01, 0.01, [10.0 0.0; 0.0 10.
0], [1.0 0.0; 0.0 1.0], 4.0, [1.0 1.0; 1.0 1.0], [0.6 -0.02; -0.02 0.3])
rslds_result = fit_rslds(y, 2, 2, 2; iterations = 150, hyperparameters = hyperparameters, progress = true)
Inference results:
  Posteriors       | available for (ϕ, w, P, A, s, H, Λ, B, u, x)
  Free Energy:     | Real[386913.0, 75580.8, 2809.27, 2010.43, 1920.47, 186
5.65, 1893.55, 1933.31, 1972.41, 2010.39  …  1724.41, 1725.22, 1726.01, 172
6.79, 1727.56, 1728.32, 1729.06, 1729.79, 1730.51, 1731.22]
using Plots
switching_state_posterior = rslds_result.posteriors[:s][end];
states = states_to_categorical(switching_state_posterior);
scatter(states, label="Estimated Regimes", color="blue", linewidth=2)

continuous_state_posterior = rslds_result.posteriors[:x][end];
index = 1
from = 1
to = 500

m_continuous = getindex.(mean.(continuous_state_posterior), index);
var_continuous = getindex.(var.(continuous_state_posterior), index);
plot(m_continuous[from+1:to], ribbon=sqrt.(var_continuous[from+1:to]), label="Estimated States", color="blue",fillalpha=0.2, linewidth=2)
plot!(getindex.(x,index)[from:to], label="True States", color="green", linewidth=1)
scatter!(getindex.(y,index)[from:to], label="Observed Data", color="black", ms=1.3)
lens!([10,50],[-3, 3],inset = (1, bbox(0.07, 0.6, 0.3, 0.3)), )

index = 2
m_continuous = getindex.(mean.(continuous_state_posterior), index);
var_continuous = getindex.(var.(continuous_state_posterior), index);
plot(m_continuous[from+1:to], ribbon=sqrt.(var_continuous[from+1:to]), label="Estimated States", color="blue",fillalpha=0.2, linewidth=2)
plot!(getindex.(x,index)[from:to], label="True States", color="green", linewidth=1)
scatter!(getindex.(y,index)[from:to], label="Observed Data", color="black", ms=1.3)
lens!([350,400],[-3, 3],inset = (1, bbox(0.07, 0.6, 0.3, 0.3)), )

println("Estimated continuous transition matrices:")
println("----------------------------------------")
for i in 1:length(rslds_result.posteriors[:H][end])
    println("Matrix $i:")
    println(reshape(mean(rslds_result.posteriors[:H][end][i]), 2, 2))
    println()
end
Estimated continuous transition matrices:
----------------------------------------
Matrix 1:
[0.9801827287287415 -0.20963425131252503; 0.19050473128110149 0.97914938515
24095]

Matrix 2:
[0.5040325201488759 -0.3213042852544048; 0.10781013310896664 0.535126629567
7637]
println("Estimated discrete transition matrix for HMM layer:")
println("----------------------------------------")
println(mean(rslds_result.posteriors[:P][end]))
Estimated discrete transition matrix for HMM layer:
----------------------------------------
[0.9866080421546887 0.008636213099697618; 0.01339195784531141 0.99136378690
03024]

Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Environment

This example was executed in a clean, isolated environment. Below are the exact package versions used:

For reproducibility:

  • Use the same package versions when running locally
  • Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/experimental_examples/recurrent_switching_linear_dynamical_system/Project.toml`
  [b4ee3484] BayesBase v1.5.4
  [62312e5e] ExponentialFamily v2.0.5
  [b3f8163a] GraphPPL v4.6.2
  [91a5bcdd] Plots v1.40.13
  [a194aa59] ReactiveMP v5.4.3
  [86711068] RxInfer v4.4.2
  [860ef19b] StableRNGs v1.0.2