Interfacing with PyTorch¶
Adding a GTN function or layer to PyTorch is just like adding any custom extension to PyTorch. For the details, take a look at an example which constructs a custom loss function in PyTorch with GTN. We’ll go over the example here at a high-level with attention to the bits specific to GTN.
First declare the class which should inherit from torch.autograd.Function
.
class GTNLossFunction(torch.autograd.Function):
"""
A minimal example of adding a custom loss function built with GTN graphs to
PyTorch.
The example is a sequence criterion which computes a loss between a
frame-level input and a token-level target. The tokens in the target can
align to one or more frames in the input.
"""
The GTNLossFunction
requires static forward
and a backward
methods.
The forward method, with some additional comments, is copied below:
@staticmethod
def forward(ctx, inputs, targets):
B, T, C = inputs.shape
losses = [None] * B
emissions_graphs = [None] * B
# Move data to the host as GTN operations run on the CPU:
device = inputs.device
inputs = inputs.cpu()
targets = targets.cpu()
# Compute the loss for the b-th example:
def forward_single(b):
emissions = gtn.linear_graph(T, C, inputs.requires_grad)
# *NB* A reference to the `data` should be held explicitly when
# using `data_ptr()` otherwise the memory may be claimed before the
# weights are set. For example, the following is undefined and will
# likely cause serious issues:
# `emissions.set_weights(inputs[b].contiguous().data_ptr())`
data = inputs[b].contiguous()
emissions.set_weights(data.data_ptr())
target = GTNLossFunction.make_target_graph(targets[b])
# Score the target:
target_score = gtn.forward_score(gtn.intersect(target, emissions))
# Normalization term:
norm = gtn.forward_score(emissions)
# Compute the loss:
loss = gtn.subtract(norm, target_score)
# We need the save the `loss` graph to call `gtn.backward` and we
# need the `emissions` graph to access the gradients:
losses[b] = loss
emissions_graphs[b] = emissions
# Compute the loss in parallel over the batch:
gtn.parallel_for(forward_single, range(B))
# Save some graphs and other data for backward:
ctx.auxiliary_data = (losses, emissions_graphs, inputs.shape)
# Put losses back in a torch tensor and move them back to the device:
return torch.tensor([l.item() for l in losses]).to(device)
To perform the backward computation, we save losses
, a list which holds the
individual graphs storing the loss for each example. We also save the
emissions_graphs
so that we can access the gradients in order to construct
the full gradient with respect to the input
tensor.
GTN provides parallel_for()
to run computations in parallel. We us it
here to parallelize the forward computation and backward computations over
examples in the batch. Using parallel_for()
requires some care. For
example notice the lines:
losses = [None] * B
emissions_graphs = [None] * B
These lists must be preconstructed so that threads can insert into the list
rather than constructively appending to it, which could cause a race condition
during the execution of forward_single
.
The backward
method is very simple. It just calls backward()
on
each loss
graph and accumulating the gradients into a
torch.Tensor
.
@staticmethod
def backward(ctx, grad_output):
losses, emissions_graphs, in_shape = ctx.auxiliary_data
B, T, C = in_shape
input_grad = torch.empty((B, T, C))
# Compute the gradients for each example:
def backward_single(b):
gtn.backward(losses[b])
emissions = emissions_graphs[b]
grad = emissions.grad().weights_to_numpy()
input_grad[b] = torch.from_numpy(grad).view(1, T, C)
# Compute gradients in parallel over the batch:
gtn.parallel_for(backward_single, range(B))
return grad_output.unsqueeze(1).unsqueeze(1) * input_grad.to(grad_output.device), None