 # Notes on Latent Structure Models and SPIGOT

These notes aim to shed light on the recently proposed structured projected intermediate gradient optimization technique (SPIGOT, Peng et al., 2018). SPIGOT is a variant of the straight-through estimator (Bengio et al., 2013) which bypasses gradients of the argmax function by back-propagating a surrogate "gradient." We provide a new interpretation to the proposed gradient and put this technique into perspective, linking it to other methods for training neural networks with discrete latent variables. As a by-product, we suggest alternate variants of SPIGOT which will be further explored in future work.

## Authors

##### This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

## 1 Introduction

In these notes, we assume a general latent structure model involving input variables , output variables , and latent discrete variables . We assume that , where (typically, ): i.e., the latent discrete variable can be represented as a

-th dimensional binary vector. This often results from a decomposition of a structure into parts: for example,

could be a dependency tree for a sentence of words, represented as a vector of size , indexed by pairs of word indices , with if arc belongs to the tree, and otherwise.

#### Notation.

In the following, we denote the

-dimensional probability simplex by

Given , we denote the expectation of a function

under the probability distribution

by . We denote the convex hull of the (finite) set by . The set can be interpreted as the smallest convex set which contains .

#### Background.

In the literature on structured prediction, the set is sometimes called the marginal polytope, since any point inside it can be interpreted as some marginal distribution over parts of the structure (arcs) under some distribution over structures. There are three relevant problems that may be formulated in a structured setting:

• Finding the highest scoring structure, a.k.a. maximum a-posteriori (MAP): identify

 argmaxμ∈conv(Z)s⊤μ (1)
• Marginal inference: finding the (unique) marginal distribution induced by the scores , given by an entropy projection onto the marginal polytope:

 (2)
• SparseMAP: finding the (unique) marginal distribution induced by the scores , given by an Euclidean projection onto the marginal polytope:

 argmaxμ∈conv(Z)s⊤μ−12∥μ∥2 (3)

#### Unstructured setting.

We may encode the simple case of an unstructured categorical case by setting which leads to The optimization problems above then recover some well known transformations, as described in the table below.

## 2 Latent structure model

Throughout, we assume a neural network classifier, parametrized by

and , which consists of three parts:

• An encoder function which, given an input , outputs a vector of “scores” , as ;

• An argmax node which, given these scores, outputs the highest-scoring structure:

 ^z(s)=argmaxz∈Zs⊤z. (4)
• A decoder function which, given and , makes a prediction as . We will sometimes write to emphasize the dependency on . For reasons that will be clear in the sequel, we assume that the decoder also accepts convex combinations of latent variables as input, i.e., it may also output predictions where .

Thus, given input , this network predicts:

 ^y=gθ⎛⎜ ⎜⎝x,^z(s)argmaxz∈Zfϕ(x)⊤z⎞⎟ ⎟⎠. (5)

To train this network, we assume a loss function

, where

denotes the true output. We want to minimize this loss over the training data, using the gradient backpropagation algorithm.

We assume is easy to compute: it can be done by performing standard gradient backpropagation from the output layer until the output of the argmax node. The main challenge of this model is to sidestep the argmax node to propagate gradient information to the encoder parameters. Indeed, we have:

 ∇ϕL(^y,y⋆)=∂fϕ(x)∂ϕ∂^z(s)∂s=0∇zL(^y(^z),y⋆)=0, (6)

so no gradient will flow to the encoder. Common approaches to circumvent this problem include:

• Replace the argmax node by a stochastic node where

is a random variable parametrized by

(e.g., using a Gibbs distribution). Then, compute the gradient of the expected loss . This is the approach underlying REINFORCE, score function estimators, and minimum risk training (Williams, 1992; Smith and Eisner, 2006; Stoyanov et al., 2011). Niculae et al. (2018b) explore a sparse alternative to the Gibbs distribution.

• Keep the network deterministic, but do a continuous relaxation of the argmax node, for example replacing it with softmax or sparsemax (Martins and Astudillo, 2016). In the structured case, this gives rise to structured attention networks (Kim et al., 2017) and their sparse variant, SparseMAP (Niculae et al., 2018a). Mathematically, this corresponds to moving the expectation inside the loss, optimizing .

• Keep the argmax node and perform the usual forward computation, but backpropagate a surrogate gradient. This is the approach underlying straight-through estimators (Bengio et al., 2013) and SPIGOT (Peng et al., 2018). We will develop this approach in the remainder of these notes.

In what follows, we assume that:

• We have access to the gradient ;111This gradient would not exist if the decoder were defined only for the vertices and not convex combinations thereof. This assumption is not needed in the minimum risk training approach discussed toward the end of this note.

• We want to replace the (zero) gradient by a surrogate .

## 3 SPIGOT as the approximate computation of a pulled back loss

We now provide an interpretation of SPIGOT as the minimization of a “pulled back” loss with respect to the latent variable . SPIGOT uses the following surrogate gradient:

 ~∇sL(^y(^z),y⋆) = ^z−Πconv(Z)[^z−η∇zL(^y(^z),y⋆)] (7) = ^z−SparseMAP(^z−ηγ(^z)),

where we used the fact that the SparseMAP transformation (Niculae et al., 2018a) is equivalent to an Euclidean projection, i.e. .

### 3.1 Intermediate loss on the latent variable

Let us start by stating an obvious fact, which will draw intuition for the rest: if we had supervision for the latent variable (e.g., if the true label was revealed to us), we could simply define an intermediate loss which can induce nonzero updates to the encoder parameters. In fact, if is small, we can enumerate all possible values of and define as the one that minimizes the downstream loss, , using the current network parameters ; this would become our “groundtruth.”

While this seems somewhat sensible, we may expect some instability in the beginning of the training process, since the decoder parameters are likely to be very suboptimal at this stage. A more robust procedure is to allow for some label uncertainty: instead of picking a single label , pick the convex combination that minimizes . In fact, it is likely that the that minimizes the downstream loss will not put all the probability mass on a single label, and we may benefit from that if the downstream loss is what we care about. With this in mind, we define:

 μ⋆=argminμ∈conv(Z)L(^y(μ),y⋆). (8)

For most interesting predictive models , this optimization problem is non-convex and lacks a closed form solution. One common strategy is the projected gradient algorithm, which iteratively performs the following updates:

 μ(t+1)=Πconv(Z)[μ(t)−ηt∇pL(^y(μ(t)),y⋆)], (9)

where is a step size and denotes the Euclidean projection of point onto the set . With a suitable choice of step sizes, the projected gradient algorithm is guaranteed to converge to a local optimum of Eq. 8. If we initialize and run a single iteration of projected gradient, we obtain the following estimate of :

 ~μ=Πconv(Z)[^z−η∇pL(^y(^z),y⋆)]. (10)

We can now treat as if it were the “groundtruth” label distribution, turning the optimization of the encoder as if it were a supervised learning problem. If we use a perceptron loss,

 ℓperc(^z(s),~μ) = maxz∈Zs⊤z−s⊤~μ (11) = s⊤^z(s)−s⊤~μ,

 ∇sℓperc(^z(s),~μ) = ^z−~μ (12) = ^z−Πconv(Z)[^z−η∇μL(^y(^z),y)] = ^z−SparseMAP(^z−ηγ(^z)),

which is precisely the SPIGOT gradient surrogate presented in Eq. 7. This leads to the following insight into how SPIGOT updates the encoder parameters:

SPIGOT minimizes the perceptron loss between and a pulled back target computed by one projected gradient step on starting at .

This construction suggests some possible alternate strategies. The first results in a well known algorithm, while the rest result in novel variations.

[style=unboxed,leftmargin=0cm]

Relaxing the constraint.

The constraints in Eq. 8 make the optimization problem more complicated. We relax them and define . This problem still must be tackled iteratively, but the projection step can now be avoided. One iteration of gradient descent yields . The perceptron update then recovers straight-through,222Specifically, the “identity” variant of STE, in which the backward pass acts as if (Bengio et al., 2013). via a novel derivation:

 ∇sℓperc(^z(s),~μ)=^z−(^z−ηγ(^z))=ηγ(^z). (13)

This leads to the following insight into straight-through and its relationship to SPIGOT:

STE minimizes the perceptron loss between the latent and a pulled back target computed by one gradient step on starting at .

Instead of a single projected gradient step, we could have run multiple steps of the iteration in Eq. 9. We would expect this to yield an estimate closer to , at the cost of more computation.

Different initialization.

The projected gradient update in Eq. 10 uses as the initial point. This is a sensible choice, if we believe the encoder prediction is close enough to the “groundtruth” , and it is computationally convenient because has already been computed in the forward propagation step and can be cached. However, other initializations are possible, for example , or .

Different intermediate loss function.

For simplicity, consider the unstructured case. Let . If we use the cross-entropy loss instead the perceptron loss, we get

 ∇sℓcross(p(s),~p) = p(s)−~p (14) = p(s)−ΠΔK[p(s)−η∇pL(^y(p(s)),y⋆)] = p(s)−sparsemax(p(s)−ηγ(p(s))).

This generalizes easily to the CRF loss in the structured case.

Also in the unstructured case, the exponentiated gradient algorithm (Kivinen and Warmuth, 1997) tackle the constrained optimization problem in Eq. 8 with the following multiplicative updates:

 p(t+1) ∝ p(t)exp(−ηt∇pL(^y(p(t)),y⋆)), (15)

where each point is strictly positive. This includes the initializer , so we cannot have ; for this reason we assume . A single iteration of exponentiated gradient with this initialization gives:

 ~p ∝ p(s)exp(−η∇pL(p(s),y⋆)) (16) = softmax(logp(s)−ηγ(p(s))) = softmax(s−ηγ(p(s))).

With the cross-entropy loss, i.e. the Kullback-Leibler divergence

, we obtain:

 ∇sℓcross(p(s),~p) = p(s)−~p (17) = p(s)−softmax(s−ηγ(p(s))) = softmax(s)−softmax(s−ηγ(p(s))),

i.e., the surrogate gradient is the difference of a softmax with a softmax with “perturbed” scores. This generalizes to an instance of mirror descent with Kullback-Leibler projections in the structured case.

## 4 Relation to other methods for latent structure models

### 4.1 Continuous relaxation of argmax

To simplify, let us consider the case where

is a categorical variable. If we replace the argmax node by a continuous transformation

(e.g., a softmax with a temperature), the gradient

can be exactly computed by the chain rule:

 ∇sL(^y(ρ(s)),y) = Jρ(s)∇zL(^y(ρ(s)),y), (18)

where is the Jacobian of transformation at point .

### 4.2 Minimum risk training

In this case, the network has a stochastic node , with as above. The gradient of the risk with respect to is:

 ∇sEs[L(^y(z),y)] = ∑zL(^y(z),y)∇spz(s) (19) = Jρ(s)ℓ,

where is a vector where the th entry contains the loss value .

Another way of writing the gradient above, noting that , is:

 ∇sEs[L(^y(z),y)] = ∑zL(^y(z),y)∇spz(s) (20) = Es[L(^y(z),y)∇slogpz(s)].

It is interesting to compare this gradient with the SPIGOT surrogate gradient in Eq. 7. Also here a “pulled-back loss” (now ) is used in the gradient computation, this time as part of a weighted sum, where the weights are the reward and the probability . For example, if the downstream loss minimizer is and all are equally bad (i.e., if they have the same loss), then we obtain

 ∇sEs[L(^y(z),y)] ∝ pz⋆(s)∇slogpz⋆(s). (21)