In these notes, we assume a general latent structure model involving input variables , output variables , and latent discrete variables . We assume that , where (typically, ): i.e., the latent discrete variable can be represented as a
-th dimensional binary vector. This often results from a decomposition of a structure into parts: for example,could be a dependency tree for a sentence of words, represented as a vector of size , indexed by pairs of word indices , with if arc belongs to the tree, and otherwise.
In the literature on structured prediction, the set is sometimes called the marginal polytope, since any point inside it can be interpreted as some marginal distribution over parts of the structure (arcs) under some distribution over structures. There are three relevant problems that may be formulated in a structured setting:
Finding the highest scoring structure, a.k.a. maximum a-posteriori (MAP): identify
Marginal inference: finding the (unique) marginal distribution induced by the scores , given by an entropy projection onto the marginal polytope:
SparseMAP: finding the (unique) marginal distribution induced by the scores , given by an Euclidean projection onto the marginal polytope:
We may encode the simple case of an unstructured categorical case by setting which leads to The optimization problems above then recover some well known transformations, as described in the table below.
2 Latent structure model
Throughout, we assume a neural network classifier, parametrized byand , which consists of three parts:
An encoder function which, given an input , outputs a vector of “scores” , as ;
An argmax node which, given these scores, outputs the highest-scoring structure:
A decoder function which, given and , makes a prediction as . We will sometimes write to emphasize the dependency on . For reasons that will be clear in the sequel, we assume that the decoder also accepts convex combinations of latent variables as input, i.e., it may also output predictions where .
Thus, given input , this network predicts:
To train this network, we assume a loss function, where
denotes the true output. We want to minimize this loss over the training data, using the gradient backpropagation algorithm.
We assume is easy to compute: it can be done by performing standard gradient backpropagation from the output layer until the output of the argmax node. The main challenge of this model is to sidestep the argmax node to propagate gradient information to the encoder parameters. Indeed, we have:
so no gradient will flow to the encoder. Common approaches to circumvent this problem include:
Replace the argmax node by a stochastic node where
is a random variable parametrized by(e.g., using a Gibbs distribution). Then, compute the gradient of the expected loss . This is the approach underlying REINFORCE, score function estimators, and minimum risk training (Williams, 1992; Smith and Eisner, 2006; Stoyanov et al., 2011). Niculae et al. (2018b) explore a sparse alternative to the Gibbs distribution.
Keep the network deterministic, but do a continuous relaxation of the argmax node, for example replacing it with softmax or sparsemax (Martins and Astudillo, 2016). In the structured case, this gives rise to structured attention networks (Kim et al., 2017) and their sparse variant, SparseMAP (Niculae et al., 2018a). Mathematically, this corresponds to moving the expectation inside the loss, optimizing .
In what follows, we assume that:
We have access to the gradient ;111This gradient would not exist if the decoder were defined only for the vertices and not convex combinations thereof. This assumption is not needed in the minimum risk training approach discussed toward the end of this note.
We want to replace the (zero) gradient by a surrogate .
3 SPIGOT as the approximate computation of a pulled back loss
We now provide an interpretation of SPIGOT as the minimization of a “pulled back” loss with respect to the latent variable . SPIGOT uses the following surrogate gradient:
where we used the fact that the SparseMAP transformation (Niculae et al., 2018a) is equivalent to an Euclidean projection, i.e. .
3.1 Intermediate loss on the latent variable
Let us start by stating an obvious fact, which will draw intuition for the rest: if we had supervision for the latent variable (e.g., if the true label was revealed to us), we could simply define an intermediate loss which can induce nonzero updates to the encoder parameters. In fact, if is small, we can enumerate all possible values of and define as the one that minimizes the downstream loss, , using the current network parameters ; this would become our “groundtruth.”
While this seems somewhat sensible, we may expect some instability in the beginning of the training process, since the decoder parameters are likely to be very suboptimal at this stage. A more robust procedure is to allow for some label uncertainty: instead of picking a single label , pick the convex combination that minimizes . In fact, it is likely that the that minimizes the downstream loss will not put all the probability mass on a single label, and we may benefit from that if the downstream loss is what we care about. With this in mind, we define:
For most interesting predictive models , this optimization problem is non-convex and lacks a closed form solution. One common strategy is the projected gradient algorithm, which iteratively performs the following updates:
where is a step size and denotes the Euclidean projection of point onto the set . With a suitable choice of step sizes, the projected gradient algorithm is guaranteed to converge to a local optimum of Eq. 8. If we initialize and run a single iteration of projected gradient, we obtain the following estimate of :
we get the following gradient:
which is precisely the SPIGOT gradient surrogate presented in Eq. 7. This leads to the following insight into how SPIGOT updates the encoder parameters:
This construction suggests some possible alternate strategies. The first results in a well known algorithm, while the rest result in novel variations.
- Relaxing the constraint.
The constraints in Eq. 8 make the optimization problem more complicated. We relax them and define . This problem still must be tackled iteratively, but the projection step can now be avoided. One iteration of gradient descent yields . The perceptron update then recovers straight-through,222Specifically, the “identity” variant of STE, in which the backward pass acts as if (Bengio et al., 2013). via a novel derivation:
This leads to the following insight into straight-through and its relationship to SPIGOT:
- Multiple projected gradient steps.
Instead of a single projected gradient step, we could have run multiple steps of the iteration in Eq. 9. We would expect this to yield an estimate closer to , at the cost of more computation.
- Different initialization.
The projected gradient update in Eq. 10 uses as the initial point. This is a sensible choice, if we believe the encoder prediction is close enough to the “groundtruth” , and it is computationally convenient because has already been computed in the forward propagation step and can be cached. However, other initializations are possible, for example , or .
- Different intermediate loss function.
For simplicity, consider the unstructured case. Let . If we use the cross-entropy loss instead the perceptron loss, we get
This generalizes easily to the CRF loss in the structured case.
- Exponentiated gradient instead of projected gradient.
where each point is strictly positive. This includes the initializer , so we cannot have ; for this reason we assume . A single iteration of exponentiated gradient with this initialization gives:
With the cross-entropy loss, i.e. the Kullback-Leibler divergence, we obtain:
i.e., the surrogate gradient is the difference of a softmax with a softmax with “perturbed” scores. This generalizes to an instance of mirror descent with Kullback-Leibler projections in the structured case.
4 Relation to other methods for latent structure models
4.1 Continuous relaxation of argmax
4.2 Minimum risk training
In this case, the network has a stochastic node , with as above. The gradient of the risk with respect to is:
where is a vector where the th entry contains the loss value .
Another way of writing the gradient above, noting that , is:
It is interesting to compare this gradient with the SPIGOT surrogate gradient in Eq. 7. Also here a “pulled-back loss” (now ) is used in the gradient computation, this time as part of a weighted sum, where the weights are the reward and the probability . For example, if the downstream loss minimizer is and all are equally bad (i.e., if they have the same loss), then we obtain
- Bengio et al.  Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. preprint arXiv:1308.3432, 2013.
- Kim et al.  Yoon Kim, Carl Denton, Luong Hoang, and Alexander M Rush. Structured attention networks. In Proc. of ICLR, 2017.
- Kivinen and Warmuth  Jyrki Kivinen and Manfred K Warmuth. Exponentiated gradient versus gradient descent for linear predictors. Information and Computation, 132(1):1–63, 1997.
- Martins and Astudillo  André FT Martins and Ramón Fernandez Astudillo. From softmax to sparsemax: A sparse model of attention and multi-label classification. In Proc. of ICML, 2016.
- Niculae et al. [2018a] Vlad Niculae, André FT Martins, Mathieu Blondel, and Claire Cardie. SparseMAP: Differentiable sparse structured inference. In Proc. of ICML, 2018a.
- Niculae et al. [2018b] Vlad Niculae, André FT Martins, and Claire Cardie. Towards dynamic computation graphs via sparse latent structure. In Proc. of EMNLP, 2018b.
- Peng et al.  Hao Peng, Sam Thomson, and Noah A Smith. Backpropagating through structured argmax using a SPIGOT. In Proc. of ACL, 2018.
- Smith and Eisner  David A Smith and Jason Eisner. Minimum risk annealing for training log-linear models. In Proc. of COLING/ACL, 2006.
- Stoyanov et al.  Veselin Stoyanov, Alexander Ropson, and Jason Eisner. Empirical risk minimization of graphical model parameters given approximate inference, decoding, and model structure. In Proc. of AISTATS, 2011.
- Williams  Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3-4):229–256, 1992.