Autograd Basics¶
Whenever we apply a function on gtn::Graph
(s), if at least one of
its inputs requires a gradient, the output gtn::Graph
records its
inputs, and a method by which to compute its gradient. The gradients are
computed with a GradFunc
, a function that takes the gradient with respect
to the gtn::Graph
and computes the gradient with respect to its
inputs.
To recursively compute the gradient of a gtn::Graph
with respect
to all of its inputs in the computation graph, call gtn::backward()
on the gtn::Graph
. This iteratively computes the gradients by
repeatedly applying the chain rule in a reverse topological ordering of the
computation graph.
Here is a simple example
auto g1 = Graph(true);
g1.addNode(true);
g1.addNode(false, true);
g1.addArc(0, 1, 0, 0, 1.0);
auto g2 = Graph(false);
g2.addNode(true);
g2.addNode(false, true);
g2.addArc(0, 1, 0, 0, 2.0);
auto g3 = Graph(true);
g3.addNode(true);
g3.addNode(false, true);
g3.addArc(0, 1, 0, 0, 3.0);
auto g4 = negate(subtract(add(g1, g2), g3));
g4.backward();
auto g1Grad = g1.grad(); // g1Grad.item() = -1
auto g3Grad = g3.grad(); // g3Grad.item() = 1
Warning
Calling g2.grad()
will throw an exception since calcGrad
is set to false
for the graph
Retaining the Computation Graph
The gtn::backward()
function takes an optional boolean parameter,
retainGraph
, which is false
by default. When the argument is false,
each gtn::Graph
is cleared from the computation graph during the
backward pass as soon as it is no longer needed. This reduces peak memory
usage while computing gradients. Setting retainGraph
to true
is not
recommended unless there is a need to call backward multiple times without
redoing the forward computation.
Let’s consider the computation graph that is built with the above example
Assuming we need to call gtn::backward()
only once on this graph, we
can see that the intermediate Graph gSub
can be deleted as soon as as the
gradients of g3
and gSum
are computed. This will happen when
retainGraph
is set to false
.