Pyro-Compatible Distributions¶
This interface provides a number of PyTorch-style distributions that use
funsors internally to perform inference. These high-level objects are based on
a wrapping class: FunsorDistribution
which
wraps a funsor in a PyTorch-distributions-compatible interface.
FunsorDistribution
objects can be used
directly in Pyro models (using the standard Pyro backend).
FunsorDistribution Base Class¶
-
class
FunsorDistribution
(funsor_dist, batch_shape=torch.Size([]), event_shape=torch.Size([]), dtype='real', validate_args=None)[source]¶ Bases:
pyro.distributions.torch_distribution.TorchDistribution
Distribution
wrapper around aFunsor
for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface.Parameters: - funsor_dist (funsor.terms.Funsor) – A funsor with an input named “value” that is treated as a random variable. The distribution should be normalized over “value”.
- batch_shape (torch.Size) – The distribution’s batch shape. This must
be in the same order as the input of the
funsor_dist
, but may contain extra dims of size 1. - event_shape – The distribution’s event shape.
-
arg_constraints
= {}¶
-
support
¶
Hidden Markov Models¶
-
class
DiscreteHMM
(initial_logits, transition_logits, observation_dist, validate_args=None)[source]¶ Bases:
funsor.pyro.distribution.FunsorDistribution
Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses [1] to parallelize over time, achieving O(log(time)) parallel complexity.
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_logits
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:# homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape
This class should be interchangeable with
pyro.distributions.hmm.DiscreteHMM
.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
- “Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Parameters: - initial_logits (Tensor) – A logits tensor for an initial
categorical distribution over latent states. Should have rightmost size
state_dim
and be broadcastable tobatch_shape + (state_dim,)
. - transition_logits (Tensor) – A logits tensor for transition
conditional distributions between latent states. Should have rightmost
shape
(state_dim, state_dim)
(old, new), and be broadcastable tobatch_shape + (num_steps, state_dim, state_dim)
. - observation_dist (Distribution) – A conditional
distribution of observed data conditioned on latent state. The
.batch_shape
should have rightmost sizestate_dim
and be broadcastable tobatch_shape + (num_steps, state_dim)
. The.event_shape
may be arbitrary.
-
has_rsample
¶
-
class
GaussianHMM
(initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None)[source]¶ Bases:
funsor.pyro.distribution.FunsorDistribution
Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.This corresponds to the generative model:
z = initial_distribution.sample() x = [] for t in range(num_steps): z = z @ transition_matrix + transition_dist.sample() x.append(z @ observation_matrix + observation_dist.sample())
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianHMM
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
- “Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Variables: Parameters: - initial_dist (MultivariateNormal) – A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
. - transition_matrix (Tensor) – A linear transformation of hidden
state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim)
where the rightmost dims are ordered(old, new)
. - transition_dist (MultivariateNormal) – A process
noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim,)
. - transition_matrix – A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim)
. - observation_dist (MultivariateNormal or
Normal) – An observation noise distribution. This should
have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(obs_dim,)
.
-
has_rsample
= True¶
-
arg_constraints
= {}¶
-
class
GaussianMRF
(initial_dist, transition_dist, observation_dist, validate_args=None)[source]¶ Bases:
funsor.pyro.distribution.FunsorDistribution
Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianMRF
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
- “Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Variables: Parameters: - initial_dist (MultivariateNormal) – A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
. - transition_dist (MultivariateNormal) – A joint
distribution factor over a pair of successive time steps. This should
have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + hidden_dim,)
(old+new). - observation_dist (MultivariateNormal) – A joint
distribution factor over a hidden and an observed state. This should
have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + obs_dim,)
.
-
has_rsample
= True¶
-
class
SwitchingLinearHMM
(initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None)[source]¶ Bases:
funsor.pyro.distribution.FunsorDistribution
Switching Linear Dynamical System represented as a Hidden Markov Model.
This corresponds to the generative model:
z = Categorical(logits=initial_logits).sample() y = initial_mvn[z].sample() x = [] for t in range(num_steps): z = Categorical(logits=transition_logits[t, z]).sample() y = y @ transition_matrix[t, z] + transition_mvn[t, z].sample() x.append(y @ observation_matrix[t, z] + observation_mvn[t, z].sample())
Viewed as a dynamic Bayesian network:
z[t-1] ----> z[t] ---> z[t+1] Discrete latent class | \ | \ | \ | y[t-1] ----> y[t] ----> y[t+1] Gaussian latent state | / | / | / V / V / V / x[t-1] x[t] x[t+1] Gaussian observation
Let
class
be the latent class,state
be the latent multivariate normal state, andvalue
be the observed multivariate normal value.Parameters: - initial_logits (Tensor) – Represents
p(class[0])
. - initial_mvn (MultivariateNormal) – Represents
p(state[0] | class[0])
. - transition_logits (Tensor) – Represents
p(class[t+1] | class[t])
. - transition_matrix (Tensor) –
- transition_mvn (MultivariateNormal) – Together
with
transition_matrix
, this representsp(state[t], state[t+1] | class[t])
. - observation_matrix (Tensor) –
- observation_mvn (MultivariateNormal) – Together
with
observation_matrix
, this representsp(value[t+1], state[t+1] | class[t+1])
. - exact (bool) – If True, perform exact inference at cost exponential in
num_steps
. If False, use amoment_matching()
approximation and use parallel scan algorithm to reduce parallel complexity to logarithmic innum_steps
. Defaults to False.
-
has_rsample
= True¶
-
arg_constraints
= {}¶
-
filter
(value)[source]¶ Compute posterior over final state given a sequence of observations.
Parameters: value (Tensor) – A sequence of observations. Returns: A posterior distribution over latent states at the final time step, represented as a pair (cat, mvn)
, whereCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction.Return type: tuple
- initial_logits (Tensor) – Represents
Conversion Utilities¶
This module follows a convention for converting between funsors and PyTorch distribution objects. This convention is compatible with NumPy/PyTorch-style broadcasting. Following PyTorch distributions (and Tensorflow distributions), we consider “event shapes” to be on the right and broadcast-compatible “batch shapes” to be on the left.
This module also aims to be forgiving in inputs and pedantic in outputs:
methods accept either the superclass torch.distributions.Distribution
objects or the subclass pyro.distributions.TorchDistribution
objects.
Methods return only the narrower subclass
pyro.distributions.TorchDistribution
objects.
-
tensor_to_funsor
(tensor, event_inputs=(), event_output=0, dtype='real')[source]¶ Convert a
torch.Tensor
to afunsor.tensor.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.Parameters: - tensor (torch.Tensor) – A PyTorch tensor.
- event_inputs (tuple) – A tuple of names for rightmost tensor
dimensions. If
tensor
has these names, they will be converted toresult.inputs
. - event_output (int) – The number of tensor dimensions assigned to
result.output
. These must be on the right of anyevent_input
dimensions.
Returns: A funsor.
Return type:
-
funsor_to_tensor
(funsor_, ndims, event_inputs=())[source]¶ Convert a
funsor.tensor.Tensor
to atorch.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.Parameters: - funsor (funsor.tensor.Tensor) – A funsor.
- ndims (int) – The number of result dims,
== result.dim()
. - event_inputs (tuple) – Names assigned to rightmost dimensions.
Returns: A PyTorch tensor.
Return type:
-
dist_to_funsor
(pyro_dist, event_inputs=())[source]¶ Convert a PyTorch distribution to a Funsor.
Parameters: torch.distribution.Distribution – A PyTorch distribution. Returns: A funsor. Return type: funsor.terms.Funsor
-
mvn_to_funsor
(pyro_dist, event_inputs=(), real_inputs={})[source]¶ Convert a joint
torch.distributions.MultivariateNormal
distribution into aFunsor
with multiple real inputs.This should satisfy:
sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0]
Parameters: - pyro_dist (torch.distributions.MultivariateNormal) – A multivariate normal distribution over one or more variables of real or vector or tensor type.
- event_inputs (tuple) – A tuple of names for rightmost dimensions.
These will be assigned to
result.inputs
of typeBint
. - real_inputs (OrderedDict) – A dict mapping real variable name
to appropriately sized
Real
. The sum of all.numel()
of all real inputs should be equal to thepyro_dist
dimension.
Returns: A funsor with given
real_inputs
and possibly additional Bint inputs.Return type:
-
funsor_to_mvn
(gaussian, ndims, event_inputs=())[source]¶ Convert a
Funsor
to apyro.distributions.MultivariateNormal
, dropping the normalization constant.Parameters: - gaussian (funsor.gaussian.Gaussian or funsor.joint.Joint) – A Gaussian funsor.
- ndims (int) – The number of batch dimensions in the result.
- event_inputs (tuple) – A tuple of names to assign to rightmost dimensions.
Returns: a multivariate normal distribution.
Return type:
-
funsor_to_cat_and_mvn
(funsor_, ndims, event_inputs)[source]¶ Converts a labeled gaussian mixture model to a pair of distributions.
Parameters: - funsor (funsor.joint.Joint) – A Gaussian mixture funsor.
- ndims (int) – The number of batch dimensions in the result.
Returns: A pair
(cat, mvn)
, wherecat
is aCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components.
-
class
AffineNormal
(matrix, loc, scale, value_x, value_y)[source]¶ Bases:
funsor.terms.Funsor
Represents a conditional diagonal normal distribution over a random variable
Y
whose mean is an affine function of a random variableX
. The likelihood ofX
is thus:AffineNormal(matrix, loc, scale).condition(y).log_density(x)
which is equivalent to:
Normal(x @ matrix + loc, scale).to_event(1).log_prob(y)
Parameters: - matrix (Funsor) – A transformation from
X
toY
. Should have rightmost shape(x_dim, y_dim)
. - loc (Funsor) – A constant offset for
Y
’s mean. Should have output shape(y_dim,)
. - scale (Funsor) – Standard deviation for
Y
. Should have output shape(y_dim,)
. - value_x (Funsor) – A value
X
. - value_y (Funsor) – A value
Y
.
- matrix (Funsor) – A transformation from
-
matrix_and_mvn_to_funsor
(matrix, mvn, event_dims=(), x_name='value_x', y_name='value_y')[source]¶ Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:
y = x @ matrix + mvn.sample()
The result is a non-normalized Gaussian funsor with two real inputs,
x_name
andy_name
, corresponding to a conditional distribution of real vectory` given real vector ``x
.Parameters: - matrix (torch.Tensor) – A matrix with rightmost shape
(x_size, y_size)
. - mvn (torch.distributions.MultivariateNormal or
torch.distributions.Independent of torch.distributions.Normal) – A multivariate normal distribution with
event_shape == (y_size,)
. - event_dims (tuple) – A tuple of names for rightmost dimensions.
These will be assigned to
result.inputs
of typeBint
. - x_name (str) – The name of the
x
random variable. - y_name (str) – The name of the
y
random variable.
Returns: A funsor with given
real_inputs
and possibly additional Bint inputs.Return type: - matrix (torch.Tensor) – A matrix with rightmost shape