Operations¶
Operation classes¶
- class LogAbsDetJacobianOp(*args, **kwargs)¶
Bases:
BinaryOp
- static default(x, y, fn)¶
- dispatcher = <dispatched log_abs_det_jacobian>¶
- name = 'log_abs_det_jacobian'¶
- signature = <Signature (x, y, fn)>¶
- class Op(*args, **kwargs)[source]¶
Bases:
object
Abstract base class for all mathematical operations on ground terms.
Ops take
arity
-many leftmost positional args that may be funsors, followed by additional non-funsor args and kwargs. The additional args and kwargs must have default values.When wrapping new backend ops, keep in mind these restrictions, which may require you to wrap backend functions before making them into ops:
Create new ops only by decoraing a default implementation with
@UnaryOp.make
,@BinaryOp.make
, etc.Register backend-specific implementations via
@my_op.register(type1)
,@my_op.register(type1, type2)
etc for arity 1, 2, etc. Patterns may include only the firstarity
-many types.Only the first
arity
-many arguments may be funsors. Remaining args and kwargs must all be ground Python data.
- Variables
~.arity (int) – The number of funsor arguments this op takes. Must be defined by subclasses.
- Parameters
*args –
**kwargs – All extra arguments to this op, excluding the arguments up to
.arity
,
- arity = NotImplemented¶
- register(*pattern)¶
- class TransformOp(*args, **kwargs)[source]¶
Bases:
UnaryOp
- set_inv(fn)[source]¶
- Parameters
fn (callable) – A function that inputs an arg
y
and outputs a valuex
such thaty=self(x)
.
- class WrappedTransformOp(*args, **kwargs)¶
Bases:
TransformOp
Wrapper for a backend
Transform
object that provides.inv
and.log_abs_det_jacobian
. This additionally validates shapes on the first__call__()
.- static default(x, fn, *, validate_args=True)¶
Wrapper for a backend
Transform
object that provides.inv
and.log_abs_det_jacobian
. This additionally validates shapes on the first__call__()
.
- dispatcher = <dispatched wrapped_transform>¶
- property inv¶
- property log_abs_det_jacobian¶
- name = 'wrapped_transform'¶
- signature = <Signature (x, fn, *, validate_args=True)>¶
Builtin operations¶
- abs = ops.abs¶
Return the absolute value of the argument.
- add = ops.add¶
Same as a + b.
- and_ = ops.and_¶
Same as a & b.
- atanh = ops.atanh¶
Return the inverse hyperbolic tangent of x.
- eq = ops.eq¶
Same as a == b.
- exp = ops.exp¶
Return e raised to the power of x.
- floordiv = ops.floordiv¶
Same as a // b.
- ge = ops.ge¶
Same as a >= b.
- getitem = ops.getitem¶
- getslice = ops.getslice¶
- gt = ops.gt¶
Same as a > b.
- invert = ops.invert¶
Same as ~a.
- le = ops.le¶
Same as a <= b.
- lgamma = ops.lgamma¶
Natural logarithm of absolute value of Gamma function at x.
- log = ops.log¶
- log1p = ops.log1p¶
Return the natural logarithm of 1+x (base e).
The result is computed in a way which is accurate for x near zero.
- lshift = ops.lshift¶
Same as a << b.
- lt = ops.lt¶
Same as a < b.
- matmul = ops.matmul¶
Same as a @ b.
- max = ops.max¶
- min = ops.min¶
- mod = ops.mod¶
Same as a % b.
- mul = ops.mul¶
Same as a * b.
- ne = ops.ne¶
Same as a != b.
- neg = ops.neg¶
Same as -a.
- null = ops.null¶
Placeholder associative op that unifies with any other op
- or_ = ops.or_¶
Same as a | b.
- pos = ops.pos¶
Same as +a.
- pow = ops.pow¶
Same as a ** b.
- reciprocal = ops.reciprocal¶
- rshift = ops.rshift¶
Same as a >> b.
- safediv = ops.safediv¶
- safesub = ops.safesub¶
- sigmoid = ops.sigmoid¶
- sqrt = ops.sqrt¶
Return the square root of x.
- sub = ops.sub¶
Same as a - b.
- tanh = ops.tanh¶
Return the hyperbolic tangent of x.
- truediv = ops.truediv¶
Same as a / b.
- xor = ops.xor¶
Same as a ^ b.
Array operations¶
- all = ops.all¶
- amax = ops.amax¶
- amin = ops.amin¶
- any = ops.any¶
- argmax = ops.argmax¶
- argmin = ops.argmin¶
- astype = ops.astype¶
- cat = ops.cat¶
- cholesky = ops.cholesky¶
Like
numpy.linalg.cholesky()
but uses sqrt for scalar matrices.
- cholesky_inverse = ops.cholesky_inverse¶
Like
torch.cholesky_inverse()
but supports batching and gradients.
- cholesky_solve = ops.cholesky_solve¶
- clamp = ops.clamp¶
- detach = ops.detach¶
- diagonal = ops.diagonal¶
- einsum = ops.einsum¶
- expand = ops.expand¶
- finfo = ops.finfo¶
- flip = ops.flip¶
- full_like = ops.full_like¶
- isnan = ops.isnan¶
- logaddexp = ops.logaddexp¶
- logsumexp = ops.logsumexp¶
- mean = ops.mean¶
- new_arange = ops.new_arange¶
- new_eye = ops.new_eye¶
- new_full = ops.new_full¶
- new_zeros = ops.new_zeros¶
- permute = ops.permute¶
- prod = ops.prod¶
- qr = ops.qr¶
- randn = ops.randn¶
- sample = ops.sample¶
- scatter = ops.scatter¶
- scatter_add = ops.scatter_add¶
- stack = ops.stack¶
- std = ops.std¶
- sum = ops.sum¶
- transpose = ops.transpose¶
- triangular_inv = ops.triangular_inv¶
- triangular_solve = ops.triangular_solve¶
- unsqueeze = ops.unsqueeze¶
- var = ops.var¶