Note
Go to the end to download the full example code
Example: Discrete HMM¶
import argparse
from collections import OrderedDict
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import reinterpret
from funsor.optimizer import apply_optimizer
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True)
emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True)
params = [trans_probs, emit_probs]
# A discrete HMM model.
def model(data):
log_prob = funsor.to_funsor(0.0)
trans = dist.Categorical(
probs=funsor.Tensor(
trans_probs,
inputs=OrderedDict([("prev", funsor.Bint[args.hidden_dim])]),
)
)
emit = dist.Categorical(
probs=funsor.Tensor(
emit_probs,
inputs=OrderedDict([("latent", funsor.Bint[args.hidden_dim])]),
)
)
x_curr = funsor.Number(0, args.hidden_dim)
for t, y in enumerate(data):
x_prev = x_curr
# A delayed sample statement.
x_curr = funsor.Variable("x_{}".format(t), funsor.Bint[args.hidden_dim])
log_prob += trans(prev=x_prev, value=x_curr)
if not args.lazy and isinstance(x_prev, funsor.Variable):
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)
log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
data = torch.ones(args.time_steps, dtype=torch.long)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
if args.lazy:
with funsor.interpretations.lazy:
log_prob = apply_optimizer(model(data))
log_prob = reinterpret(log_prob)
else:
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kalman filter example")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("-d", "--hidden-dim", default=2, type=int)
parser.add_argument("--lazy", action="store_true")
parser.add_argument("--filter", action="store_true")
parser.add_argument("--xfail-if-not-implemented", action="store_true")
args = parser.parse_args()
if args.xfail_if_not_implemented:
try:
main(args)
except NotImplementedError:
print("XFAIL")
else:
main(args)