Funsors¶
Basic Funsors¶
- class Approximate(*args, **kwargs)[source]¶
Bases:
Funsor
Interpretation-specific approximation wrt a set of variables.
The default eager interpretation should be exact. The user-facing interface is the
Funsor.approximate()
method.
- class Cat(name, parts, part_name=None)[source]¶
Bases:
Funsor
Concatenate funsors along an existing input dimension.
- Parameters
- class Funsor(*args, **kwargs)[source]¶
Bases:
object
Abstract base class for immutable functional tensors.
Concrete derived classes must implement
__init__()
methods taking hashable*args
and no optional**kwargs
so as to support cons hashing.Derived classes with
.fresh
variables must implement aneager_subs()
method. Derived classes with.bound
variables must implement an_alpha_convert()
method.- Parameters
inputs (OrderedDict) – A mapping from input name to domain. This can be viewed as a typed context or a mapping from free variables to domains.
output (Domain) – An output domain.
- property dtype¶
- property shape¶
- property requires_grad¶
- sample(sampled_vars, sample_inputs=None, rng_key=None)[source]¶
Create a Monte Carlo approximation to this funsor by replacing functions of
sampled_vars
withDelta
s.The result is a
Funsor
with the same.inputs
and.output
as the original funsor (plussample_inputs
if provided), so that self can be replaced by the sample in expectation computations:y = x.sample(sampled_vars) assert y.inputs == x.inputs assert y.output == x.output exact = (x.exp() * integrand).reduce(ops.add) approx = (y.exp() * integrand).reduce(ops.add)
If
sample_inputs
is provided, this creates a batch of samples.- Parameters
sampled_vars (str, Variable, or set or frozenset thereof.) – A set of input variables to sample.
sample_inputs (OrderedDict) – An optional mapping from variable name to
Domain
over which samples will be batched.rng_key (None or JAX's random.PRNGKey) – a PRNG state to be used by JAX backend to generate random samples
- align(names)[source]¶
Align this funsor to match given
names
. This is mainly useful in preparation for extracting.data
of afunsor.tensor.Tensor
.
- eager_subs(subs)[source]¶
Internal substitution function. This relies on the user-facing
__call__()
method to coerce non-Funsors to Funsors. Once all inputs are Funsors,eager_subs()
implementations can recurse to callSubs
.
- class Independent(*args, **kwargs)[source]¶
Bases:
Funsor
Creates an independent diagonal distribution.
This is equivalent to substitution followed by reduction:
f = ... # a batched distribution assert f.inputs['x_i'] == Reals[4, 5] assert f.inputs['i'] == Bint[3] g = Independent(f, 'x', 'i', 'x_i') assert g.inputs['x'] == Reals[3, 4, 5] assert 'x_i' not in g.inputs assert 'i' not in g.inputs x = Variable('x', Reals[3, 4, 5]) g == f(x_i=x['i']).reduce(ops.add, 'i')
- Parameters
- class Lambda(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy inverse to
ops.getitem
.This is useful to simulate higher-order functions of integers by representing those functions as arrays.
- Parameters
var (Variable) – A variable to bind.
expr (funsor) – A funsor.
- class Number(data, dtype=None)[source]¶
Bases:
Funsor
Funsor backed by a Python number.
- Parameters
data (numbers.Number) – A python number.
dtype – A nonnegative integer or the string “real”.
- class Reduce(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy reduction over multiple variables.
The user-facing interface is the
Funsor.reduce()
method.- Parameters
op (AssociativeOp) – An associative operator.
arg (funsor) – An argument to be reduced.
reduced_vars (frozenset) – A set of variables over which to reduce.
- class Scatter(*args, **kwargs)[source]¶
Bases:
Funsor
Transpose of structurally linear
Subs
, followed byReduce
.For injective scatter operations this should satisfy the equation:
if destin = Scatter(op, subs, source, frozenset()) then source = Subs(destin, subs)
The
reduced_vars
is merely for computational efficiency, and could always be split out into a separate.reduce()
. For example in the following equation, the left hand side uses much less memory than the right hand side:Scatter(op, subs, source, reduced_vars) == Scatter(op, subs, source, frozenset()).reduce(op, reduced_vars)
Warning
This is currently implemented only for injective scatter operations. In particular, this does not allow accumulation behavior like scatter-add.
Note
Scatter(ops.add, ...)
is the funsor analog ofnumpy.add.at()
ortorch.index_put()
orjax.lax.scatter_add()
. For injective substitutions,Scatter(ops.add, ...)
is roughly equivalent to the tensor operation:result = zeros(...) # since zero is the additive unit result[subs] = source
- Parameters
- class Stack(*args, **kwargs)[source]¶
Bases:
Funsor
Stack of funsors along a new input dimension.
- Parameters
- class Slice(name, *args, **kwargs)[source]¶
Bases:
Funsor
Symbolic representation of a Python
slice
object.- Parameters
- class Subs(arg, subs)[source]¶
Bases:
Funsor
Lazy substitution of the form
x(u=y, v=z)
.- Parameters
arg (Funsor) – A funsor being substituted into.
subs (tuple) – A tuple of
(name, value)
pairs, wherename
is a string andvalue
can be coerced to aFunsor
viato_funsor()
.
- class Unary(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy unary operation.
- Parameters
op (Op) – A unary operator.
arg (Funsor) – An argument.
- class Variable(*args, **kwargs)[source]¶
Bases:
Funsor
Funsor representing a single free variable.
- Parameters
name (str) – A variable name.
output (funsor.domains.Domain) – A domain.
- to_data(x, name_to_dim=None, **kwargs)[source]¶
- to_data(x: Funsor, name_to_dim=None)
- to_data(x: Number, name_to_dim=None)
- to_data(x: Tensor, name_to_dim=None)
- to_data(funsor_dist: Distribution, name_to_dim=None)
- to_data(funsor_dist: Independent[Union[Independent, Distribution], str, str, str], name_to_dim=None)
- to_data(funsor_dist: Gaussian, name_to_dim=None)
- to_data(funsor_dist: Contraction[Union[LogaddexpOp, NullOp], AddOp, frozenset, Tuple[Union[Tensor, Number], Gaussian]], name_to_dim=None)
- to_data(funsor_dist: Multinomial, name_to_dim=None)
- to_data(funsor_dist: Delta, name_to_dim=None)
- to_data(expr: Unary[TransformOp, Union[Unary, Variable]], name_to_dim=None)
- to_data(x: Constant, name_to_dim=None)
Extract a python object from a
Funsor
.Raises a
ValueError
if free variables remain or if the funsor is lazy.- Parameters
x – An object, possibly a
Funsor
.name_to_dim (OrderedDict) – An optional inputs hint.
- Returns
A non-funsor equivalent to
x
.- Raises
ValueError if any free variables remain.
- Raises
PatternMissingError if funsor is not fully evaluated.
- to_funsor(x, output=None, dim_to_name=None, **kwargs)[source]¶
- to_funsor(x: Funsor, output=None, dim_to_name=None)
- to_funsor(name: str, output=None)
- to_funsor(x: Number, output=None, dim_to_name=None)
- to_funsor(s: slice, output=None, dim_to_name=None)
- to_funsor(args: tuple, output=None, dim_to_name=None)
- to_funsor(x: generic, output=None, dim_to_name=None)
- to_funsor(x: ndarray, output=None, dim_to_name=None)
- to_funsor(backend_dist: Beta, output=None, dim_to_name=None)
- to_funsor(backend_dist: Cauchy, output=None, dim_to_name=None)
- to_funsor(backend_dist: Chi2, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_BernoulliProbs, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_BernoulliLogits, output=None, dim_to_name=None)
- to_funsor(backend_dist: Binomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: Categorical, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_CategoricalLogits, output=None, dim_to_name=None)
- to_funsor(pyro_dist: Delta, output=None, dim_to_name=None)
- to_funsor(backend_dist: Dirichlet, output=None, dim_to_name=None)
- to_funsor(backend_dist: DirichletMultinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: Exponential, output=None, dim_to_name=None)
- to_funsor(backend_dist: Gamma, output=None, dim_to_name=None)
- to_funsor(backend_dist: GammaPoisson, output=None, dim_to_name=None)
- to_funsor(backend_dist: Geometric, output=None, dim_to_name=None)
- to_funsor(backend_dist: Gumbel, output=None, dim_to_name=None)
- to_funsor(backend_dist: HalfCauchy, output=None, dim_to_name=None)
- to_funsor(backend_dist: HalfNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Laplace, output=None, dim_to_name=None)
- to_funsor(backend_dist: Logistic, output=None, dim_to_name=None)
- to_funsor(backend_dist: LowRankMultivariateNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Multinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: MultivariateNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedBeta, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedDirichlet, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedGamma, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Normal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Pareto, output=None, dim_to_name=None)
- to_funsor(backend_dist: Poisson, output=None, dim_to_name=None)
- to_funsor(backend_dist: StudentT, output=None, dim_to_name=None)
- to_funsor(backend_dist: Uniform, output=None, dim_to_name=None)
- to_funsor(backend_dist: VonMises, output=None, dim_to_name=None)
- to_funsor(backend_dist: ContinuousBernoulli, output=None, dim_to_name=None)
- to_funsor(backend_dist: FisherSnedecor, output=None, dim_to_name=None)
- to_funsor(backend_dist: NegativeBinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: OneHotCategorical, output=None, dim_to_name=None)
- to_funsor(backend_dist: RelaxedBernoulli, output=None, dim_to_name=None)
- to_funsor(backend_dist: Weibull, output=None, dim_to_name=None)
- to_funsor(tfm: Transform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: ExpTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: TanhTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: SigmoidTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: _InverseTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: ComposeTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(backend_dist: ExpandedDistribution, output=None, dim_to_name=None)
- to_funsor(backend_dist: Independent, output=None, dim_to_name=None)
- to_funsor(backend_dist: MaskedDistribution, output=None, dim_to_name=None)
- to_funsor(backend_dist: TransformedDistribution, output=None, dim_to_name=None)
- to_funsor(pyro_dist: Bernoulli, output=None, dim_to_name=None)
- to_funsor(x: ProvenanceTensor, output=None, dim_to_name=None)
- to_funsor(x: Tensor, output=None, dim_to_name=None)
- to_funsor(pyro_dist: FunsorDistribution, output=None, dim_to_name=None)
Convert to a
Funsor
. OnlyFunsor
s and scalars are accepted.- Parameters
x – An object.
output (funsor.domains.Domain) – An optional output hint.
dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions to name strings.
- Returns
A Funsor equivalent to
x
.- Return type
- Raises
ValueError
Delta¶
- class Delta(*args)[source]¶
Bases:
Funsor
Normalized delta distribution binding multiple variables.
There are three syntaxes supported for constructing Deltas:
Delta(((name1, (point1, log_density1)), (name2, (point2, log_density2)), (name3, (point3, log_density3))))
or for a single name:
Delta(name, point, log_density)
or for default
log_density == 0
:Delta(name, point)
- Parameters
terms (tuple) – A tuple of tuples of the form
(name, (point, log_density))
.
Tensor¶
- Einsum(equation, *operands)[source]¶
Wrapper around
torch.einsum()
ornp.einsum()
to operate on real-valued Funsors.Note this operates only on the
output
tensor. To perform sum-product contractions on named dimensions, instead use+
andReduce
.- Parameters
equation (str) – An
torch.einsum()
ornp.einsum()
equation.operands (tuple) – A tuple of input funsors.
- class Function(*args, **kwargs)[source]¶
Bases:
Funsor
Funsor wrapped by a native PyTorch or NumPy function.
Functions are assumed to support broadcasting and can be eagerly evaluated on funsors with free variables of int type (i.e. batch dimensions).
Function
s are usually created via thefunction()
decorator.
- class Tensor(data, inputs=None, dtype='real')[source]¶
Bases:
Funsor
Funsor backed by a PyTorch Tensor or a NumPy ndarray.
This follows the
torch.distributions
convention of arranging named “batch” dimensions on the left and remaining “event” dimensions on the right. The output shape is determined by all remaining dims. For example:data = torch.zeros(5,4,3,2) x = Tensor(data, {"i": Bint[5], "j": Bint[4]}) assert x.output == Reals[3, 2]
Operators like
matmul
and.sum()
operate only on the output shape, and will not change the named inputs.- Parameters
- property requires_grad¶
- new_arange(name, *args, **kwargs)[source]¶
Helper to create a named
torch.arange()
ornp.arange()
funsor. In some cases this can be replaced by a symbolicSlice
.
- align_tensor(new_inputs, x, expand=False)[source]¶
Permute and add dims to a tensor to match desired
new_inputs
.- Parameters
new_inputs (OrderedDict) – A target set of inputs.
x (funsor.terms.Funsor) – A
Tensor
orNumber
.expand (bool) – If False (default), set result size to 1 for any input of
x
not innew_inputs
; if True expand tonew_inputs
size.
- Returns
a number or
torch.Tensor
ornp.ndarray
that can be broadcast to other tensors with inputsnew_inputs
.- Return type
int or float or torch.Tensor or np.ndarray
- align_tensors(*args, **kwargs)[source]¶
Permute multiple tensors before applying a broadcasted op.
This is mainly useful for implementing eager funsor operations.
- Parameters
*args (funsor.terms.Funsor) – Multiple
Tensor
s andNumber
s.expand (bool) – Whether to expand input tensors. Defaults to False.
- Returns
a pair
(inputs, tensors)
where tensors are alltorch.Tensor
s ornp.ndarray
s that can be broadcast together to a single data with giveninputs
.- Return type
- function(*signature)[source]¶
Decorator to wrap a PyTorch/NumPy function, using either type hints or explicit type annotations.
Example:
# Using type hints: @funsor.tensor.function def matmul(x: Reals[3, 4], y: Reals[4, 5]) -> Reals[3, 5]: return torch.matmul(x, y) # Using explicit type annotations: @funsor.tensor.function(Reals[3, 4], Reals[4, 5], Reals[3, 5]) def matmul(x, y): return torch.matmul(x, y) @funsor.tensor.function(Reals[10], Reals[10, 10], Reals[10], Real) def mvn_log_prob(loc, scale_tril, x): d = torch.distributions.MultivariateNormal(loc, scale_tril) return d.log_prob(x)
To support functions that output nested tuples of tensors, specify a nested
Tuple
of output types, for example:@funsor.tensor.function def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return torch.max(x, dim=-1)
- Parameters
*signature – A sequence if input domains followed by a final output domain or nested tuple of output domains.
- tensordot(x, y, dims)[source]¶
Wrapper around
torch.tensordot()
ornp.tensordot()
to operate on real-valued Funsors.Note this operates only on the
output
tensor. To perform sum-product contractions on named dimensions, instead use+
andReduce
.Arguments should satisfy:
len(x.shape) >= dims len(y.shape) >= dims dims == 0 or x.shape[-dims:] == y.shape[:dims]
Gaussian¶
- class BlockMatrix(shape)[source]¶
Bases:
object
Jit-compatible helper to build blockwise matrices. Syntax is similar to
torch.zeros()
x = BlockMatrix((100, 20, 20)) x[..., 0:4, 0:4] = x11 x[..., 0:4, 6:10] = x12 x[..., 6:10, 0:4] = x12.transpose(-1, -2) x[..., 6:10, 6:10] = x22 x = x.as_tensor() assert x.shape == (100, 20, 20)
- class BlockVector(shape)[source]¶
Bases:
object
Jit-compatible helper to build blockwise vectors. Syntax is similar to
torch.zeros()
x = BlockVector((100, 20)) x[..., 0:4] = x1 x[..., 6:10] = x2 x = x.as_tensor() assert x.shape == (100, 20)
- class Gaussian(white_vec=None, prec_sqrt=None, inputs=None, *, mean=None, info_vec=None, precision=None, scale_tril=None, covariance=None)[source]¶
Bases:
Funsor
Funsor representing a batched Gaussian log-density function.
Gaussians are the internal representation for joint and conditional multivariate normal distributions and multivariate normal likelihoods. Mathematically, a Gaussian represents the quadratic log density function:
f(x) = -0.5 * || x @ prec_sqrt - white_vec ||^2 = -0.5 * < x @ prec_sqrt - white_vec | x @ prec_sqrt - white_vec > = -0.5 * < x | prec_sqrt @ prec_sqrt.T | x> + < x | prec_sqrt | white_vec > - 0.5 ||white_vec||^2
Internally Gaussians use a square root information filter (SRIF) representation consisting of a square root of the precision matrix
prec_sqrt
and a vector in the whitened spacewhite_vec
. This representation allows space-efficient construction of Gaussians with incomplete information, i.e. with zero eigenvalues in the precision matrix. These incomplete log densities arise when making low-dimensional observations of higher-dimensional hidden state. Sampling and marginalization are supported only for full-rank Gaussians or full-rank subsets of Gaussians. See therank()
andis_full_rank()
properties.Note
Gaussian
s are not normalized probability distributions, rather they are canonicalized to evaluate to zero log density at their maximum:f(prec_sqrt \ white_vec) = 0
. Not only are Gaussians non-normalized, but they may be rank deficient and non-normalizable, in which case sampling and marginalization are supported only un full-rank subsets of variables.- Parameters
white_vec (torch.Tensor) – An batched white noise vector, where
white_vec = prec_sqrt.T @ mean
. Alternatively you can specify one of the kwargsmean
orinfo_vec
, which will be converted towhite_vec
.prec_sqrt (torch.Tensor) – A batched square root of the positive semidefinite precision matrix. This need not be square, and typically has shape
prec_sqrt.shape == white_vec.shape[:-1] + (dim, rank)
, wheredim
is the total flattened size of real inputs andrank = white_vec.shape[-1]
. Alternatively you can specify one of the kwargsprecision
,covariance
, orscale_tril
, which will be converted toprec_sqrt
.inputs (OrderedDict) – Mapping from name to
Domain
.
- compression_threshold = 2¶
- classmethod set_compression_threshold(threshold: float)[source]¶
Context manager to set rank compression threshold.
To save space Gaussians compress wide
prec_sqrt
matrices down to square. However compression uses a QR decomposition which can be expensive and which has unstable gradients when the resulting precision matrix is rank deficient. To balance space and time costs and numerical stability, compression is trigger only onprec_sqrt
matrices whose width to height ratio is greater thanthreshold
.- Parameters
threshold (float) – Defaults to 2. To optimize for space and deterministic computations, set
threshold = 1
. To optimize for fewest QR decompositions and numerical stability, setthreshold = math.inf
.
- property rank¶
- property is_full_rank¶
Joint¶
Contraction¶
- class Contraction(*args, **kwargs)[source]¶
Bases:
Funsor
Declarative representation of a finitary sum-product operation.
After normalization via the
normalize()
interpretation contractions will canonically order their terms by type:Delta, Number, Tensor, Gaussian
- GaussianMixture¶
alias of
Contraction
Integrate¶
Constant¶
- class ConstantMeta(name, bases, dct)[source]¶
Bases:
FunsorMeta
Wrapper to convert
const_inputs
to a tuple.
- class Constant(const_inputs, arg)[source]¶
Bases:
Funsor
Funsor that is constant wrt
const_inputs
.Constant
can be used for provenance tracking.Examples:
a = Constant(OrderedDict(x=Real, y=Bint[3]), Number(0)) a(y=1) # returns Constant(OrderedDict(x=Real), Number(0)) a(x=2, y=1) # returns Number(0) d = Tensor(torch.tensor([1, 2, 3]))["y"] a + d # returns Constant(OrderedDict(x=Real), d) c = Constant(OrderedDict(x=Bint[3]), Number(1)) c.reduce(ops.add, "x") # returns Number(3)
- Parameters
const_inputs (dict) – A mapping from input name (str) to datatype (
funsor.domain.Domain
).arg (funsor) – A funsor that is constant wrt to const_inputs.